Skip to content

Commit 5533c81

Browse files
dusantism-dbMaxGekk
authored andcommitted
[SPARK-48355][SQL] Support for CASE statement
### What changes were proposed in this pull request? Add support for [case statements](https://docs.google.com/document/d/1cpSuR3KxRuTSJ4ZMQ73FJ4_-hjouNNU2zfI4vri6yhs/edit#heading=h.ofijhkunigv) to sql scripting. There are 2 types of case statement - simple and searched (EXAMPLES BELOW). Proposed changes are: - Add `caseStatement` grammar rule to SqlBaseParser.g4 - Add visit case statement methods to `AstBuilder` - Add `SearchedCaseStatement` and `SearchedCaseStatementExec` classes, to enable them to be run in sql scripts. The reason only searched case nodes are added is that, in the current implementation, a simple case is parsed into a searched case, by creating internal `EqualTo` expressions to compare the main case expression to the expressions in the when clauses. This approach is similar to the existing case **expressions**, which are parsed in the same way. The problem with this approach is that the main expression is unnecessarily evaluated N times, where N is the number of when clauses, which can be quite inefficient, for example if the expression is a complex query. Optimally, the main expression would be evaluated once, and then compared to the other expressions. I'm open to suggestions as to what the best approach to achieve this would be. Simple case compares one expression (case variable) to others, until an equal one is found. Else clause is optional. ``` BEGIN CASE 1 WHEN 1 THEN SELECT 1; WHEN 2 THEN SELECT 2; ELSE SELECT 3; END CASE; END ``` Searched case evaluates boolean expressions. Else clause is optional. ``` BEGIN CASE WHEN 1 = 1 THEN SELECT 1; WHEN 2 IN (1,2,3) THEN SELECT 2; ELSE SELECT 3; END CASE; END ``` ### Why are the changes needed? Case statements are currently not implemented in sql scripting. ### Does this PR introduce _any_ user-facing change? Yes, users will now be able to use case statements in their sql scripts. ### How was this patch tested? Tests for both simple and searched case statements are added to SqlScriptingParserSuite, SqlScriptingExecutionNodeSuite and SqlScriptingInterpreterSuite. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#47672 from dusantism-db/sql-scripting-case-statement. Authored-by: Dušan Tišma <dusan.tisma@databricks.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent aa54ed1 commit 5533c81

File tree

8 files changed

+920
-4
lines changed

8 files changed

+920
-4
lines changed

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ compoundStatement
6464
| setStatementWithOptionalVarKeyword
6565
| beginEndCompoundBlock
6666
| ifElseStatement
67+
| caseStatement
6768
| whileStatement
6869
| repeatStatement
6970
| leaveStatement
@@ -98,6 +99,13 @@ iterateStatement
9899
: ITERATE multipartIdentifier
99100
;
100101

102+
caseStatement
103+
: CASE (WHEN conditions+=booleanExpression THEN conditionalBodies+=compoundBody)+
104+
(ELSE elseBody=compoundBody)? END CASE #searchedCaseStatement
105+
| CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN conditionalBodies+=compoundBody)+
106+
(ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement
107+
;
108+
101109
singleStatement
102110
: (statement|setResetStatement) SEMICOLON* EOF
103111
;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,52 @@ class AstBuilder extends DataTypeAstBuilder
261261
WhileStatement(condition, body, Some(labelText))
262262
}
263263

264+
override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext): CaseStatement = {
265+
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
266+
SingleStatement(
267+
Project(
268+
Seq(Alias(expression(boolExpr), "condition")()),
269+
OneRowRelation()))
270+
})
271+
val conditionalBodies =
272+
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
273+
274+
if (conditions.length != conditionalBodies.length) {
275+
throw SparkException.internalError(
276+
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
277+
s" ${conditionalBodies.length} in case statement")
278+
}
279+
280+
CaseStatement(
281+
conditions = conditions,
282+
conditionalBodies = conditionalBodies,
283+
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
284+
}
285+
286+
override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext): CaseStatement = {
287+
// uses EqualTo to compare the case variable(the main case expression)
288+
// to the WHEN clause expressions
289+
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
290+
SingleStatement(
291+
Project(
292+
Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()),
293+
OneRowRelation()))
294+
})
295+
val conditionalBodies =
296+
ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
297+
298+
if (conditions.length != conditionalBodies.length) {
299+
throw SparkException.internalError(
300+
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
301+
s" ${conditionalBodies.length} in case statement")
302+
}
303+
304+
CaseStatement(
305+
conditions = conditions,
306+
conditionalBodies = conditionalBodies,
307+
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
308+
}
309+
264310
override def visitRepeatStatement(ctx: RepeatStatementContext): RepeatStatement = {
265311
val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel()))
266312
val boolExpr = ctx.booleanExpression()
@@ -292,7 +338,7 @@ class AstBuilder extends DataTypeAstBuilder
292338
case c: RepeatStatementContext
293339
if Option(c.beginLabel()).isDefined &&
294340
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
295-
=> true
341+
=> true
296342
case _ => false
297343
}
298344
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,17 @@ case class LeaveStatement(label: String) extends CompoundPlanStatement
124124
* @param label Label of the loop to iterate.
125125
*/
126126
case class IterateStatement(label: String) extends CompoundPlanStatement
127+
128+
/**
129+
* Logical operator for CASE statement.
130+
* @param conditions Collection of conditions which correspond to WHEN clauses.
131+
* @param conditionalBodies Collection of bodies that have a corresponding condition,
132+
* in WHEN branches.
133+
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
134+
*/
135+
case class CaseStatement(
136+
conditions: Seq[SingleStatement],
137+
conditionalBodies: Seq[CompoundBody],
138+
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
139+
assert(conditions.length == conditionalBodies.length)
140+
}

0 commit comments

Comments
 (0)