From a3244ea748ddeffa24452f71484ee8c9c5f72d4e Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Mon, 25 Nov 2024 22:01:28 +0100 Subject: [PATCH 1/6] on condition requirement for join --- src/parser/mod.rs | 23 +++++++++++++++++-- tests/sqlparser_common.rs | 48 +++++++++++++++++++++++++++++++++++++-- tests/sqlparser_hive.rs | 2 +- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 1bf173169..f7437d867 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10102,10 +10102,30 @@ impl<'a> Parser<'a> { }; let relation = self.parse_table_factor()?; let join_constraint = self.parse_join_constraint(natural)?; + let join_operator = join_operator_type(join_constraint); + + let requires_constraint = match join_operator { + JoinOperator::Inner(JoinConstraint::None) + | JoinOperator::LeftOuter(JoinConstraint::None) + | JoinOperator::RightOuter(JoinConstraint::None) + | JoinOperator::FullOuter(JoinConstraint::None) + | JoinOperator::LeftSemi(JoinConstraint::None) + | JoinOperator::RightSemi(JoinConstraint::None) + | JoinOperator::LeftAnti(JoinConstraint::None) + | JoinOperator::RightAnti(JoinConstraint::None) + | JoinOperator::Semi(JoinConstraint::None) + | JoinOperator::Anti(JoinConstraint::None) => !natural, + _ => false, + }; + + if requires_constraint { + self.expected("ON, or USING after JOIN", self.peek_token())? + } + Join { relation, global, - join_operator: join_operator_type(join_constraint), + join_operator, } }; joins.push(join); @@ -10914,7 +10934,6 @@ impl<'a> Parser<'a> { Ok(JoinConstraint::Using(columns)) } else { Ok(JoinConstraint::None) - //self.expected("ON, or USING after JOIN", self.peek_token()) } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index b41063859..9e0f9e805 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7444,7 +7444,7 @@ fn lateral_derived() { #[test] fn lateral_function() { - let sql = "SELECT * FROM customer LEFT JOIN LATERAL generate_series(1, customer.id)"; + let sql = "SELECT * FROM customer CROSS JOIN LATERAL generate_series(1, customer.id)"; let actual_select_only = verified_only_select(sql); let expected = Select { distinct: None, @@ -7485,7 +7485,7 @@ fn lateral_function() { alias: None, }, global: false, - join_operator: JoinOperator::LeftOuter(JoinConstraint::None), + join_operator: JoinOperator::CrossJoin, }], }], lateral_views: vec![], @@ -12198,3 +12198,47 @@ fn parse_create_table_select() { ); } } + +#[test] +fn parse_no_condition_join_strategy() { + let dialects = all_dialects_where(|d| d.supports_create_table_select()); + + let join_types = vec![ + "JOIN", + "INNER JOIN", + "LEFT JOIN", + "LEFT OUTER JOIN", + "RIGHT JOIN", + "RIGHT OUTER JOIN", + "FULL JOIN", + "FULL OUTER JOIN", + "CROSS JOIN", + "NATURAL JOIN", + "LEFT SEMI JOIN", + "RIGHT SEMI JOIN", + "LEFT ANTI JOIN", + "RIGHT ANTI JOIN", + "SEMI JOIN", + "ANTI JOIN", + ]; + + for join in join_types { + let sql = format!( + "SELECT * FROM (SELECT 1 AS id, 'Foo' AS name) AS l {} (SELECT 1 AS id, 'Bar' AS name) AS r", + join + ); + let result = dialects.parse_sql_statements(&sql); + if join.starts_with("CROSS") || join.starts_with("NATURAL") { + // CROSS JOIN and NATURAL JOIN don't require ON or USING clauses + assert!(result.is_ok()); + } else { + // Other joins require ON or USING clauses + assert_eq!( + result.unwrap_err(), + ParserError::ParserError( + "Expected: ON, or USING after JOIN, found: EOF".to_string() + ) + ); + } + } +} diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 8d4f7a680..02754513b 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -285,7 +285,7 @@ fn test_distribute_by() { #[test] fn no_join_condition() { - let join = "SELECT a, b FROM db.table_name JOIN a"; + let join = "SELECT a, b FROM db.table_name CROSS JOIN a"; hive().verified_stmt(join); } From c10f0f7a811cfd92442e1387a23bebc3beb46e30 Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Mon, 25 Nov 2024 22:10:00 +0100 Subject: [PATCH 2/6] minor changes --- tests/sqlparser_common.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 07bd26365..d31da38c7 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -12409,8 +12409,10 @@ fn parse_no_condition_join_strategy() { ) ); } + } } +#[test] fn test_reserved_keywords_for_identifiers() { let dialects = all_dialects_where(|d| d.is_reserved_for_identifier(Keyword::INTERVAL)); // Dialects that reserve the word INTERVAL will not allow it as an unquoted identifier From f19fd13aa8fe63438e87bbbb12dec26bf18b8dde Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Thu, 28 Nov 2024 16:24:30 +0100 Subject: [PATCH 3/6] introduce verify_join_constraint --- src/ast/query.rs | 21 +++++++++++++++++++++ src/dialect/mod.rs | 32 +++++++++++++++++++++++++++++++- src/dialect/mysql.rs | 37 ++++++++++++++++++++++++++++++++++++- src/parser/mod.rs | 16 +--------------- tests/sqlparser_common.rs | 31 +++++++++++++++++++++++++------ 5 files changed, 114 insertions(+), 23 deletions(-) diff --git a/src/ast/query.rs b/src/ast/query.rs index bf36c626f..bf1be8ffe 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -1827,6 +1827,27 @@ pub enum JoinOperator { }, } +impl JoinOperator { + pub fn constraint(&self) -> JoinConstraint { + match self { + JoinOperator::Inner(constraint) + | JoinOperator::LeftOuter(constraint) + | JoinOperator::RightOuter(constraint) + | JoinOperator::FullOuter(constraint) + | JoinOperator::Semi(constraint) + | JoinOperator::LeftSemi(constraint) + | JoinOperator::RightSemi(constraint) + | JoinOperator::Anti(constraint) + | JoinOperator::LeftAnti(constraint) + | JoinOperator::RightAnti(constraint) => constraint.clone(), + JoinOperator::AsOf { constraint, .. } => constraint.clone(), + JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply => { + JoinConstraint::None + } + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index b622c1da3..c1f208476 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -49,7 +49,7 @@ pub use self::postgresql::PostgreSqlDialect; pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; -use crate::ast::{ColumnOption, Expr, Statement}; +use crate::ast::{ColumnOption, Expr, JoinConstraint, JoinOperator, Statement}; pub use crate::keywords; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -687,6 +687,36 @@ pub trait Dialect: Debug + Any { fn is_reserved_for_identifier(&self, kw: Keyword) -> bool { keywords::RESERVED_FOR_IDENTIFIER.contains(&kw) } + + /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. + /// Returns `true` if the join constraint is valid, otherwise `false`. + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + let constraint = join_operator.constraint(); + + match constraint { + JoinConstraint::Natural => true, + JoinConstraint::On(_) | JoinConstraint::Using(_) => match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } => true, + _ => false, + }, + JoinConstraint::None => match join_operator { + JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply => { + true + } + _ => false, + }, + } + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 197ce48d4..29149ba5b 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -19,7 +19,9 @@ use alloc::boxed::Box; use crate::{ - ast::{BinaryOperator, Expr, LockTable, LockTableType, Statement}, + ast::{ + BinaryOperator, Expr, JoinConstraint, JoinOperator, LockTable, LockTableType, Statement, + }, dialect::Dialect, keywords::Keyword, parser::{Parser, ParserError}, @@ -102,6 +104,39 @@ impl Dialect for MySqlDialect { fn supports_create_table_select(&self) -> bool { true } + + /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. + /// Returns `true` if the join constraint is valid, otherwise `false`. + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + let constraint = join_operator.constraint(); + + match constraint { + JoinConstraint::Natural => true, + JoinConstraint::On(_) | JoinConstraint::Using(_) => match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } => true, + _ => false, + }, + JoinConstraint::None => match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::CrossApply + | JoinOperator::OuterApply => true, + _ => false, + }, + } + } } /// `LOCK TABLES` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 9b58a107c..ca3c6b25a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10157,21 +10157,7 @@ impl<'a> Parser<'a> { let join_constraint = self.parse_join_constraint(natural)?; let join_operator = join_operator_type(join_constraint); - let requires_constraint = match join_operator { - JoinOperator::Inner(JoinConstraint::None) - | JoinOperator::LeftOuter(JoinConstraint::None) - | JoinOperator::RightOuter(JoinConstraint::None) - | JoinOperator::FullOuter(JoinConstraint::None) - | JoinOperator::LeftSemi(JoinConstraint::None) - | JoinOperator::RightSemi(JoinConstraint::None) - | JoinOperator::LeftAnti(JoinConstraint::None) - | JoinOperator::RightAnti(JoinConstraint::None) - | JoinOperator::Semi(JoinConstraint::None) - | JoinOperator::Anti(JoinConstraint::None) => !natural, - _ => false, - }; - - if requires_constraint { + if !self.dialect.verify_join_constraint(&join_operator) { self.expected("ON, or USING after JOIN", self.peek_token())? } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index d31da38c7..5e8571771 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -12370,7 +12370,8 @@ fn parse_create_table_select() { #[test] fn parse_no_condition_join_strategy() { - let dialects = all_dialects_where(|d| d.supports_create_table_select()); + let mysql_dialect = TestedDialects::new(vec![Box::new(MySqlDialect {})]); + let generic_dialect = TestedDialects::new(vec![Box::new(GenericDialect {})]); let join_types = vec![ "JOIN", @@ -12396,14 +12397,32 @@ fn parse_no_condition_join_strategy() { "SELECT * FROM (SELECT 1 AS id, 'Foo' AS name) AS l {} (SELECT 1 AS id, 'Bar' AS name) AS r", join ); - let result = dialects.parse_sql_statements(&sql); + let result_generic = generic_dialect.parse_sql_statements(&sql); if join.starts_with("CROSS") || join.starts_with("NATURAL") { - // CROSS JOIN and NATURAL JOIN don't require ON or USING clauses - assert!(result.is_ok()); + assert!(result_generic.is_ok()); + } else { + assert_eq!( + result_generic.unwrap_err(), + ParserError::ParserError( + "Expected: ON, or USING after JOIN, found: EOF".to_string() + ) + ); + } + + let result_mysql = mysql_dialect.parse_sql_statements(&sql); + if join.starts_with("CROSS") + || join.starts_with("NATURAL") + || join.starts_with("INNER") + || join.starts_with("JOIN") + || join.starts_with("LEFT JOIN") + || join.starts_with("LEFT OUTER") + || join.starts_with("RIGHT JOIN") + || join.starts_with("RIGHT OUTER") + { + assert!(result_mysql.is_ok()); } else { - // Other joins require ON or USING clauses assert_eq!( - result.unwrap_err(), + result_mysql.unwrap_err(), ParserError::ParserError( "Expected: ON, or USING after JOIN, found: EOF".to_string() ) From 62f8b0de3332dfae9606f032715e268a5280544f Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Thu, 28 Nov 2024 16:32:37 +0100 Subject: [PATCH 4/6] minor changes --- src/dialect/mod.rs | 40 +++++++++++++++++--------------------- src/dialect/mysql.rs | 46 +++++++++++++++++++++----------------------- 2 files changed, 40 insertions(+), 46 deletions(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index c1f208476..1fd1c1502 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -691,30 +691,26 @@ pub trait Dialect: Debug + Any { /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. /// Returns `true` if the join constraint is valid, otherwise `false`. fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { - let constraint = join_operator.constraint(); - - match constraint { + match join_operator.constraint() { JoinConstraint::Natural => true, - JoinConstraint::On(_) | JoinConstraint::Using(_) => match join_operator { + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::RightSemi(_) - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightAnti(_) - | JoinOperator::AsOf { .. } => true, - _ => false, - }, - JoinConstraint::None => match join_operator { - JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply => { - true - } - _ => false, - }, + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } + ), + JoinConstraint::None => matches!( + join_operator, + JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply + ), } } } diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 29149ba5b..f32985d26 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -108,33 +108,31 @@ impl Dialect for MySqlDialect { /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. /// Returns `true` if the join constraint is valid, otherwise `false`. fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { - let constraint = join_operator.constraint(); - - match constraint { + match join_operator.constraint() { JoinConstraint::Natural => true, - JoinConstraint::On(_) | JoinConstraint::Using(_) => match join_operator { + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::RightSemi(_) - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightAnti(_) - | JoinOperator::AsOf { .. } => true, - _ => false, - }, - JoinConstraint::None => match join_operator { + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } + ), + JoinConstraint::None => matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::CrossApply - | JoinOperator::OuterApply => true, - _ => false, - }, + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::CrossApply + | JoinOperator::OuterApply + ), } } } From cc1c2a4b9c58035ef1b0819f8c72ddd0a81afa1e Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Fri, 29 Nov 2024 15:25:57 +0100 Subject: [PATCH 5/6] add verify_join_operator to dialect --- src/dialect/ansi.rs | 13 +- src/dialect/bigquery.rs | 31 +++- src/dialect/clickhouse.rs | 21 ++- src/dialect/databricks.rs | 18 ++- src/dialect/duckdb.rs | 21 ++- src/dialect/hive.rs | 38 ++++- src/dialect/mod.rs | 11 +- src/dialect/mssql.rs | 36 ++++- src/dialect/mysql.rs | 26 ++-- src/dialect/postgresql.rs | 14 +- src/dialect/redshift.rs | 27 ++++ src/dialect/sqlite.rs | 14 +- src/parser/mod.rs | 16 ++- tests/sqlparser_common.rs | 292 +++++++++++++++++++++++++++----------- 14 files changed, 470 insertions(+), 108 deletions(-) diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs index 32ba7b32a..615bac29d 100644 --- a/src/dialect/ansi.rs +++ b/src/dialect/ansi.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ast::JoinOperator, dialect::Dialect}; /// A [`Dialect`] for [ANSI SQL](https://en.wikipedia.org/wiki/SQL:2011). #[derive(Debug)] @@ -33,4 +33,15 @@ impl Dialect for AnsiDialect { fn require_interval_qualifier(&self) -> bool { true } + + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } } diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index 96633552b..6cd8546da 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ + ast::{JoinConstraint, JoinOperator}, + dialect::Dialect, +}; /// A [`Dialect`] for [Google Bigquery](https://cloud.google.com/bigquery/) #[derive(Debug, Default)] @@ -72,4 +75,30 @@ impl Dialect for BigQueryDialect { fn require_interval_qualifier(&self) -> bool { true } + + // https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#join_types + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } + + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + match join_operator.constraint() { + JoinConstraint::Natural => false, + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + ), + JoinConstraint::None => matches!(join_operator, JoinOperator::CrossJoin), + } + } } diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 0c8f08040..64ea08efb 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ast::JoinOperator, dialect::Dialect}; // A [`Dialect`] for [ClickHouse](https://clickhouse.com/). #[derive(Debug)] @@ -50,4 +50,23 @@ impl Dialect for ClickHouseDialect { fn supports_limit_comma(&self) -> bool { true } + + // https://clickhouse.com/docs/en/sql-reference/statements/select/join + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } => true, + _ => false, + } + } } diff --git a/src/dialect/databricks.rs b/src/dialect/databricks.rs index 4924e8077..aabce9a5f 100644 --- a/src/dialect/databricks.rs +++ b/src/dialect/databricks.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ast::JoinOperator, dialect::Dialect}; /// A [`Dialect`] for [Databricks SQL](https://www.databricks.com/) /// @@ -59,4 +59,20 @@ impl Dialect for DatabricksDialect { fn require_interval_qualifier(&self) -> bool { true } + + // https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-join.html + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) => true, + _ => false, + } + } } diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index a2699d850..30630a263 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ast::JoinOperator, dialect::Dialect}; /// A [`Dialect`] for [DuckDB](https://duckdb.org/) #[derive(Debug, Default)] @@ -75,4 +75,23 @@ impl Dialect for DuckDbDialect { fn supports_load_extension(&self) -> bool { true } + + // https://duckdb.org/docs/sql/query_syntax/from.html#joins + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::AsOf { .. } => true, + _ => false, + } + } } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 571f9b9ba..ff5384322 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ + ast::{JoinConstraint, JoinOperator}, + dialect::Dialect, +}; /// A [`Dialect`] for [Hive](https://hive.apache.org/). #[derive(Debug)] @@ -61,4 +64,37 @@ impl Dialect for HiveDialect { fn supports_load_data(&self) -> bool { true } + + // https://cwiki.apache.org/confluence/display/hive/languagemanual+joins + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) => true, + _ => false, + } + } + + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + match join_operator.constraint() { + JoinConstraint::Natural => false, + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + ), + JoinConstraint::None => matches!( + join_operator, + JoinOperator::Inner(_) | JoinOperator::CrossJoin + ), + } + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1fd1c1502..4e160afee 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -688,6 +688,12 @@ pub trait Dialect: Debug + Any { keywords::RESERVED_FOR_IDENTIFIER.contains(&kw) } + /// Verifies whether the provided `JoinOperator` is supported by this SQL dialect. + /// Returns `true` if the `JoinOperator` is supported, otherwise `false`. + fn verify_join_operator(&self, _join_operator: &JoinOperator) -> bool { + true + } + /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. /// Returns `true` if the join constraint is valid, otherwise `false`. fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { @@ -709,7 +715,10 @@ pub trait Dialect: Debug + Any { ), JoinConstraint::None => matches!( join_operator, - JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply + JoinOperator::CrossJoin + | JoinOperator::CrossApply + | JoinOperator::OuterApply + | JoinOperator::AsOf { .. } ), } } diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index 2d0ef027f..f4bd693cb 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::dialect::Dialect; +use crate::{ + ast::{JoinConstraint, JoinOperator}, + dialect::Dialect, +}; /// A [`Dialect`] for [Microsoft SQL Server](https://www.microsoft.com/en-us/sql-server/) #[derive(Debug)] @@ -78,4 +81,35 @@ impl Dialect for MsSqlDialect { fn supports_named_fn_args_with_rarrow_operator(&self) -> bool { false } + + // https://learn.microsoft.com/en-us/sql/relational-databases/performance/joins?view=sql-server-ver16 + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::CrossApply + | JoinOperator::OuterApply => true, + _ => false, + } + } + + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + match join_operator.constraint() { + JoinConstraint::Natural => false, + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + ), + JoinConstraint::None => matches!( + join_operator, + JoinOperator::CrossJoin | JoinOperator::CrossApply | JoinOperator::OuterApply + ), + } + } } diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index f32985d26..c81ab4e6f 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -105,24 +105,22 @@ impl Dialect for MySqlDialect { true } - /// Verifies if the given `JoinOperator`'s constraint is valid for this SQL dialect. - /// Returns `true` if the join constraint is valid, otherwise `false`. + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { match join_operator.constraint() { JoinConstraint::Natural => true, JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( join_operator, - JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::RightSemi(_) - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightAnti(_) - | JoinOperator::AsOf { .. } + JoinOperator::Inner(_) | JoinOperator::LeftOuter(_) | JoinOperator::RightOuter(_) ), JoinConstraint::None => matches!( join_operator, @@ -130,8 +128,6 @@ impl Dialect for MySqlDialect { | JoinOperator::LeftOuter(_) | JoinOperator::RightOuter(_) | JoinOperator::CrossJoin - | JoinOperator::CrossApply - | JoinOperator::OuterApply ), } } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index dcdcc88c1..2544c17f9 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -28,7 +28,7 @@ // limitations under the License. use log::debug; -use crate::ast::{ObjectName, Statement, UserDefinedTypeRepresentation}; +use crate::ast::{JoinOperator, ObjectName, Statement, UserDefinedTypeRepresentation}; use crate::dialect::{Dialect, Precedence}; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -231,6 +231,18 @@ impl Dialect for PostgreSqlDialect { fn supports_named_fn_args_with_expr_name(&self) -> bool { true } + + // https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-JOIN + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } } pub fn parse_create(parser: &mut Parser) -> Option> { diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs index 48eb00ab1..7a3026c3d 100644 --- a/src/dialect/redshift.rs +++ b/src/dialect/redshift.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::ast::{JoinConstraint, JoinOperator}; use crate::dialect::Dialect; use core::iter::Peekable; use core::str::Chars; @@ -79,4 +80,30 @@ impl Dialect for RedshiftSqlDialect { fn supports_partiql(&self) -> bool { true } + + // https://docs.aws.amazon.com/redshift/latest/dg/r_Join_examples.html + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } + + fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { + match join_operator.constraint() { + JoinConstraint::Natural => false, + JoinConstraint::On(_) | JoinConstraint::Using(_) => matches!( + join_operator, + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + ), + JoinConstraint::None => matches!(join_operator, JoinOperator::CrossJoin), + } + } } diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 95717f9fd..830de26a4 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::ast::Statement; +use crate::ast::{JoinOperator, Statement}; use crate::dialect::Dialect; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -81,4 +81,16 @@ impl Dialect for SQLiteDialect { fn supports_asc_desc_in_column_definition(&self) -> bool { true } + + // https://www.sqlite.org/lang_select.html + fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { + match join_operator { + JoinOperator::Inner(_) + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin => true, + _ => false, + } + } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 055fe83d5..6c62008dc 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10186,16 +10186,24 @@ impl<'a> Parser<'a> { let join_constraint = self.parse_join_constraint(natural)?; let join_operator = join_operator_type(join_constraint); - if !self.dialect.verify_join_constraint(&join_operator) { - self.expected("ON, or USING after JOIN", self.peek_token())? - } - Join { relation, global, join_operator, } }; + + if !self.dialect.verify_join_constraint(&join.join_operator) { + self.expected("ON, or USING after JOIN", self.peek_token())? + } + + if !self.dialect.verify_join_operator(&join.join_operator) { + Err(ParserError::ParserError(format!( + "Unsupported join type {} in the current dialect", + join + )))? + } + joins.push(join); } Ok(TableWithJoins { relation, joins }) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index f51870ec5..a99aeabc0 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6115,9 +6115,11 @@ fn parse_nullary_table_valued_function() { } #[test] -fn parse_implicit_join() { +fn gen_dialect_parse_implicit_join() { + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + let sql = "SELECT * FROM t1, t2"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); assert_eq!( vec![ TableWithJoins { @@ -6151,7 +6153,7 @@ fn parse_implicit_join() { ); let sql = "SELECT * FROM t1a NATURAL JOIN t1b, t2a NATURAL JOIN t2b"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); assert_eq!( vec![ TableWithJoins { @@ -6235,7 +6237,7 @@ fn parse_cross_join() { } #[test] -fn parse_joins_on() { +fn gen_dialect_parse_joins_on() { fn join_with_constraint( relation: impl Into, alias: Option, @@ -6261,9 +6263,17 @@ fn parse_joins_on() { })), } } + + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + // Test parsing of aliases assert_eq!( - only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", table_alias("foo"), @@ -6271,17 +6281,27 @@ fn parse_joins_on() { JoinOperator::Inner, )] ); - one_statement_parses_to( + dialects.one_statement_parses_to( "SELECT * FROM t1 JOIN t2 foo ON c1 = c2", "SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2", ); // Test parsing of different join operators assert_eq!( - only(&verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint("t2", None, false, JoinOperator::Inner)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6290,7 +6310,12 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6299,11 +6324,21 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 SEMI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 SEMI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint("t2", None, false, JoinOperator::Semi)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT SEMI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT SEMI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6312,7 +6347,12 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT SEMI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT SEMI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6321,11 +6361,21 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 ANTI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 ANTI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint("t2", None, false, JoinOperator::Anti)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT ANTI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT ANTI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6334,7 +6384,12 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT ANTI JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT ANTI JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6343,7 +6398,12 @@ fn parse_joins_on() { )] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6353,7 +6413,12 @@ fn parse_joins_on() { ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 GLOBAL FULL JOIN t2 ON c1 = c2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 GLOBAL FULL JOIN t2 ON c1 = c2") + .from + ) + .joins, vec![join_with_constraint( "t2", None, @@ -6364,7 +6429,7 @@ fn parse_joins_on() { } #[test] -fn parse_joins_using() { +fn gen_dialect_parse_joins_using() { fn join_with_constraint( relation: impl Into, alias: Option, @@ -6385,64 +6450,121 @@ fn parse_joins_using() { join_operator: f(JoinConstraint::Using(vec!["c1".into()])), } } + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + // Test parsing of aliases assert_eq!( - only(&verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)") + .from + ) + .joins, vec![join_with_constraint( "t2", table_alias("foo"), JoinOperator::Inner, )] ); - one_statement_parses_to( + dialects.one_statement_parses_to( "SELECT * FROM t1 JOIN t2 foo USING(c1)", "SELECT * FROM t1 JOIN t2 AS foo USING(c1)", ); // Test parsing of different join operators assert_eq!( - only(&verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 SEMI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 SEMI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::Semi)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT SEMI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT SEMI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::LeftSemi)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT SEMI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT SEMI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::RightSemi)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 ANTI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 ANTI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::Anti)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 LEFT ANTI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 LEFT ANTI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::LeftAnti)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 RIGHT ANTI JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 RIGHT ANTI JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::RightAnti)] ); assert_eq!( - only(&verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)") + .from + ) + .joins, vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } #[test] -fn parse_natural_join() { +fn gen_dialect_parse_natural_join() { fn natural_join(f: impl Fn(JoinConstraint) -> JoinOperator, alias: Option) -> Join { Join { relation: TableFactor::Table { @@ -6460,32 +6582,59 @@ fn parse_natural_join() { } } + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + // if not specified, inner join as default assert_eq!( - only(&verified_only_select("SELECT * FROM t1 NATURAL JOIN t2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 NATURAL JOIN t2") + .from + ) + .joins, vec![natural_join(JoinOperator::Inner, None)] ); // left join explicitly assert_eq!( - only(&verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 NATURAL LEFT JOIN t2") + .from + ) + .joins, vec![natural_join(JoinOperator::LeftOuter, None)] ); // right join explicitly assert_eq!( - only(&verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 NATURAL RIGHT JOIN t2") + .from + ) + .joins, vec![natural_join(JoinOperator::RightOuter, None)] ); // full join explicitly assert_eq!( - only(&verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 NATURAL FULL JOIN t2") + .from + ) + .joins, vec![natural_join(JoinOperator::FullOuter, None)] ); // natural join another table with alias assert_eq!( - only(&verified_only_select("SELECT * FROM t1 NATURAL JOIN t2 AS t3").from).joins, + only( + dialects + .verified_only_select("SELECT * FROM t1 NATURAL JOIN t2 AS t3") + .from + ) + .joins, vec![natural_join(JoinOperator::Inner, table_alias("t3"))] ); @@ -6503,11 +6652,13 @@ fn parse_complex_join() { } #[test] -fn parse_join_nesting() { +fn gen_dialect_parse_join_nesting() { + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + let sql = "SELECT * FROM a NATURAL JOIN (b NATURAL JOIN (c NATURAL JOIN d NATURAL JOIN e)) \ NATURAL JOIN (f NATURAL JOIN (g NATURAL JOIN h))"; assert_eq!( - only(&verified_only_select(sql).from).joins, + only(dialects.verified_only_select(sql).from).joins, vec![ join(nest!(table("b"), nest!(table("c"), table("d"), table("e")))), join(nest!(table("f"), nest!(table("g"), table("h")))), @@ -6515,19 +6666,19 @@ fn parse_join_nesting() { ); let sql = "SELECT * FROM (a NATURAL JOIN b) NATURAL JOIN c"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); let from = only(select.from); assert_eq!(from.relation, nest!(table("a"), table("b"))); assert_eq!(from.joins, vec![join(table("c"))]); let sql = "SELECT * FROM (((a NATURAL JOIN b)))"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); let from = only(select.from); assert_eq!(from.relation, nest!(nest!(nest!(table("a"), table("b"))))); assert_eq!(from.joins, vec![]); let sql = "SELECT * FROM a NATURAL JOIN (((b NATURAL JOIN c)))"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); let from = only(select.from); assert_eq!(from.relation, table("a")); assert_eq!( @@ -6536,7 +6687,7 @@ fn parse_join_nesting() { ); let sql = "SELECT * FROM (a NATURAL JOIN b) AS c"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); let from = only(select.from); assert_eq!( from.relation, @@ -6552,20 +6703,22 @@ fn parse_join_nesting() { } #[test] -fn parse_join_syntax_variants() { - one_statement_parses_to( +fn gen_dialect_parse_join_syntax_variants() { + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + + dialects.one_statement_parses_to( "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", "SELECT c1 FROM t1 JOIN t2 USING(c1)", ); - one_statement_parses_to( + dialects.one_statement_parses_to( "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", ); - one_statement_parses_to( + dialects.one_statement_parses_to( "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", ); - one_statement_parses_to( + dialects.one_statement_parses_to( "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", ); @@ -6682,27 +6835,29 @@ fn parse_recursive_cte() { } #[test] -fn parse_derived_tables() { +fn gen_dialect_parse_derived_tables() { + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); + let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b"; - let _ = verified_only_select(sql); + let _ = dialects.verified_only_select(sql); //TODO: add assertions let sql = "SELECT a.x, b.y \ FROM (SELECT x FROM foo) AS a (x) \ CROSS JOIN (SELECT y FROM bar) AS b (y)"; - let _ = verified_only_select(sql); + let _ = dialects.verified_only_select(sql); //TODO: add assertions let sql = "SELECT * FROM (((SELECT 1)))"; - let _ = verified_only_select(sql); + let _ = dialects.verified_only_select(sql); // TODO: add assertions let sql = "SELECT * FROM t NATURAL JOIN (((SELECT 1)))"; - let _ = verified_only_select(sql); + let _ = dialects.verified_only_select(sql); // TODO: add assertions let sql = "SELECT * FROM (((SELECT 1) UNION (SELECT 2)) AS t1 NATURAL JOIN t2)"; - let select = verified_only_select(sql); + let select = dialects.verified_only_select(sql); let from = only(select.from); assert_eq!( from.relation, @@ -12424,9 +12579,8 @@ fn parse_create_table_select() { } #[test] -fn parse_no_condition_join_strategy() { - let mysql_dialect = TestedDialects::new(vec![Box::new(MySqlDialect {})]); - let generic_dialect = TestedDialects::new(vec![Box::new(GenericDialect {})]); +fn gen_dialect_parse_no_condition_join_strategy() { + let dialects = TestedDialects::new(vec![Box::new(GenericDialect {})]); let join_types = vec![ "JOIN", @@ -12452,32 +12606,12 @@ fn parse_no_condition_join_strategy() { "SELECT * FROM (SELECT 1 AS id, 'Foo' AS name) AS l {} (SELECT 1 AS id, 'Bar' AS name) AS r", join ); - let result_generic = generic_dialect.parse_sql_statements(&sql); + let result = dialects.parse_sql_statements(&sql); if join.starts_with("CROSS") || join.starts_with("NATURAL") { - assert!(result_generic.is_ok()); - } else { - assert_eq!( - result_generic.unwrap_err(), - ParserError::ParserError( - "Expected: ON, or USING after JOIN, found: EOF".to_string() - ) - ); - } - - let result_mysql = mysql_dialect.parse_sql_statements(&sql); - if join.starts_with("CROSS") - || join.starts_with("NATURAL") - || join.starts_with("INNER") - || join.starts_with("JOIN") - || join.starts_with("LEFT JOIN") - || join.starts_with("LEFT OUTER") - || join.starts_with("RIGHT JOIN") - || join.starts_with("RIGHT OUTER") - { - assert!(result_mysql.is_ok()); + assert!(result.is_ok()); } else { assert_eq!( - result_mysql.unwrap_err(), + result.unwrap_err(), ParserError::ParserError( "Expected: ON, or USING after JOIN, found: EOF".to_string() ) From 40b38f6419985cf27d40c0760ac110d730bf645f Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Fri, 29 Nov 2024 15:51:39 +0100 Subject: [PATCH 6/6] minor changes --- src/dialect/ansi.rs | 14 +++++++------- src/dialect/bigquery.rs | 14 +++++++------- src/dialect/clickhouse.rs | 28 ++++++++++++++-------------- src/dialect/databricks.rs | 22 +++++++++++----------- src/dialect/duckdb.rs | 28 ++++++++++++++-------------- src/dialect/hive.rs | 18 +++++++++--------- src/dialect/mssql.rs | 18 +++++++++--------- src/dialect/mysql.rs | 12 ++++++------ src/dialect/postgresql.rs | 14 +++++++------- src/dialect/redshift.rs | 14 +++++++------- src/dialect/sqlite.rs | 14 +++++++------- tests/sqlparser_hive.rs | 2 +- 12 files changed, 99 insertions(+), 99 deletions(-) diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs index 615bac29d..d294f6189 100644 --- a/src/dialect/ansi.rs +++ b/src/dialect/ansi.rs @@ -35,13 +35,13 @@ impl Dialect for AnsiDialect { } fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + ) } } diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index 6cd8546da..d72495a9c 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -78,14 +78,14 @@ impl Dialect for BigQueryDialect { // https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#join_types fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + ) } fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 64ea08efb..a341a3088 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -53,20 +53,20 @@ impl Dialect for ClickHouseDialect { // https://clickhouse.com/docs/en/sql-reference/statements/select/join fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::RightSemi(_) - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightAnti(_) - | JoinOperator::AsOf { .. } => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::AsOf { .. } + ) } } diff --git a/src/dialect/databricks.rs b/src/dialect/databricks.rs index aabce9a5f..7c394da9f 100644 --- a/src/dialect/databricks.rs +++ b/src/dialect/databricks.rs @@ -62,17 +62,17 @@ impl Dialect for DatabricksDialect { // https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-join.html fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + ) } } diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index 30630a263..c6571384c 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -78,20 +78,20 @@ impl Dialect for DuckDbDialect { // https://duckdb.org/docs/sql/query_syntax/from.html#joins fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightAnti(_) - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::RightSemi(_) - | JoinOperator::AsOf { .. } => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightAnti(_) + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::RightSemi(_) + | JoinOperator::AsOf { .. } + ) } } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index ff5384322..ce73c73c6 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -67,16 +67,16 @@ impl Dialect for HiveDialect { // https://cwiki.apache.org/confluence/display/hive/languagemanual+joins fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + ) } fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs index f4bd693cb..ae7e8a58e 100644 --- a/src/dialect/mssql.rs +++ b/src/dialect/mssql.rs @@ -84,16 +84,16 @@ impl Dialect for MsSqlDialect { // https://learn.microsoft.com/en-us/sql/relational-databases/performance/joins?view=sql-server-ver16 fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin - | JoinOperator::CrossApply - | JoinOperator::OuterApply => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + | JoinOperator::CrossApply + | JoinOperator::OuterApply + ) } fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index c81ab4e6f..09f2e055d 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -106,13 +106,13 @@ impl Dialect for MySqlDialect { } fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::CrossJoin + ) } fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 2544c17f9..dccceeaf6 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -234,14 +234,14 @@ impl Dialect for PostgreSqlDialect { // https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-JOIN fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + ) } } diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs index 7a3026c3d..d8eee6867 100644 --- a/src/dialect/redshift.rs +++ b/src/dialect/redshift.rs @@ -83,14 +83,14 @@ impl Dialect for RedshiftSqlDialect { // https://docs.aws.amazon.com/redshift/latest/dg/r_Join_examples.html fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + ) } fn verify_join_constraint(&self, join_operator: &JoinOperator) -> bool { diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 830de26a4..813c68d82 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -84,13 +84,13 @@ impl Dialect for SQLiteDialect { // https://www.sqlite.org/lang_select.html fn verify_join_operator(&self, join_operator: &JoinOperator) -> bool { - match join_operator { + matches!( + join_operator, JoinOperator::Inner(_) - | JoinOperator::LeftOuter(_) - | JoinOperator::RightOuter(_) - | JoinOperator::FullOuter(_) - | JoinOperator::CrossJoin => true, - _ => false, - } + | JoinOperator::LeftOuter(_) + | JoinOperator::RightOuter(_) + | JoinOperator::FullOuter(_) + | JoinOperator::CrossJoin + ) } } diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 02754513b..8d4f7a680 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -285,7 +285,7 @@ fn test_distribute_by() { #[test] fn no_join_condition() { - let join = "SELECT a, b FROM db.table_name CROSS JOIN a"; + let join = "SELECT a, b FROM db.table_name JOIN a"; hive().verified_stmt(join); }