diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5b8e23259..c5920f699 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -74,7 +74,7 @@ pub use self::ddl::{ pub use self::dml::{Delete, Insert}; pub use self::operator::{BinaryOperator, UnaryOperator}; pub use self::query::{ - AfterMatchSkip, ConnectBy, Cte, CteAsMaterialized, Distinct, EmptyMatchesMode, + AfterMatchSkip, ConnectBy, Cse, Cte, CteAsMaterialized, Distinct, EmptyMatchesMode, ExceptSelectItem, ExcludeSelectItem, ExprWithAlias, ExprWithAliasAndOrderBy, Fetch, ForClause, ForJson, ForXml, FormatClause, GroupByExpr, GroupByWithModifier, IdentWithAlias, IlikeSelectItem, InputFormatClause, Interpolate, InterpolateExpr, Join, JoinConstraint, @@ -90,8 +90,9 @@ pub use self::query::{ TableIndexType, TableSample, TableSampleBucket, TableSampleKind, TableSampleMethod, TableSampleModifier, TableSampleQuantity, TableSampleSeed, TableSampleSeedModifier, TableSampleUnit, TableVersion, TableWithJoins, Top, TopQuantity, UpdateTableFromKind, - ValueTableMode, Values, WildcardAdditionalOptions, With, WithFill, XmlNamespaceDefinition, - XmlPassingArgument, XmlPassingClause, XmlTableColumn, XmlTableColumnOption, + ValueTableMode, Values, WildcardAdditionalOptions, With, WithExpression, WithFill, + XmlNamespaceDefinition, XmlPassingArgument, XmlPassingClause, XmlTableColumn, + XmlTableColumnOption, }; pub use self::trigger::{ diff --git a/src/ast/query.rs b/src/ast/query.rs index ffc9ce666..84cb62e93 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -603,7 +603,7 @@ pub struct With { /// Token for the "WITH" keyword pub with_token: AttachedToken, pub recursive: bool, - pub cte_tables: Vec, + pub cte_tables: Vec, } impl fmt::Display for With { @@ -641,7 +641,71 @@ impl fmt::Display for CteAsMaterialized { } } -/// A single CTE (used after `WITH`): ` [(col1, col2, ...)] AS ( )` +/// `WITH` clause in `SELECT`. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum WithExpression { + /// Common table expression. + Cte(Cte), + /// Common scalar expression. + Cse(Cse), +} + +impl WithExpression { + pub fn cte(&self) -> Option<&Cte> { + match self { + Self::Cte(cte) => Some(cte), + Self::Cse(_) => None, + } + } + + pub fn cse(&self) -> Option<&Cse> { + match self { + Self::Cte(_) => None, + Self::Cse(cse) => Some(cse), + } + } +} + +impl fmt::Display for WithExpression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Cte(cte) => cte.fmt(f), + Self::Cse(cse) => cse.fmt(f), + } + } +} + +/// A common scalar expression (CSE). +/// +/// ```sql +/// [WITH] AS [,] +/// ``` +/// +/// See https://clickhouse.com/docs/sql-reference/statements/select/with#common-scalar-expressions +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct Cse { + pub expr: Expr, + pub ident: Ident, +} + +impl fmt::Display for Cse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.expr.fmt(f)?; + f.write_str(" AS ")?; + self.ident.fmt(f)?; + Ok(()) + } +} +/// A common table expression (CTE). +/// +/// ```sql +/// [WITH] [(col1, col2, ...)] AS ( ) [,] +/// ``` +/// /// The names in the column list before `AS`, when specified, replace the names /// of the columns returned by the query. The parser does not validate that the /// number of columns in the query matches the number of columns in the query. diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 5d3694be8..830fed8da 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -176,6 +176,23 @@ impl Spanned for With { } } +impl Spanned for super::query::WithExpression { + fn span(&self) -> Span { + match self { + super::query::WithExpression::Cte(cte) => cte.span(), + super::query::WithExpression::Cse(cse) => cse.span(), + } + } +} + +impl Spanned for super::query::Cse { + fn span(&self) -> Span { + let super::query::Cse { expr, ident } = self; + + union_spans(core::iter::once(expr.span()).chain(core::iter::once(ident.span))) + } +} + impl Spanned for Cte { fn span(&self) -> Span { let Cte { @@ -2560,7 +2577,11 @@ pub mod tests { let query = test.0.parse_query().unwrap(); let cte_span = query.clone().with.unwrap().cte_tables[0].span(); - let cte_query_span = query.clone().with.unwrap().cte_tables[0].query.span(); + let cte_query_span = query.clone().with.unwrap().cte_tables[0] + .cte() + .unwrap() + .query + .span(); let body_span = query.body.span(); // the WITH keyboard is part of the query diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index f5e70c309..ac68326ed 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -80,6 +80,11 @@ impl Dialect for ClickHouseDialect { true } + /// See . + fn supports_common_scalar_expressions(&self) -> bool { + true + } + /// See fn supports_order_by_all(&self) -> bool { true diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 3c7be8f7e..c11a71d90 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -596,6 +596,20 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports Common Scalar Expressions in `SELECT`. + /// + /// For example: + /// ```sql + /// WITH + /// toDate('2000-01-01') AS start_date + /// SELECT * from tbl WHERE col1 > start_date; + /// ``` + /// + /// [ClickHouse](https://clickhouse.com/docs/sql-reference/statements/select/with#common-scalar-expressions) + fn supports_common_scalar_expressions(&self) -> bool { + false + } + /// Return true if the dialect supports specifying multiple options /// in a `CREATE TABLE` statement for the structure of the new table. For example: /// `CREATE TABLE t (a INT, b INT) AS SELECT 1 AS b, 2 AS a` diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 54db367d8..81017654f 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11801,7 +11801,7 @@ impl<'a> Parser<'a> { Some(With { with_token: with_token.clone().into(), recursive: self.parse_keyword(Keyword::RECURSIVE), - cte_tables: self.parse_comma_separated(Parser::parse_cte)?, + cte_tables: self.parse_comma_separated(Parser::parse_with_expression)?, }) } else { None @@ -12260,7 +12260,28 @@ impl<'a> Parser<'a> { }) } - /// Parse a CTE (`alias [( col1, col2, ... )] AS (subquery)`) + /// Parse the expression in a `WITH` clause. + pub fn parse_with_expression(&mut self) -> Result { + Ok(if self.dialect.supports_common_scalar_expressions() { + if let Some(cse) = self.maybe_parse(|parser| parser.parse_cse())? { + WithExpression::Cse(cse) + } else { + WithExpression::Cte(self.parse_cte()?) + } + } else { + WithExpression::Cte(self.parse_cte()?) + }) + } + + /// Parse a [`Cse`] in a `WITH` clause. + pub fn parse_cse(&mut self) -> Result { + let expr = self.parse_expr()?; + self.expect_keyword_is(Keyword::AS)?; + let ident = self.parse_identifier()?; + Ok(Cse { expr, ident }) + } + + /// Parse a [`Cte`] in a `WITH` clause. pub fn parse_cte(&mut self) -> Result { let name = self.parse_identifier()?; diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index bc1431f9c..deb49acc0 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -1729,6 +1729,65 @@ fn test_parse_not_null_in_column_options() { ); } +#[test] +fn parse_cse() { + clickhouse().verified_stmt("WITH x AS (SELECT 1) UPDATE t SET bar = (SELECT * FROM x)"); + + let with = concat!( + "WITH", + " toIntervalSecond(300) AS bucket_size,", + " toDateTime64(1735751460, 9) AS start_time,", + " toDateTime64(1735755060, 9) AS end_time ", + "SELECT", + " toStartOfInterval(EventTime, bucket_size) AS bucket,", + " count() AS count ", + "FROM logs", + ); + clickhouse().verified_query(with); + + let mixed = concat!( + "WITH", + " toDate(now()) AS today,", + " tbl (c) AS (SELECT toDate('2000-01-01')) ", + "SELECT", + " * ", + "FROM tbl ", + "WHERE c < today" + ); + clickhouse().verified_query(mixed); + + // valid + clickhouse() + .parse_sql_statements("WITH foo() AS bar SELECT 1") + .unwrap(); + + // ClickHouse allows these, but not sqlparser + clickhouse() + .parse_sql_statements("WITH foo, bar SELECT 1") + .expect_err("Expected: AS, found: ,"); + + clickhouse() + .parse_sql_statements("WITH foo(), bar SELECT 1") + .expect_err("Expected: identifier, found: )"); + + // invalid + clickhouse() + .parse_sql_statements("WITH foo bar SELECT 1") + .expect_err("Expected: "); + + clickhouse() + .parse_sql_statements("WITH foo() bar SELECT 1") + .expect_err("Expected: "); + + clickhouse() + .parse_sql_statements("WITH foo() bar() SELECT 1") + .expect_err("Expected: "); + + clickhouse() + .parse_sql_statements("WITH foo() AS bar() SELECT 1") + .expect_err("Expected: "); +} + fn clickhouse() -> TestedDialects { TestedDialects::new(vec![Box::new(ClickHouseDialect {})]) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index aa47c0f7a..78969bb4c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7537,7 +7537,7 @@ fn parse_ctes() { fn assert_ctes_in_select(expected: &[&str], sel: &Query) { for (i, exp) in expected.iter().enumerate() { - let Cte { alias, query, .. } = &sel.with.as_ref().unwrap().cte_tables[i]; + let Cte { alias, query, .. } = &sel.with.as_ref().unwrap().cte_tables[i].cte().unwrap(); assert_eq!(*exp, query.to_string()); assert_eq!( if i == 0 { @@ -7580,7 +7580,10 @@ fn parse_ctes() { // CTE in a CTE... let sql = &format!("WITH outer_cte AS ({with}) SELECT * FROM outer_cte"); let select = verified_query(sql); - assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query); + assert_ctes_in_select( + &cte_sqls, + &only(&select.with.unwrap().cte_tables).cte().unwrap().query, + ); } #[test] @@ -7598,6 +7601,8 @@ fn parse_cte_renamed_columns() { .cte_tables .first() .unwrap() + .cte() + .unwrap() .alias .columns ); @@ -7628,7 +7633,7 @@ fn parse_recursive_cte() { materialized: None, closing_paren_token: AttachedToken::empty(), }; - assert_eq!(with.cte_tables.first().unwrap(), &expected); + assert_eq!(with.cte_tables.first().unwrap().cte().unwrap(), &expected); } #[test] @@ -17105,7 +17110,7 @@ fn test_parse_semantic_view_table_factor() { } let ast_sql = r#"SELECT * FROM SEMANTIC_VIEW( - my_model + my_model DIMENSIONS DATE_PART('year', date_col), region_name METRICS orders.revenue, orders.count WHERE active = true