Skip to content

Commit 5acb466

Browse files
authored
Parse case expr (#120)
* Add CASE expression to Parser Adds CASE expression parsing GitHub Issue: #119 Adds the required constructs to enable CASE EXPRESSION parsing. The PR includes changes to partiql-ast and partiql-parser to handle SearchedCase and SimpleSearch parsing. See the mentioned GitHub issue for more details. As a result of the changes in this PR, the conformance tests in the following pass: https://github.com/partiql/partiql-tests/blob/main/partiql-test-data/pass/parser/primitives/case.ion
1 parent f1498da commit 5acb466

File tree

4 files changed

+157
-36
lines changed

4 files changed

+157
-36
lines changed

partiql-ast/src/ast.rs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -222,38 +222,39 @@ pub struct Expr {
222222
/// Represents an AST Node of type T with BytePosition Location
223223
pub type AstBytePos<T> = AstNode<T, BytePosition>;
224224

225-
pub type LitAst = AstBytePos<Lit>;
226-
pub type VarRefAst = AstBytePos<VarRef>;
227-
pub type ParamAst = AstBytePos<VarRef>;
228-
pub type StructAst = AstBytePos<Struct>;
229225
pub type BagAst = AstBytePos<Bag>;
230-
pub type ListAst = AstBytePos<List>;
231-
pub type SexpAst = AstBytePos<Sexp>;
232-
pub type BinOpAst = AstBytePos<BinOp>;
233-
pub type UniOpAst = AstBytePos<UniOp>;
234-
pub type LikeAst = AstBytePos<Like>;
235226
pub type BetweenAst = AstBytePos<Between>;
236-
pub type InAst = AstBytePos<In>;
237-
pub type SimpleCaseAst = AstBytePos<SimpleCase>;
238-
pub type SearchCaseAst = AstBytePos<SearchCase>;
239-
pub type SetExprAst = AstBytePos<SetExpr>;
240-
pub type PathAst = AstBytePos<Path>;
241-
pub type CallAst = AstBytePos<Call>;
227+
pub type BinOpAst = AstBytePos<BinOp>;
242228
pub type CallAggAst = AstBytePos<CallAgg>;
243-
pub type SelectAst = AstBytePos<Select>;
244-
pub type ProjectionAst = AstBytePos<Projection>;
245-
pub type ProjectItemAst = AstBytePos<ProjectItem>;
229+
pub type CallAst = AstBytePos<Call>;
230+
pub type CaseAst = AstBytePos<Case>;
246231
pub type FromClauseAst = AstBytePos<FromClause>;
247232
pub type FromLetAst = AstBytePos<FromLet>;
233+
pub type GroupByExprAst = AstBytePos<GroupByExpr>;
234+
pub type GroupKeyAst = AstBytePos<GroupKey>;
235+
pub type InAst = AstBytePos<In>;
248236
pub type JoinAst = AstBytePos<Join>;
249237
pub type JoinSpecAst = AstBytePos<JoinSpec>;
250238
pub type LetAst = AstBytePos<Let>;
251-
pub type GroupByExprAst = AstBytePos<GroupByExpr>;
252-
pub type GroupKeyAst = AstBytePos<GroupKey>;
239+
pub type LikeAst = AstBytePos<Like>;
240+
pub type ListAst = AstBytePos<List>;
241+
pub type LitAst = AstBytePos<Lit>;
253242
pub type OrderByExprAst = AstBytePos<OrderByExpr>;
254-
pub type SortSpecAst = AstBytePos<SortSpec>;
243+
pub type ParamAst = AstBytePos<VarRef>;
244+
pub type PathAst = AstBytePos<Path>;
245+
pub type ProjectItemAst = AstBytePos<ProjectItem>;
246+
pub type ProjectionAst = AstBytePos<Projection>;
255247
pub type QueryAst = AstBytePos<Query>;
256248
pub type QuerySetAst = AstBytePos<QuerySet>;
249+
pub type SearchedCaseAst = AstBytePos<SearchedCase>;
250+
pub type SelectAst = AstBytePos<Select>;
251+
pub type SetExprAst = AstBytePos<SetExpr>;
252+
pub type SexpAst = AstBytePos<Sexp>;
253+
pub type SimpleCaseAst = AstBytePos<SimpleCase>;
254+
pub type SortSpecAst = AstBytePos<SortSpec>;
255+
pub type StructAst = AstBytePos<Struct>;
256+
pub type UniOpAst = AstBytePos<UniOp>;
257+
pub type VarRefAst = AstBytePos<VarRef>;
257258

258259
#[derive(Clone, Debug, PartialEq)]
259260
pub struct Query {
@@ -302,10 +303,7 @@ pub enum ExprKind {
302303
Like(LikeAst),
303304
Between(BetweenAst),
304305
In(InAst),
305-
/// CASE <expr> [ WHEN <expr> THEN <expr> ]... [ ELSE <expr> ] END
306-
SimpleCase(SimpleCaseAst),
307-
/// CASE [ WHEN <expr> THEN <expr> ]... [ ELSE <expr> ] END
308-
SearchedCase(SearchCaseAst),
306+
Case(CaseAst),
309307
/// Constructors
310308
Struct(StructAst),
311309
Bag(BagAst),
@@ -439,6 +437,19 @@ pub struct In {
439437
pub operands: Vec<Box<Expr>>,
440438
}
441439

440+
#[derive(Clone, Debug, PartialEq)]
441+
pub struct Case {
442+
pub kind: CaseKind,
443+
}
444+
445+
#[derive(Clone, Debug, PartialEq)]
446+
pub enum CaseKind {
447+
/// CASE <expr> [ WHEN <expr> THEN <expr> ]... [ ELSE <expr> ] END
448+
SimpleCase(SimpleCase),
449+
/// CASE [ WHEN <expr> THEN <expr> ]... [ ELSE <expr> ] END
450+
SearchedCase(SearchedCase),
451+
}
452+
442453
#[derive(Clone, Debug, PartialEq)]
443454
pub struct SimpleCase {
444455
pub expr: Box<Expr>,
@@ -447,7 +458,7 @@ pub struct SimpleCase {
447458
}
448459

449460
#[derive(Clone, Debug, PartialEq)]
450-
pub struct SearchCase {
461+
pub struct SearchedCase {
451462
pub cases: Vec<ExprPair>,
452463
pub default: Option<Box<Expr>>,
453464
}

partiql-parser/src/lexer.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,18 @@ pub enum Token<'input> {
510510
Between,
511511
#[regex("(?i:By)")]
512512
By,
513+
#[regex("(?i:Case)")]
514+
Case,
513515
#[regex("(?i:Cross)")]
514516
Cross,
515517
#[regex("(?i:Desc)")]
516518
Desc,
517519
#[regex("(?i:Distinct)")]
518520
Distinct,
521+
#[regex("(?i:Else)")]
522+
Else,
523+
#[regex("(?i:End)")]
524+
End,
519525
#[regex("(?i:Escape)")]
520526
Escape,
521527
#[regex("(?i:Except)")]
@@ -582,6 +588,8 @@ pub enum Token<'input> {
582588
Right,
583589
#[regex("(?i:Select)")]
584590
Select,
591+
#[regex("(?i:Then)")]
592+
Then,
585593
#[regex("(?i:True)")]
586594
True,
587595
#[regex("(?i:Union)")]
@@ -594,6 +602,8 @@ pub enum Token<'input> {
594602
Value,
595603
#[regex("(?i:Values)")]
596604
Values,
605+
#[regex("(?i:When)")]
606+
When,
597607
#[regex("(?i:Where)")]
598608
Where,
599609
#[regex("(?i:With)")]
@@ -651,9 +661,12 @@ impl<'input> fmt::Display for Token<'input> {
651661
| Token::At
652662
| Token::Between
653663
| Token::By
664+
| Token::Case
654665
| Token::Cross
655666
| Token::Desc
656667
| Token::Distinct
668+
| Token::Else
669+
| Token::End
657670
| Token::Escape
658671
| Token::Except
659672
| Token::False
@@ -687,12 +700,14 @@ impl<'input> fmt::Display for Token<'input> {
687700
| Token::Preserve
688701
| Token::Right
689702
| Token::Select
703+
| Token::Then
690704
| Token::True
691705
| Token::Union
692706
| Token::Unpivot
693707
| Token::Using
694708
| Token::Value
695709
| Token::Values
710+
| Token::When
696711
| Token::Where
697712
| Token::With => {
698713
write!(f, "{}", format!("{:?}", self).to_uppercase())
@@ -720,7 +735,7 @@ mod tests {
720735
"WiTH Where Value uSiNg Unpivot UNION True Select right Preserve pivoT Outer Order Or \
721736
On Offset Nulls Null Not Natural Missing Limit Like Left Lateral Last Join \
722737
Intersect Is Inner In Having Group From Full First False Except Escape Desc \
723-
Cross By Between At As And Asc All Values";
738+
Cross By Between At As And Asc All Values Case When Then Else End";
724739
let symbols = symbols.split(' ').chain(primitives.split(' '));
725740
let keywords = keywords.split(' ');
726741

@@ -740,7 +755,8 @@ mod tests {
740755
"LIMIT", "/", "LIKE", "^", "LEFT", ".", "LATERAL", "||", "LAST", ":", "JOIN",
741756
"--", "INTERSECT", "/**/", "IS", "<ident:IDENT>", "INNER", "<atident:@IDENT>", "IN",
742757
"HAVING", "GROUP", "FROM", "FULL", "FIRST", "FALSE", "EXCEPT", "ESCAPE", "DESC",
743-
"CROSS", "BY", "BETWEEN", "AT", "AS", "AND", "ASC", "ALL", "VALUES"
758+
"CROSS", "BY", "BETWEEN", "AT", "AS", "AND", "ASC", "ALL", "VALUES", "CASE", "WHEN",
759+
"THEN", "ELSE", "END",
744760
];
745761
let displayed = toks
746762
.into_iter()

partiql-parser/src/parse/mod.rs

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,14 +382,36 @@ mod tests {
382382
"#;
383383
parse!(q);
384384
}
385-
}
386-
387-
mod values {
388-
use super::*;
389385

390386
#[test]
391-
fn values() {
392-
parse!("VALUES ('A', `5e0`), ('B', 3.0), ('X', 9.0)");
387+
fn select_with_case() {
388+
parse!(r#"SELECT a WHERE CASE WHEN x <> 0 THEN y/x > 1.5 ELSE false END"#);
389+
parse!(
390+
r#"SELECT a,
391+
CASE WHEN a=1 THEN 'one'
392+
WHEN a=2 THEN 'two'
393+
ELSE 'other'
394+
END
395+
FROM test"#
396+
);
397+
398+
parse!(
399+
r#"SELECT VALUE
400+
{
401+
'locationType': R.LocationType,
402+
'Location': (
403+
CASE WHEN id IS NOT NULL THEN
404+
(SELECT VALUE (CASE WHEN R.LocationType = 'z' THEN n ELSE d END)
405+
FROM R.Scope AS scope WHERE scope.name = id)
406+
ELSE
407+
(SELECT VALUE (CASE WHEN R.LocationType = 'z' THEN n ELSE d END)
408+
FROM R.Scope AS scope WHERE scope.name = someZone)
409+
END
410+
),
411+
'marketType' : MarketInfo.marketType,
412+
}
413+
FROM UNPIVOT R.returnValueMap.success AS "list" AT symb"#
414+
);
393415
}
394416
}
395417

@@ -430,6 +452,23 @@ mod tests {
430452
}
431453
}
432454

455+
mod case_expr {
456+
use super::*;
457+
458+
#[test]
459+
fn searched_case() {
460+
parse!(r#"CASE WHEN TRUE THEN 2 END"#);
461+
parse!(r#"CASE WHEN id IS 1 THEN 2 WHEN titanId IS 2 THEN 3 ELSE 1 END"#);
462+
parse!(r#"CASE hello WHEN id IS NOT NULL THEN (SELECT * FROM data) ELSE 1 END"#);
463+
}
464+
465+
#[test]
466+
#[should_panic]
467+
fn searched_case_failure() {
468+
parse!(r#"CASE hello WHEN id IS NOT NULL THEN SELECT * FROM data ELSE 1 END"#);
469+
}
470+
}
471+
433472
mod errors {
434473
use super::*;
435474
use crate::result::{LexicalError, UnexpectedToken, UnexpectedTokenData};

partiql-parser/src/parse/partiql.lalrpop

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ Projection: ast::ProjectItemAst = {
210210
FromClause: ast::FromClauseAst = {
211211
<lo:@L> "FROM" <mut froms:(<TableReference> "," "LATERAL"?)*> <last:TableReference> <hi:@R> => {
212212
let total: Location<BytePosition> = Location::from(lo.into()..hi.into());
213-
213+
214214
// use `reduce` to process the comma-seperated `TableReference`s
215215
// as left-associative `CROSS JOIN`s
216216
froms.push(last);
@@ -739,7 +739,57 @@ ExprPrecedence02: ast::Expr = {
739739
}
740740

741741
#[inline]
742-
ExprPrecedence01: ast::Expr = {<ExprTerm>,}
742+
ExprPrecedence01: ast::Expr = {
743+
<casexpr:CaseExpr> => ast::Expr {
744+
kind: ast::ExprKind::Case(casexpr)
745+
},
746+
<ExprTerm>,
747+
};
748+
749+
// ------------------------------------------------------------------------------ //
750+
// //
751+
// Case Expression //
752+
// //
753+
// ------------------------------------------------------------------------------ //
754+
// Implements parsing for CASE expressions.
755+
//
756+
// Searched Case Example:
757+
// CASE WHEN titanId IS 1 THEN 2 WHEN titanId IS 2 THEN 3 ELSE 1 END
758+
//
759+
// Simple Case Example:
760+
// CASE hello WHEN titanId IS NOT NULL THEN (SELECT * FROM data) ELSE 1 END
761+
//
762+
// The following is not allowed:
763+
// CASE hello WHEN titanId IS NOT NULL THEN SELECT * FROM data ELSE 1 END
764+
//
765+
// This becaue as per SQL-92 standard THEN <result> ultimately leads to
766+
// the following:
767+
// <subquery> ::= <left paren> <query expression> <right paren>
768+
769+
CaseExpr: ast::CaseAst = {
770+
<lo:@L> "CASE" <expr:ExprQuery?> <cases:ExprPairWhenThen+> <elsexpr:ElseClause?> "END" <hi:@R> => {
771+
match expr {
772+
None => ast::Case {
773+
kind: ast::CaseKind::SearchedCase(
774+
ast::SearchedCase { cases, default: elsexpr }
775+
)
776+
}.ast(lo..hi),
777+
Some(expr) => ast::Case {
778+
kind: ast::CaseKind::SimpleCase(
779+
ast::SimpleCase { expr, cases, default: elsexpr }
780+
)
781+
}.ast(lo..hi)
782+
}
783+
}
784+
};
785+
786+
ElseClause: Box<ast::Expr> = {
787+
"ELSE" <e:ExprQuery> => Box::new(*e)
788+
};
789+
790+
ExprPairWhenThen: ast::ExprPair = {
791+
<lo:@L> "WHEN" <first:ExprQuery> "THEN" <second:ExprQuery> <hi:@R> => ast::ExprPair { first, second },
792+
};
743793

744794
pub ExprTerm: ast::Expr = {
745795
"(" <q:Query> ")" => *q,
@@ -1027,9 +1077,12 @@ extern {
10271077
"AT" => lexer::Token::At,
10281078
"BETWEEN" => lexer::Token::Between,
10291079
"BY" => lexer::Token::By,
1080+
"CASE" => lexer::Token::Case,
10301081
"CROSS" => lexer::Token::Cross,
10311082
"DESC" => lexer::Token::Desc,
10321083
"DISTINCT" => lexer::Token::Distinct,
1084+
"ELSE" => lexer::Token::Else,
1085+
"END" => lexer::Token::End,
10331086
"ESCAPE" => lexer::Token::Escape,
10341087
"EXCEPT" => lexer::Token::Except,
10351088
"FALSE" => lexer::Token::False,
@@ -1063,12 +1116,14 @@ extern {
10631116
"PRESERVE" => lexer::Token::Preserve,
10641117
"RIGHT" => lexer::Token::Right,
10651118
"SELECT" => lexer::Token::Select,
1119+
"THEN" => lexer::Token::Then,
10661120
"TRUE" => lexer::Token::True,
10671121
"UNION" => lexer::Token::Union,
10681122
"UNPIVOT" => lexer::Token::Unpivot,
10691123
"USING" => lexer::Token::Using,
10701124
"VALUE" => lexer::Token::Value,
10711125
"VALUES" => lexer::Token::Values,
1126+
"WHEN" => lexer::Token::When,
10721127
"WHERE" => lexer::Token::Where,
10731128
"WITH" => lexer::Token::With,
10741129
}

0 commit comments

Comments
 (0)