diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 756774353..fe93ee91f 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -145,6 +145,117 @@ where DisplaySeparated { slice, sep: ", " } } +pub struct DisplaySeparatedWithNewlines<'a, T> +where + T: fmt::Display + Spanned, +{ + slice: &'a [T], + sep: &'static str, + last_span: Span, +} + +impl fmt::Display for DisplaySeparatedWithNewlines<'_, T> +where + T: fmt::Display + Spanned, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Initialize the last span to track where we left off in our previous display logic. + // We suppose we are at the start of a line, so we take the first item's starting position + let mut last_span = self.last_span; + if let Some(first) = self.slice.first() { + let first_span = first.span(); + write_span_gap_lines(f, &mut last_span, first_span)?; + } + let mut delim = ""; + for t in self.slice { + write!(f, "{delim}")?; + last_span.end.column += u64::try_from(delim.len()).unwrap_or(1); + + let current_span = t.span(); + write_span_gap(f, last_span, current_span)?; + write!(f, "{t}")?; + last_span = current_span; + delim = self.sep; + } + Ok(()) + } +} + +/// Write newlines and spaces between two spans +pub fn write_span_gap( + f: &mut fmt::Formatter, + mut last_span: Span, + current_span: Span, +) -> fmt::Result { + // write all the newlines between the last item and the current item + while last_span.end.line < current_span.start.line { + writeln!(f)?; + last_span.end.line += 1; + last_span.end.column = 1; + } + // write spaces between the last item and the current item + while last_span.end.column < current_span.start.column { + write!(f, " ")?; + last_span.end.column += 1; + } + Ok(()) +} + +/// Write newlines between two spans. If the two spans are on the same line, write a single space +pub fn write_span_gap_lines( + f: &mut fmt::Formatter, + last_span: &mut Span, + current_span: Span, +) -> fmt::Result { + let mut needs_space = true; + while last_span.end.line < current_span.start.line { + writeln!(f)?; + last_span.end.line += 1; + last_span.end.column = 1; + needs_space = false; + } + if needs_space { + write!(f, " ")?; + last_span.end.column += 1; + } + Ok(()) +} + +pub fn display_separated_with_newlines<'a, T>( + slice: &'a [T], + sep: &'static str, + last_span: Span, +) -> DisplaySeparatedWithNewlines<'a, T> +where + T: fmt::Display + Spanned, +{ + DisplaySeparatedWithNewlines { + slice, + sep, + last_span, + } +} + +pub fn display_comma_separated_with_newlines( + slice: &[T], + last_span: Span, +) -> DisplaySeparatedWithNewlines<'_, T> +where + T: fmt::Display + Spanned, +{ + // if we don't have span info, just add a space between the items + let sep = if slice.iter().all(|s| s.span() == Span::empty()) { + ", " + } else { + "," + }; + DisplaySeparatedWithNewlines { + slice, + sep, + last_span, + } +} + /// An identifier, decomposed into its value or character data and the quote style. #[derive(Debug, Clone, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -3763,21 +3874,44 @@ impl fmt::Display for Statement { if let Some(or) = or { write!(f, "{or} ")?; } + let mut last_span = table.span(); write!(f, "{table}")?; if let Some(UpdateTableFromKind::BeforeSet(from)) = from { - write!(f, " FROM {from}")?; + let from_span = from.span(); + write_span_gap_lines(f, &mut last_span, from_span)?; + last_span = from_span; + write!(f, "FROM {from}")?; } if !assignments.is_empty() { - write!(f, " SET {}", display_comma_separated(assignments))?; + let assign_span = assignments.first().unwrap().span(); + write_span_gap_lines(f, &mut last_span, assign_span)?; + last_span.end.column += 3; + write!( + f, + "SET{}", + display_comma_separated_with_newlines(assignments, last_span) + )?; + last_span = assignments.last().unwrap().span(); } if let Some(UpdateTableFromKind::AfterSet(from)) = from { - write!(f, " FROM {from}")?; + write_span_gap_lines(f, &mut last_span, from.span())?; + last_span = from.span(); + write!(f, "FROM {from}")?; } if let Some(selection) = selection { - write!(f, " WHERE {selection}")?; + write_span_gap_lines(f, &mut last_span, selection.span())?; + last_span = selection.span(); + write!(f, "WHERE {selection}")?; } if let Some(returning) = returning { - write!(f, " RETURNING {}", display_comma_separated(returning))?; + let returning_span = returning.first().unwrap().span(); + write_span_gap_lines(f, &mut last_span, returning_span)?; + last_span.end = returning_span.start; + write!( + f, + "RETURNING{}", + display_comma_separated_with_newlines(returning, last_span) + )?; } Ok(()) } @@ -5420,6 +5554,7 @@ impl fmt::Display for GrantObjects { pub struct Assignment { pub target: AssignmentTarget, pub value: Expr, + pub span: Span, } impl fmt::Display for Assignment { diff --git a/src/ast/spans.rs b/src/ast/spans.rs index dad0c5379..5eaf1dfd8 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1229,9 +1229,7 @@ impl Spanned for DoUpdate { impl Spanned for Assignment { fn span(&self) -> Span { - let Assignment { target, value } = self; - - target.span().union(&value.span()) + self.span } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 47d4d6f0d..56101df29 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -12052,10 +12052,17 @@ impl<'a> Parser<'a> { /// Parse a `var = expr` assignment, used in an UPDATE statement pub fn parse_assignment(&mut self) -> Result { + let start = self.peek_token().span.start; let target = self.parse_assignment_target()?; self.expect_token(&Token::Eq)?; let value = self.parse_expr()?; - Ok(Assignment { target, value }) + self.prev_token(); + let end = self.next_token().span.end; + Ok(Assignment { + target, + value, + span: Span::new(start, end), + }) } /// Parse the left-hand side of an assignment, used in an UPDATE statement diff --git a/src/test_utils.rs b/src/test_utils.rs index e76cdb87a..a0f65a4f3 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -135,6 +135,26 @@ impl TestedDialects { // Parser::parse_sql(&**self.dialects.first().unwrap(), sql) } + /// Parses a single SQL string into multiple statements, ensuring + /// the result is the same for all tested dialects. + pub fn parse_sql_statements_with_locations( + &self, + sql: &str, + ) -> Result, ParserError> { + self.one_of_identical_results(|dialect| { + let mut tokenizer = Tokenizer::new(dialect, sql); + if let Some(options) = &self.options { + tokenizer = tokenizer.with_unescape(options.unescape); + } + let tokens = tokenizer.tokenize_with_location()?; + self.new_parser(dialect) + .with_tokens_with_locations(tokens) + .parse_statements() + }) + // To fail the `ensure_multiple_dialects_are_tested` test: + // Parser::parse_sql(&**self.dialects.first().unwrap(), sql) + } + /// Ensures that `sql` parses as a single [Statement] for all tested /// dialects. /// @@ -152,7 +172,7 @@ impl TestedDialects { /// 2. re-serializing the result of parsing `sql` produces the same /// `canonical` sql string pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement { - let mut statements = self.parse_sql_statements(sql).expect(sql); + let mut statements = self.parse_sql_statements_with_locations(sql).expect(sql); assert_eq!(statements.len(), 1); if !canonical.is_empty() && sql != canonical { @@ -167,6 +187,17 @@ impl TestedDialects { only_statement } + /// Identical to `one_statement_parses_to`, but sets all locations to empty. + pub fn one_statement_parses_to_no_span(&self, sql: &str, canonical: &str) -> Statement { + let mut statements = self.parse_sql_statements(sql).expect(sql); + assert_eq!(statements.len(), 1); + let only_statement = statements.pop().unwrap(); + if !canonical.is_empty() { + assert_eq!(canonical, only_statement.to_string()) + } + only_statement + } + /// Ensures that `sql` parses as an [`Expr`], and that /// re-serializing the parse result produces canonical pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr { @@ -184,6 +215,11 @@ impl TestedDialects { self.one_statement_parses_to(sql, sql) } + /// Identical to `verified_stmt`, but sets all locations to empty. + pub fn verified_stmt_no_span(&self, sql: &str) -> Statement { + self.one_statement_parses_to_no_span(sql, sql) + } + /// Ensures that `sql` parses as a single [Query], and that /// re-serializing the parse result produces the same `sql` /// string (is not modified after a serialization round-trip). diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 9dfabc014..30cb447f0 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -1624,16 +1624,18 @@ fn parse_merge() { let update_action = MergeAction::Update { assignments: vec![ Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("a")])), value: Expr::Value(number("1")), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("b")])), value: Expr::Value(number("2")), }, ], }; - match bigquery_and_generic().verified_stmt(sql) { + match bigquery_and_generic().verified_stmt_no_span(sql) { Statement::Merge { into, table, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 3c2e0899f..8d8881955 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -297,25 +297,30 @@ fn parse_update() { match verified_stmt(sql) { Statement::Update { table, - assignments, + mut assignments, selection, .. } => { assert_eq!(table.to_string(), "t".to_string()); + // remove the span from the assignments before comparison + assignments.iter_mut().for_each(|a| a.span = Span::empty()); assert_eq!( assignments, vec![ Assignment { target: AssignmentTarget::ColumnName(ObjectName(vec!["a".into()])), value: Expr::Value(number("1")), + span: Span::empty(), }, Assignment { target: AssignmentTarget::ColumnName(ObjectName(vec!["b".into()])), value: Expr::Value(number("2")), + span: Span::empty(), }, Assignment { target: AssignmentTarget::ColumnName(ObjectName(vec!["c".into()])), value: Expr::Value(number("3")), + span: Span::empty(), }, ] ); @@ -354,7 +359,7 @@ fn parse_update_set_from() { Box::new(MsSqlDialect {}), Box::new(SQLiteDialect {}), ]); - let stmt = dialects.verified_stmt(sql); + let stmt = dialects.verified_stmt_no_span(sql); assert_eq!( stmt, Statement::Update { @@ -363,6 +368,7 @@ fn parse_update_set_from() { joins: vec![], }, assignments: vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])), value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) }], @@ -439,7 +445,7 @@ fn parse_update_set_from() { #[test] fn parse_update_with_table_alias() { let sql = "UPDATE users AS u SET u.username = 'new_user' WHERE u.username = 'old_user'"; - match verified_stmt(sql) { + match verified_stmt_no_span(sql) { Statement::Update { table, assignments, @@ -470,6 +476,7 @@ fn parse_update_with_table_alias() { ); assert_eq!( vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![ Ident::new("u"), Ident::new("username") @@ -8529,7 +8536,10 @@ fn test_revoke() { fn parse_merge() { let sql = "MERGE INTO s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE"; let sql_no_into = "MERGE s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE"; - match (verified_stmt(sql), verified_stmt(sql_no_into)) { + match ( + verified_stmt_no_span(sql), + verified_stmt_no_span(sql_no_into), + ) { ( Statement::Merge { into, @@ -8698,6 +8708,7 @@ fn parse_merge() { action: MergeAction::Update { assignments: vec![ Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![ Ident::new("dest"), Ident::new("F") @@ -8708,6 +8719,7 @@ fn parse_merge() { ]), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![ Ident::new("dest"), Ident::new("G") @@ -8992,6 +9004,10 @@ fn verified_stmt(query: &str) -> Statement { all_dialects().verified_stmt(query) } +fn verified_stmt_no_span(query: &str) -> Statement { + all_dialects().verified_stmt_no_span(query) +} + fn verified_query(query: &str) -> Query { all_dialects().verified_query(query) } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 4a4e79611..3d46de5f1 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -1826,40 +1826,50 @@ fn parse_insert_with_on_duplicate_update() { })), source ); + let Some(OnInsert::DuplicateKeyUpdate(mut assignments)) = on else { + unreachable!("expected duplicate key update"); + }; + // remove the span from the assignments before comparison + assignments.iter_mut().for_each(|a| a.span = Span::empty()); assert_eq!( - Some(OnInsert::DuplicateKeyUpdate(vec![ + vec![ Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( "description".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("description"))]), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( "perm_create".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_create"))]), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( "perm_read".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_read"))]), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( "perm_update".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_update"))]), }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new( "perm_delete".to_string() )])), value: call("VALUES", [Expr::Identifier(Ident::new("perm_delete"))]), }, - ])), - on + ], + assignments ); } _ => unreachable!(), @@ -1986,7 +1996,7 @@ fn parse_update_with_joins() { match mysql().verified_stmt(sql) { Statement::Update { table, - assignments, + mut assignments, from: _from, selection, returning, @@ -2039,8 +2049,11 @@ fn parse_update_with_joins() { }, table ); + // remove the span from the assignments before comparison + assignments.iter_mut().for_each(|a| a.span = Span::empty()); assert_eq!( vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec![ Ident::new("o"), Ident::new("completed") diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index fd520d507..34beb0908 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1774,7 +1774,7 @@ fn parse_prepare() { #[test] fn parse_pg_on_conflict() { - let stmt = pg_and_generic().verified_stmt( + let stmt = pg_and_generic().verified_stmt_no_span( "INSERT INTO distributors (did, dname) \ VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ ON CONFLICT(did) \ @@ -1793,6 +1793,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) },], @@ -1804,7 +1805,7 @@ fn parse_pg_on_conflict() { _ => unreachable!(), }; - let stmt = pg_and_generic().verified_stmt( + let stmt = pg_and_generic().verified_stmt_no_span( "INSERT INTO distributors (did, dname, area) \ VALUES (5, 'Gizmo Transglobal', 'Mars'), (6, 'Associated Computing, Inc', 'Venus') \ ON CONFLICT(did, area) \ @@ -1824,6 +1825,7 @@ fn parse_pg_on_conflict() { OnConflictAction::DoUpdate(DoUpdate { assignments: vec![ Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::CompoundIdentifier(vec![ "EXCLUDED".into(), @@ -1831,6 +1833,7 @@ fn parse_pg_on_conflict() { ]) }, Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec!["area".into()])), value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()]) }, @@ -1862,7 +1865,7 @@ fn parse_pg_on_conflict() { _ => unreachable!(), }; - let stmt = pg_and_generic().verified_stmt( + let stmt = pg_and_generic().verified_stmt_no_span( "INSERT INTO distributors (did, dname, dsize) \ VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ ON CONFLICT(did) \ @@ -1881,6 +1884,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::Value(Value::Placeholder("$1".to_string())) },], @@ -1900,7 +1904,7 @@ fn parse_pg_on_conflict() { _ => unreachable!(), }; - let stmt = pg_and_generic().verified_stmt( + let stmt = pg_and_generic().verified_stmt_no_span( "INSERT INTO distributors (did, dname, dsize) \ VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ ON CONFLICT ON CONSTRAINT distributors_did_pkey \ @@ -1919,6 +1923,7 @@ fn parse_pg_on_conflict() { assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { + span: Span::empty(), target: AssignmentTarget::ColumnName(ObjectName(vec!["dname".into()])), value: Expr::Value(Value::Placeholder("$1".to_string())) },], @@ -3728,16 +3733,8 @@ fn parse_delimited_identifiers() { #[test] fn parse_update_has_keyword() { - pg().one_statement_parses_to( - r#"UPDATE test SET name=$1, - value=$2, - where=$3, - create=$4, - is_default=$5, - classification=$6, - sort=$7 - WHERE id=$8"#, - r#"UPDATE test SET name = $1, value = $2, where = $3, create = $4, is_default = $5, classification = $6, sort = $7 WHERE id = $8"# + pg().verified_stmt( + r#"UPDATE test SET name = $1, value = $2, where = $3, create = $4, is_default = $5, classification = $6, sort = $7 WHERE id = $8"#, ); } diff --git a/tests/sqlparser_sqlite.rs b/tests/sqlparser_sqlite.rs index 0adf7f755..a067008b0 100644 --- a/tests/sqlparser_sqlite.rs +++ b/tests/sqlparser_sqlite.rs @@ -30,7 +30,7 @@ use sqlparser::ast::Value::Placeholder; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, SQLiteDialect}; use sqlparser::parser::{ParserError, ParserOptions}; -use sqlparser::tokenizer::Token; +use sqlparser::tokenizer::{Span, Token}; #[test] fn pragma_no_value() { @@ -464,7 +464,7 @@ fn parse_attach_database() { fn parse_update_tuple_row_values() { // See https://github.com/sqlparser-rs/sqlparser-rs/issues/1311 assert_eq!( - sqlite().verified_stmt("UPDATE x SET (a, b) = (1, 2)"), + sqlite().verified_stmt_no_span("UPDATE x SET (a, b) = (1, 2)"), Statement::Update { or: None, assignments: vec![Assignment { @@ -475,7 +475,8 @@ fn parse_update_tuple_row_values() { value: Expr::Tuple(vec![ Expr::Value(Value::Number("1".parse().unwrap(), false)), Expr::Value(Value::Number("2".parse().unwrap(), false)) - ]) + ]), + span: Span::empty(), }], selection: None, table: TableWithJoins { diff --git a/tests/test_formatting.rs b/tests/test_formatting.rs new file mode 100644 index 000000000..23657e395 --- /dev/null +++ b/tests/test_formatting.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![warn(clippy::all)] +//! Test SQL syntax, which all sqlparser dialects must parse in the same way. +//! +//! Note that it does not mean all SQL here is valid in all the dialects, only +//! that 1) it's either standard or widely supported and 2) it can be parsed by +//! sqlparser regardless of the chosen dialect (i.e. it doesn't conflict with +//! dialect-specific parsing rules). + +extern crate core; + +use sqlparser::test_utils::all_dialects; + +#[test] +fn format_update_tuple_row_values() { + all_dialects().verified_stmt( + "\ +UPDATE x +SET (a, b) = (1, 2) +WHERE c = 3\ + ", + ); +} + +#[test] +fn format_update_multiple_sets_newlines() { + all_dialects().verified_stmt( + "\ +UPDATE x +SET a = 1, + b = 2, + c = 3 +WHERE d = 4\ + ", + ); +} + +#[test] +fn format_update_newline_before_where() { + all_dialects().verified_stmt( + "\ +UPDATE x SET x = 1 +WHERE c = 3\ + ", + ); +}