Skip to content

Commit d2b2866

Browse files
committed
Replace parallel condition/result vectors with single CaseWhen vector in Expr::Case
The primary motivation for this change is to fix the visitor traversal order for CASE expressions. In SQL, CASE expressions follow a specific syntactic order (e.g., `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5`), AST visitors now process nodes in the same order as they appear in the source code. The previous implementation, using separate `conditions` and `results` vectors, would visit all conditions first and then all results, which didn't match the source order. The new `CaseWhen` structure ensures visitors process expressions in the correct order: `a,1,2,3,4,5`. A secondary benefit is making invalid states unrepresentable in the type system. The previous implementation using parallel vectors (`conditions` and `results`) made it possible to create invalid CASE expressions where the number of conditions didn't match the number of results. When this happened, the `Display` implementation would silently drop elements from the longer list, potentially masking bugs. The new `CaseWhen` struct couples each condition with its result, making it impossible to create such mismatched states. While this is a breaking change to the AST structure, sqlparser has a history of making such changes when they improve correctness. I don't expect significant downstream breakages, and the benefits of correct visitor ordering and type safety are significant, so I think the trade-off is worthwhile.
1 parent b4b5576 commit d2b2866

File tree

5 files changed

+96
-51
lines changed

5 files changed

+96
-51
lines changed

src/ast/mod.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,22 @@ pub enum CeilFloorKind {
566566
Scale(Value),
567567
}
568568

569+
/// A WHEN clause in a CASE expression containing both
570+
/// the condition and its corresponding result
571+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
572+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
573+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
574+
pub struct CaseWhen {
575+
pub condition: Expr,
576+
pub result: Expr,
577+
}
578+
579+
impl fmt::Display for CaseWhen {
580+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
581+
write!(f, "WHEN {} THEN {}", self.condition, self.result)
582+
}
583+
}
584+
569585
/// An SQL expression of any type.
570586
///
571587
/// # Semantics / Type Checking
@@ -896,8 +912,7 @@ pub enum Expr {
896912
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
897913
Case {
898914
operand: Option<Box<Expr>>,
899-
conditions: Vec<Expr>,
900-
results: Vec<Expr>,
915+
conditions: Vec<CaseWhen>,
901916
else_result: Option<Box<Expr>>,
902917
},
903918
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
@@ -1572,17 +1587,15 @@ impl fmt::Display for Expr {
15721587
Expr::Case {
15731588
operand,
15741589
conditions,
1575-
results,
15761590
else_result,
15771591
} => {
15781592
write!(f, "CASE")?;
15791593
if let Some(operand) = operand {
15801594
write!(f, " {operand}")?;
15811595
}
1582-
for (c, r) in conditions.iter().zip(results) {
1583-
write!(f, " WHEN {c} THEN {r}")?;
1596+
for when in conditions {
1597+
write!(f, " {when}")?;
15841598
}
1585-
15861599
if let Some(else_result) = else_result {
15871600
write!(f, " ELSE {else_result}")?;
15881601
}

src/ast/spans.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,15 +1428,15 @@ impl Spanned for Expr {
14281428
Expr::Case {
14291429
operand,
14301430
conditions,
1431-
results,
14321431
else_result,
14331432
} => union_spans(
14341433
operand
14351434
.as_ref()
14361435
.map(|i| i.span())
14371436
.into_iter()
1438-
.chain(conditions.iter().map(|i| i.span()))
1439-
.chain(results.iter().map(|i| i.span()))
1437+
.chain(conditions.iter().flat_map(|case_when| {
1438+
[case_when.condition.span(), case_when.result.span()]
1439+
}))
14401440
.chain(else_result.as_ref().map(|i| i.span())),
14411441
),
14421442
Expr::Exists { subquery, .. } => subquery.span(),

src/parser/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,11 +1989,11 @@ impl<'a> Parser<'a> {
19891989
self.expect_keyword_is(Keyword::WHEN)?;
19901990
}
19911991
let mut conditions = vec![];
1992-
let mut results = vec![];
19931992
loop {
1994-
conditions.push(self.parse_expr()?);
1993+
let condition = self.parse_expr()?;
19951994
self.expect_keyword_is(Keyword::THEN)?;
1996-
results.push(self.parse_expr()?);
1995+
let result = self.parse_expr()?;
1996+
conditions.push(CaseWhen { condition, result });
19971997
if !self.parse_keyword(Keyword::WHEN) {
19981998
break;
19991999
}
@@ -2007,7 +2007,6 @@ impl<'a> Parser<'a> {
20072007
Ok(Expr::Case {
20082008
operand,
20092009
conditions,
2010-
results,
20112010
else_result,
20122011
})
20132012
}

tests/sqlparser_common.rs

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6257,22 +6257,26 @@ fn parse_searched_case_expr() {
62576257
&Case {
62586258
operand: None,
62596259
conditions: vec![
6260-
IsNull(Box::new(Identifier(Ident::new("bar")))),
6261-
BinaryOp {
6262-
left: Box::new(Identifier(Ident::new("bar"))),
6263-
op: Eq,
6264-
right: Box::new(Expr::Value(number("0"))),
6260+
CaseWhen {
6261+
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
6262+
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
62656263
},
6266-
BinaryOp {
6267-
left: Box::new(Identifier(Ident::new("bar"))),
6268-
op: GtEq,
6269-
right: Box::new(Expr::Value(number("0"))),
6264+
CaseWhen {
6265+
condition: BinaryOp {
6266+
left: Box::new(Identifier(Ident::new("bar"))),
6267+
op: Eq,
6268+
right: Box::new(Expr::Value(number("0"))),
6269+
},
6270+
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
6271+
},
6272+
CaseWhen {
6273+
condition: BinaryOp {
6274+
left: Box::new(Identifier(Ident::new("bar"))),
6275+
op: GtEq,
6276+
right: Box::new(Expr::Value(number("0"))),
6277+
},
6278+
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
62706279
},
6271-
],
6272-
results: vec![
6273-
Expr::Value(Value::SingleQuotedString("null".to_string())),
6274-
Expr::Value(Value::SingleQuotedString("=0".to_string())),
6275-
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
62766280
],
62776281
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
62786282
"<0".to_string()
@@ -6291,8 +6295,10 @@ fn parse_simple_case_expr() {
62916295
assert_eq!(
62926296
&Case {
62936297
operand: Some(Box::new(Identifier(Ident::new("foo")))),
6294-
conditions: vec![Expr::Value(number("1"))],
6295-
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
6298+
conditions: vec![CaseWhen {
6299+
condition: Expr::Value(number("1")),
6300+
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
6301+
}],
62966302
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
62976303
"N".to_string()
62986304
)))),
@@ -12992,3 +12998,28 @@ fn test_trailing_commas_in_from() {
1299212998
"SELECT 1, 2 FROM (SELECT * FROM t1), (SELECT * FROM t2)",
1299312999
);
1299413000
}
13001+
13002+
#[test]
13003+
#[cfg(feature = "visitor")]
13004+
fn test_visit_order() {
13005+
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
13006+
let stmt = verified_stmt(sql);
13007+
let mut visited = vec![];
13008+
sqlparser::ast::visit_expressions(&stmt, |expr| {
13009+
visited.push(expr.to_string());
13010+
core::ops::ControlFlow::<()>::Continue(())
13011+
});
13012+
13013+
assert_eq!(
13014+
visited,
13015+
[
13016+
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
13017+
"a",
13018+
"1",
13019+
"2",
13020+
"3",
13021+
"4",
13022+
"5"
13023+
]
13024+
);
13025+
}

tests/sqlparser_databricks.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,31 @@ fn test_databricks_lambdas() {
108108
body: Box::new(Expr::Case {
109109
operand: None,
110110
conditions: vec![
111-
Expr::BinaryOp {
112-
left: Box::new(Expr::Identifier(Ident::new("p1"))),
113-
op: BinaryOperator::Eq,
114-
right: Box::new(Expr::Identifier(Ident::new("p2")))
111+
CaseWhen {
112+
condition: Expr::BinaryOp {
113+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
114+
op: BinaryOperator::Eq,
115+
right: Box::new(Expr::Identifier(Ident::new("p2")))
116+
},
117+
result: Expr::Value(number("0"))
118+
},
119+
CaseWhen {
120+
condition: Expr::BinaryOp {
121+
left: Box::new(call(
122+
"reverse",
123+
[Expr::Identifier(Ident::new("p1"))]
124+
)),
125+
op: BinaryOperator::Lt,
126+
right: Box::new(call(
127+
"reverse",
128+
[Expr::Identifier(Ident::new("p2"))]
129+
)),
130+
},
131+
result: Expr::UnaryOp {
132+
op: UnaryOperator::Minus,
133+
expr: Box::new(Expr::Value(number("1")))
134+
}
115135
},
116-
Expr::BinaryOp {
117-
left: Box::new(call(
118-
"reverse",
119-
[Expr::Identifier(Ident::new("p1"))]
120-
)),
121-
op: BinaryOperator::Lt,
122-
right: Box::new(call(
123-
"reverse",
124-
[Expr::Identifier(Ident::new("p2"))]
125-
))
126-
}
127-
],
128-
results: vec![
129-
Expr::Value(number("0")),
130-
Expr::UnaryOp {
131-
op: UnaryOperator::Minus,
132-
expr: Box::new(Expr::Value(number("1")))
133-
}
134136
],
135137
else_result: Some(Box::new(Expr::Value(number("1"))))
136138
})

0 commit comments

Comments
 (0)