Skip to content

Commit 31d057a

Browse files
committed
preserve sql formatting through a parse + display roundtrip (partial implementation)
this implements (a tiny portion of) apache#1634 pros: really useful when passing formatted queries to a real database, in order for database error message locations to match the original user's source locations cons: if we want to do it well, we need to track source locations better, and this adds a complexity to the Display imlementations
1 parent 94ea206 commit 31d057a

File tree

10 files changed

+303
-36
lines changed

10 files changed

+303
-36
lines changed

src/ast/mod.rs

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,117 @@ where
145145
DisplaySeparated { slice, sep: ", " }
146146
}
147147

148+
pub struct DisplaySeparatedWithNewlines<'a, T>
149+
where
150+
T: fmt::Display + Spanned,
151+
{
152+
slice: &'a [T],
153+
sep: &'static str,
154+
last_span: Span,
155+
}
156+
157+
impl<T> fmt::Display for DisplaySeparatedWithNewlines<'_, T>
158+
where
159+
T: fmt::Display + Spanned,
160+
{
161+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
162+
// Initialize the last span to track where we left off in our previous display logic.
163+
// We suppose we are at the start of a line, so we take the first item's starting position
164+
let mut last_span = self.last_span;
165+
if let Some(first) = self.slice.first() {
166+
let first_span = first.span();
167+
write_span_gap_lines(f, &mut last_span, first_span)?;
168+
}
169+
let mut delim = "";
170+
for t in self.slice {
171+
write!(f, "{delim}")?;
172+
last_span.end.column += u64::try_from(delim.len()).unwrap_or(1);
173+
174+
let current_span = t.span();
175+
write_span_gap(f, last_span, current_span)?;
176+
write!(f, "{t}")?;
177+
last_span = current_span;
178+
delim = self.sep;
179+
}
180+
Ok(())
181+
}
182+
}
183+
184+
/// Write newlines and spaces between two spans
185+
pub fn write_span_gap(
186+
f: &mut fmt::Formatter,
187+
mut last_span: Span,
188+
current_span: Span,
189+
) -> fmt::Result {
190+
// write all the newlines between the last item and the current item
191+
while last_span.end.line < current_span.start.line {
192+
writeln!(f)?;
193+
last_span.end.line += 1;
194+
last_span.end.column = 1;
195+
}
196+
// write spaces between the last item and the current item
197+
while last_span.end.column < current_span.start.column {
198+
write!(f, " ")?;
199+
last_span.end.column += 1;
200+
}
201+
Ok(())
202+
}
203+
204+
/// Write newlines between two spans. If the two spans are on the same line, write a single space
205+
pub fn write_span_gap_lines(
206+
f: &mut fmt::Formatter,
207+
last_span: &mut Span,
208+
current_span: Span,
209+
) -> fmt::Result {
210+
let mut needs_space = true;
211+
while last_span.end.line < current_span.start.line {
212+
writeln!(f)?;
213+
last_span.end.line += 1;
214+
last_span.end.column = 1;
215+
needs_space = false;
216+
}
217+
if needs_space {
218+
write!(f, " ")?;
219+
last_span.end.column += 1;
220+
}
221+
Ok(())
222+
}
223+
224+
pub fn display_separated_with_newlines<'a, T>(
225+
slice: &'a [T],
226+
sep: &'static str,
227+
last_span: Span,
228+
) -> DisplaySeparatedWithNewlines<'a, T>
229+
where
230+
T: fmt::Display + Spanned,
231+
{
232+
DisplaySeparatedWithNewlines {
233+
slice,
234+
sep,
235+
last_span,
236+
}
237+
}
238+
239+
pub fn display_comma_separated_with_newlines<T>(
240+
slice: &[T],
241+
last_span: Span,
242+
) -> DisplaySeparatedWithNewlines<'_, T>
243+
where
244+
T: fmt::Display + Spanned,
245+
{
246+
// if we don't have span info, just add a space between the items
247+
let sep = if slice.iter().all(|s| s.span() == Span::empty()) {
248+
", "
249+
} else {
250+
","
251+
};
252+
DisplaySeparatedWithNewlines {
253+
slice,
254+
sep,
255+
last_span,
256+
}
257+
}
258+
148259
/// An identifier, decomposed into its value or character data and the quote style.
149260
#[derive(Debug, Clone, PartialOrd, Ord)]
150261
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -3763,21 +3874,44 @@ impl fmt::Display for Statement {
37633874
if let Some(or) = or {
37643875
write!(f, "{or} ")?;
37653876
}
3877+
let mut last_span = table.span();
37663878
write!(f, "{table}")?;
37673879
if let Some(UpdateTableFromKind::BeforeSet(from)) = from {
3768-
write!(f, " FROM {from}")?;
3880+
let from_span = from.span();
3881+
write_span_gap_lines(f, &mut last_span, from_span)?;
3882+
last_span = from_span;
3883+
write!(f, "FROM {from}")?;
37693884
}
37703885
if !assignments.is_empty() {
3771-
write!(f, " SET {}", display_comma_separated(assignments))?;
3886+
let assign_span = assignments.first().unwrap().span();
3887+
write_span_gap_lines(f, &mut last_span, assign_span)?;
3888+
last_span.end.column += 3;
3889+
write!(
3890+
f,
3891+
"SET{}",
3892+
display_comma_separated_with_newlines(assignments, last_span)
3893+
)?;
3894+
last_span = assignments.last().unwrap().span();
37723895
}
37733896
if let Some(UpdateTableFromKind::AfterSet(from)) = from {
3774-
write!(f, " FROM {from}")?;
3897+
write_span_gap_lines(f, &mut last_span, from.span())?;
3898+
last_span = from.span();
3899+
write!(f, "FROM {from}")?;
37753900
}
37763901
if let Some(selection) = selection {
3777-
write!(f, " WHERE {selection}")?;
3902+
write_span_gap_lines(f, &mut last_span, selection.span())?;
3903+
last_span = selection.span();
3904+
write!(f, "WHERE {selection}")?;
37783905
}
37793906
if let Some(returning) = returning {
3780-
write!(f, " RETURNING {}", display_comma_separated(returning))?;
3907+
let returning_span = returning.first().unwrap().span();
3908+
write_span_gap_lines(f, &mut last_span, returning_span)?;
3909+
last_span.end = returning_span.start;
3910+
write!(
3911+
f,
3912+
"RETURNING{}",
3913+
display_comma_separated_with_newlines(returning, last_span)
3914+
)?;
37813915
}
37823916
Ok(())
37833917
}
@@ -5420,6 +5554,7 @@ impl fmt::Display for GrantObjects {
54205554
pub struct Assignment {
54215555
pub target: AssignmentTarget,
54225556
pub value: Expr,
5557+
pub span: Span,
54235558
}
54245559

54255560
impl fmt::Display for Assignment {

src/ast/spans.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,9 +1229,7 @@ impl Spanned for DoUpdate {
12291229

12301230
impl Spanned for Assignment {
12311231
fn span(&self) -> Span {
1232-
let Assignment { target, value } = self;
1233-
1234-
target.span().union(&value.span())
1232+
self.span
12351233
}
12361234
}
12371235

src/parser/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12052,10 +12052,17 @@ impl<'a> Parser<'a> {
1205212052

1205312053
/// Parse a `var = expr` assignment, used in an UPDATE statement
1205412054
pub fn parse_assignment(&mut self) -> Result<Assignment, ParserError> {
12055+
let start = self.peek_token().span.start;
1205512056
let target = self.parse_assignment_target()?;
1205612057
self.expect_token(&Token::Eq)?;
1205712058
let value = self.parse_expr()?;
12058-
Ok(Assignment { target, value })
12059+
self.prev_token();
12060+
let end = self.next_token().span.end;
12061+
Ok(Assignment {
12062+
target,
12063+
value,
12064+
span: Span::new(start, end),
12065+
})
1205912066
}
1206012067

1206112068
/// Parse the left-hand side of an assignment, used in an UPDATE statement

src/test_utils.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,26 @@ impl TestedDialects {
135135
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
136136
}
137137

138+
/// Parses a single SQL string into multiple statements, ensuring
139+
/// the result is the same for all tested dialects.
140+
pub fn parse_sql_statements_with_locations(
141+
&self,
142+
sql: &str,
143+
) -> Result<Vec<Statement>, ParserError> {
144+
self.one_of_identical_results(|dialect| {
145+
let mut tokenizer = Tokenizer::new(dialect, sql);
146+
if let Some(options) = &self.options {
147+
tokenizer = tokenizer.with_unescape(options.unescape);
148+
}
149+
let tokens = tokenizer.tokenize_with_location()?;
150+
self.new_parser(dialect)
151+
.with_tokens_with_locations(tokens)
152+
.parse_statements()
153+
})
154+
// To fail the `ensure_multiple_dialects_are_tested` test:
155+
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
156+
}
157+
138158
/// Ensures that `sql` parses as a single [Statement] for all tested
139159
/// dialects.
140160
///
@@ -152,7 +172,7 @@ impl TestedDialects {
152172
/// 2. re-serializing the result of parsing `sql` produces the same
153173
/// `canonical` sql string
154174
pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
155-
let mut statements = self.parse_sql_statements(sql).expect(sql);
175+
let mut statements = self.parse_sql_statements_with_locations(sql).expect(sql);
156176
assert_eq!(statements.len(), 1);
157177

158178
if !canonical.is_empty() && sql != canonical {
@@ -167,6 +187,17 @@ impl TestedDialects {
167187
only_statement
168188
}
169189

190+
/// Identical to `one_statement_parses_to`, but sets all locations to empty.
191+
pub fn one_statement_parses_to_no_span(&self, sql: &str, canonical: &str) -> Statement {
192+
let mut statements = self.parse_sql_statements(sql).expect(sql);
193+
assert_eq!(statements.len(), 1);
194+
let only_statement = statements.pop().unwrap();
195+
if !canonical.is_empty() {
196+
assert_eq!(canonical, only_statement.to_string())
197+
}
198+
only_statement
199+
}
200+
170201
/// Ensures that `sql` parses as an [`Expr`], and that
171202
/// re-serializing the parse result produces canonical
172203
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
@@ -184,6 +215,11 @@ impl TestedDialects {
184215
self.one_statement_parses_to(sql, sql)
185216
}
186217

218+
/// Identical to `verified_stmt`, but sets all locations to empty.
219+
pub fn verified_stmt_no_span(&self, sql: &str) -> Statement {
220+
self.one_statement_parses_to_no_span(sql, sql)
221+
}
222+
187223
/// Ensures that `sql` parses as a single [Query], and that
188224
/// re-serializing the parse result produces the same `sql`
189225
/// string (is not modified after a serialization round-trip).

tests/sqlparser_bigquery.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1624,16 +1624,18 @@ fn parse_merge() {
16241624
let update_action = MergeAction::Update {
16251625
assignments: vec![
16261626
Assignment {
1627+
span: Span::empty(),
16271628
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("a")])),
16281629
value: Expr::Value(number("1")),
16291630
},
16301631
Assignment {
1632+
span: Span::empty(),
16311633
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("b")])),
16321634
value: Expr::Value(number("2")),
16331635
},
16341636
],
16351637
};
1636-
match bigquery_and_generic().verified_stmt(sql) {
1638+
match bigquery_and_generic().verified_stmt_no_span(sql) {
16371639
Statement::Merge {
16381640
into,
16391641
table,

tests/sqlparser_common.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,25 +297,30 @@ fn parse_update() {
297297
match verified_stmt(sql) {
298298
Statement::Update {
299299
table,
300-
assignments,
300+
mut assignments,
301301
selection,
302302
..
303303
} => {
304304
assert_eq!(table.to_string(), "t".to_string());
305+
// remove the span from the assignments before comparison
306+
assignments.iter_mut().for_each(|a| a.span = Span::empty());
305307
assert_eq!(
306308
assignments,
307309
vec![
308310
Assignment {
309311
target: AssignmentTarget::ColumnName(ObjectName(vec!["a".into()])),
310312
value: Expr::Value(number("1")),
313+
span: Span::empty(),
311314
},
312315
Assignment {
313316
target: AssignmentTarget::ColumnName(ObjectName(vec!["b".into()])),
314317
value: Expr::Value(number("2")),
318+
span: Span::empty(),
315319
},
316320
Assignment {
317321
target: AssignmentTarget::ColumnName(ObjectName(vec!["c".into()])),
318322
value: Expr::Value(number("3")),
323+
span: Span::empty(),
319324
},
320325
]
321326
);
@@ -354,7 +359,7 @@ fn parse_update_set_from() {
354359
Box::new(MsSqlDialect {}),
355360
Box::new(SQLiteDialect {}),
356361
]);
357-
let stmt = dialects.verified_stmt(sql);
362+
let stmt = dialects.verified_stmt_no_span(sql);
358363
assert_eq!(
359364
stmt,
360365
Statement::Update {
@@ -363,6 +368,7 @@ fn parse_update_set_from() {
363368
joins: vec![],
364369
},
365370
assignments: vec![Assignment {
371+
span: Span::empty(),
366372
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])),
367373
value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")])
368374
}],
@@ -439,7 +445,7 @@ fn parse_update_set_from() {
439445
#[test]
440446
fn parse_update_with_table_alias() {
441447
let sql = "UPDATE users AS u SET u.username = 'new_user' WHERE u.username = 'old_user'";
442-
match verified_stmt(sql) {
448+
match verified_stmt_no_span(sql) {
443449
Statement::Update {
444450
table,
445451
assignments,
@@ -470,6 +476,7 @@ fn parse_update_with_table_alias() {
470476
);
471477
assert_eq!(
472478
vec![Assignment {
479+
span: Span::empty(),
473480
target: AssignmentTarget::ColumnName(ObjectName(vec![
474481
Ident::new("u"),
475482
Ident::new("username")
@@ -8529,7 +8536,10 @@ fn test_revoke() {
85298536
fn parse_merge() {
85308537
let sql = "MERGE INTO s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
85318538
let sql_no_into = "MERGE s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
8532-
match (verified_stmt(sql), verified_stmt(sql_no_into)) {
8539+
match (
8540+
verified_stmt_no_span(sql),
8541+
verified_stmt_no_span(sql_no_into),
8542+
) {
85338543
(
85348544
Statement::Merge {
85358545
into,
@@ -8698,6 +8708,7 @@ fn parse_merge() {
86988708
action: MergeAction::Update {
86998709
assignments: vec![
87008710
Assignment {
8711+
span: Span::empty(),
87018712
target: AssignmentTarget::ColumnName(ObjectName(vec![
87028713
Ident::new("dest"),
87038714
Ident::new("F")
@@ -8708,6 +8719,7 @@ fn parse_merge() {
87088719
]),
87098720
},
87108721
Assignment {
8722+
span: Span::empty(),
87118723
target: AssignmentTarget::ColumnName(ObjectName(vec![
87128724
Ident::new("dest"),
87138725
Ident::new("G")
@@ -8992,6 +9004,10 @@ fn verified_stmt(query: &str) -> Statement {
89929004
all_dialects().verified_stmt(query)
89939005
}
89949006

9007+
fn verified_stmt_no_span(query: &str) -> Statement {
9008+
all_dialects().verified_stmt_no_span(query)
9009+
}
9010+
89959011
fn verified_query(query: &str) -> Query {
89969012
all_dialects().verified_query(query)
89979013
}

0 commit comments

Comments
 (0)