Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 140 additions & 5 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,117 @@ where
DisplaySeparated { slice, sep: ", " }
}

pub struct DisplaySeparatedWithNewlines<'a, T>
where
T: fmt::Display + Spanned,
{
slice: &'a [T],
sep: &'static str,
last_span: Span,
}

impl<T> fmt::Display for DisplaySeparatedWithNewlines<'_, T>
where
T: fmt::Display + Spanned,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Initialize the last span to track where we left off in our previous display logic.
// We suppose we are at the start of a line, so we take the first item's starting position
let mut last_span = self.last_span;
if let Some(first) = self.slice.first() {
let first_span = first.span();
write_span_gap_lines(f, &mut last_span, first_span)?;
}
let mut delim = "";
for t in self.slice {
write!(f, "{delim}")?;
last_span.end.column += u64::try_from(delim.len()).unwrap_or(1);

let current_span = t.span();
write_span_gap(f, last_span, current_span)?;
write!(f, "{t}")?;
last_span = current_span;
delim = self.sep;
}
Ok(())
}
}

/// Write newlines and spaces between two spans
pub fn write_span_gap(
f: &mut fmt::Formatter,
mut last_span: Span,
current_span: Span,
) -> fmt::Result {
// write all the newlines between the last item and the current item
while last_span.end.line < current_span.start.line {
writeln!(f)?;
last_span.end.line += 1;
last_span.end.column = 1;
}
// write spaces between the last item and the current item
while last_span.end.column < current_span.start.column {
write!(f, " ")?;
last_span.end.column += 1;
}
Ok(())
}

/// Write newlines between two spans. If the two spans are on the same line, write a single space
pub fn write_span_gap_lines(
f: &mut fmt::Formatter,
last_span: &mut Span,
current_span: Span,
) -> fmt::Result {
let mut needs_space = true;
while last_span.end.line < current_span.start.line {
writeln!(f)?;
last_span.end.line += 1;
last_span.end.column = 1;
needs_space = false;
}
if needs_space {
write!(f, " ")?;
last_span.end.column += 1;
}
Ok(())
}

pub fn display_separated_with_newlines<'a, T>(
slice: &'a [T],
sep: &'static str,
last_span: Span,
) -> DisplaySeparatedWithNewlines<'a, T>
where
T: fmt::Display + Spanned,
{
DisplaySeparatedWithNewlines {
slice,
sep,
last_span,
}
}

pub fn display_comma_separated_with_newlines<T>(
slice: &[T],
last_span: Span,
) -> DisplaySeparatedWithNewlines<'_, T>
where
T: fmt::Display + Spanned,
{
// if we don't have span info, just add a space between the items
let sep = if slice.iter().all(|s| s.span() == Span::empty()) {
", "
} else {
","
};
DisplaySeparatedWithNewlines {
slice,
sep,
last_span,
}
}

/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -3763,21 +3874,44 @@ impl fmt::Display for Statement {
if let Some(or) = or {
write!(f, "{or} ")?;
}
let mut last_span = table.span();
write!(f, "{table}")?;
if let Some(UpdateTableFromKind::BeforeSet(from)) = from {
write!(f, " FROM {from}")?;
let from_span = from.span();
write_span_gap_lines(f, &mut last_span, from_span)?;
last_span = from_span;
write!(f, "FROM {from}")?;
}
if !assignments.is_empty() {
write!(f, " SET {}", display_comma_separated(assignments))?;
let assign_span = assignments.first().unwrap().span();
write_span_gap_lines(f, &mut last_span, assign_span)?;
last_span.end.column += 3;
write!(
f,
"SET{}",
display_comma_separated_with_newlines(assignments, last_span)
)?;
last_span = assignments.last().unwrap().span();
}
if let Some(UpdateTableFromKind::AfterSet(from)) = from {
write!(f, " FROM {from}")?;
write_span_gap_lines(f, &mut last_span, from.span())?;
last_span = from.span();
write!(f, "FROM {from}")?;
}
if let Some(selection) = selection {
write!(f, " WHERE {selection}")?;
write_span_gap_lines(f, &mut last_span, selection.span())?;
last_span = selection.span();
write!(f, "WHERE {selection}")?;
}
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
let returning_span = returning.first().unwrap().span();
write_span_gap_lines(f, &mut last_span, returning_span)?;
last_span.end = returning_span.start;
write!(
f,
"RETURNING{}",
display_comma_separated_with_newlines(returning, last_span)
)?;
}
Ok(())
}
Expand Down Expand Up @@ -5420,6 +5554,7 @@ impl fmt::Display for GrantObjects {
pub struct Assignment {
pub target: AssignmentTarget,
pub value: Expr,
pub span: Span,
}

impl fmt::Display for Assignment {
Expand Down
4 changes: 1 addition & 3 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,7 @@ impl Spanned for DoUpdate {

impl Spanned for Assignment {
fn span(&self) -> Span {
let Assignment { target, value } = self;

target.span().union(&value.span())
self.span
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12052,10 +12052,17 @@ impl<'a> Parser<'a> {

/// Parse a `var = expr` assignment, used in an UPDATE statement
pub fn parse_assignment(&mut self) -> Result<Assignment, ParserError> {
let start = self.peek_token().span.start;
let target = self.parse_assignment_target()?;
self.expect_token(&Token::Eq)?;
let value = self.parse_expr()?;
Ok(Assignment { target, value })
self.prev_token();
let end = self.next_token().span.end;
Ok(Assignment {
target,
value,
span: Span::new(start, end),
})
}

/// Parse the left-hand side of an assignment, used in an UPDATE statement
Expand Down
38 changes: 37 additions & 1 deletion src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ impl TestedDialects {
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
}

/// Parses a single SQL string into multiple statements, ensuring
/// the result is the same for all tested dialects.
pub fn parse_sql_statements_with_locations(
&self,
sql: &str,
) -> Result<Vec<Statement>, ParserError> {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
if let Some(options) = &self.options {
tokenizer = tokenizer.with_unescape(options.unescape);
}
let tokens = tokenizer.tokenize_with_location()?;
self.new_parser(dialect)
.with_tokens_with_locations(tokens)
.parse_statements()
})
// To fail the `ensure_multiple_dialects_are_tested` test:
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
}

/// Ensures that `sql` parses as a single [Statement] for all tested
/// dialects.
///
Expand All @@ -152,7 +172,7 @@ impl TestedDialects {
/// 2. re-serializing the result of parsing `sql` produces the same
/// `canonical` sql string
pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
let mut statements = self.parse_sql_statements_with_locations(sql).expect(sql);
assert_eq!(statements.len(), 1);

if !canonical.is_empty() && sql != canonical {
Expand All @@ -167,6 +187,17 @@ impl TestedDialects {
only_statement
}

/// Identical to `one_statement_parses_to`, but sets all locations to empty.
pub fn one_statement_parses_to_no_span(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
assert_eq!(statements.len(), 1);
let only_statement = statements.pop().unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}

/// Ensures that `sql` parses as an [`Expr`], and that
/// re-serializing the parse result produces canonical
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
Expand All @@ -184,6 +215,11 @@ impl TestedDialects {
self.one_statement_parses_to(sql, sql)
}

/// Identical to `verified_stmt`, but sets all locations to empty.
pub fn verified_stmt_no_span(&self, sql: &str) -> Statement {
self.one_statement_parses_to_no_span(sql, sql)
}

/// Ensures that `sql` parses as a single [Query], and that
/// re-serializing the parse result produces the same `sql`
/// string (is not modified after a serialization round-trip).
Expand Down
4 changes: 3 additions & 1 deletion tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1624,16 +1624,18 @@ fn parse_merge() {
let update_action = MergeAction::Update {
assignments: vec![
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("a")])),
value: Expr::Value(number("1")),
},
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("b")])),
value: Expr::Value(number("2")),
},
],
};
match bigquery_and_generic().verified_stmt(sql) {
match bigquery_and_generic().verified_stmt_no_span(sql) {
Statement::Merge {
into,
table,
Expand Down
24 changes: 20 additions & 4 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,25 +297,30 @@ fn parse_update() {
match verified_stmt(sql) {
Statement::Update {
table,
assignments,
mut assignments,
selection,
..
} => {
assert_eq!(table.to_string(), "t".to_string());
// remove the span from the assignments before comparison
assignments.iter_mut().for_each(|a| a.span = Span::empty());
assert_eq!(
assignments,
vec![
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["a".into()])),
value: Expr::Value(number("1")),
span: Span::empty(),
},
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["b".into()])),
value: Expr::Value(number("2")),
span: Span::empty(),
},
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["c".into()])),
value: Expr::Value(number("3")),
span: Span::empty(),
},
]
);
Expand Down Expand Up @@ -354,7 +359,7 @@ fn parse_update_set_from() {
Box::new(MsSqlDialect {}),
Box::new(SQLiteDialect {}),
]);
let stmt = dialects.verified_stmt(sql);
let stmt = dialects.verified_stmt_no_span(sql);
assert_eq!(
stmt,
Statement::Update {
Expand All @@ -363,6 +368,7 @@ fn parse_update_set_from() {
joins: vec![],
},
assignments: vec![Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])),
value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")])
}],
Expand Down Expand Up @@ -439,7 +445,7 @@ fn parse_update_set_from() {
#[test]
fn parse_update_with_table_alias() {
let sql = "UPDATE users AS u SET u.username = 'new_user' WHERE u.username = 'old_user'";
match verified_stmt(sql) {
match verified_stmt_no_span(sql) {
Statement::Update {
table,
assignments,
Expand Down Expand Up @@ -470,6 +476,7 @@ fn parse_update_with_table_alias() {
);
assert_eq!(
vec![Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("u"),
Ident::new("username")
Expand Down Expand Up @@ -8529,7 +8536,10 @@ fn test_revoke() {
fn parse_merge() {
let sql = "MERGE INTO s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
let sql_no_into = "MERGE s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
match (verified_stmt(sql), verified_stmt(sql_no_into)) {
match (
verified_stmt_no_span(sql),
verified_stmt_no_span(sql_no_into),
) {
(
Statement::Merge {
into,
Expand Down Expand Up @@ -8698,6 +8708,7 @@ fn parse_merge() {
action: MergeAction::Update {
assignments: vec![
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("dest"),
Ident::new("F")
Expand All @@ -8708,6 +8719,7 @@ fn parse_merge() {
]),
},
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("dest"),
Ident::new("G")
Expand Down Expand Up @@ -8992,6 +9004,10 @@ fn verified_stmt(query: &str) -> Statement {
all_dialects().verified_stmt(query)
}

fn verified_stmt_no_span(query: &str) -> Statement {
all_dialects().verified_stmt_no_span(query)
}

fn verified_query(query: &str) -> Query {
all_dialects().verified_query(query)
}
Expand Down
Loading
Loading