Skip to content

Commit 087879a

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-22520][SQL] Support code generation for large CaseWhen
## What changes were proposed in this pull request? Code generation is disabled for CaseWhen when the number of branches is higher than `spark.sql.codegen.maxCaseBranches` (which defaults to 20). This was done to prevent the well known 64KB method limit exception. This PR proposes to support code generation also in those cases (without causing exceptions of course). As a side effect, we could get rid of the `spark.sql.codegen.maxCaseBranches` configuration. ## How was this patch tested? existing UTs Author: Marco Gaido <[email protected]> Author: Marco Gaido <[email protected]> Closes #19752 from mgaido91/SPARK-22520.
1 parent 1ff4a77 commit 087879a

File tree

10 files changed

+122
-253
lines changed

10 files changed

+122
-253
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ class EquivalentExpressions {
8787
def childrenToRecurse: Seq[Expression] = expr match {
8888
case _: CodegenFallback => Nil
8989
case i: If => i.predicate :: Nil
90-
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
91-
case c: CaseWhenCodegen => c.children.head :: Nil
90+
case c: CaseWhen => c.children.head :: Nil
9291
case c: Coalesce => c.children.head :: Nil
9392
case other => other.children
9493
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 102 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
8888
}
8989

9090
/**
91-
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
91+
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
92+
* When a = true, returns b; when c = true, returns d; else returns e.
9293
*
9394
* @param branches seq of (branch condition, branch value)
9495
* @param elseValue optional value for the else branch
9596
*/
96-
abstract class CaseWhenBase(
97+
// scalastyle:off line.size.limit
98+
@ExpressionDescription(
99+
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
100+
arguments = """
101+
Arguments:
102+
* expr1, expr3 - the branch condition expressions should all be boolean type.
103+
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
104+
same type or coercible to a common type.
105+
""",
106+
examples = """
107+
Examples:
108+
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
109+
1
110+
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
111+
2
112+
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END;
113+
NULL
114+
""")
115+
// scalastyle:on line.size.limit
116+
case class CaseWhen(
97117
branches: Seq[(Expression, Expression)],
98-
elseValue: Option[Expression])
118+
elseValue: Option[Expression] = None)
99119
extends Expression with Serializable {
100120

101121
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -158,111 +178,103 @@ abstract class CaseWhenBase(
158178
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
159179
"CASE" + cases + elseCase + " END"
160180
}
161-
}
162-
163-
164-
/**
165-
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
166-
* When a = true, returns b; when c = true, returns d; else returns e.
167-
*
168-
* @param branches seq of (branch condition, branch value)
169-
* @param elseValue optional value for the else branch
170-
*/
171-
// scalastyle:off line.size.limit
172-
@ExpressionDescription(
173-
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
174-
arguments = """
175-
Arguments:
176-
* expr1, expr3 - the branch condition expressions should all be boolean type.
177-
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
178-
same type or coercible to a common type.
179-
""",
180-
examples = """
181-
Examples:
182-
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
183-
1
184-
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
185-
2
186-
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
187-
NULL
188-
""")
189-
// scalastyle:on line.size.limit
190-
case class CaseWhen(
191-
val branches: Seq[(Expression, Expression)],
192-
val elseValue: Option[Expression] = None)
193-
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
194-
195-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
196-
super[CodegenFallback].doGenCode(ctx, ev)
197-
}
198-
199-
def toCodegen(): CaseWhenCodegen = {
200-
CaseWhenCodegen(branches, elseValue)
201-
}
202-
}
203-
204-
/**
205-
* CaseWhen expression used when code generation condition is satisfied.
206-
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
207-
*
208-
* @param branches seq of (branch condition, branch value)
209-
* @param elseValue optional value for the else branch
210-
*/
211-
case class CaseWhenCodegen(
212-
val branches: Seq[(Expression, Expression)],
213-
val elseValue: Option[Expression] = None)
214-
extends CaseWhenBase(branches, elseValue) with Serializable {
215181

216182
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
217-
// Generate code that looks like:
218-
//
219-
// condA = ...
220-
// if (condA) {
221-
// valueA
222-
// } else {
223-
// condB = ...
224-
// if (condB) {
225-
// valueB
226-
// } else {
227-
// condC = ...
228-
// if (condC) {
229-
// valueC
230-
// } else {
231-
// elseValue
232-
// }
233-
// }
234-
// }
183+
// This variable represents whether the first successful condition is met or not.
184+
// It is initialized to `false` and it is set to `true` when the first condition which
185+
// evaluates to `true` is met and therefore is not needed to go on anymore on the computation
186+
// of the following conditions.
187+
val conditionMet = ctx.freshName("caseWhenConditionMet")
188+
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
189+
ctx.addMutableState(ctx.javaType(dataType), ev.value)
190+
191+
// these blocks are meant to be inside a
192+
// do {
193+
// ...
194+
// } while (false);
195+
// loop
235196
val cases = branches.map { case (condExpr, valueExpr) =>
236197
val cond = condExpr.genCode(ctx)
237198
val res = valueExpr.genCode(ctx)
238199
s"""
239-
${cond.code}
240-
if (!${cond.isNull} && ${cond.value}) {
241-
${res.code}
242-
${ev.isNull} = ${res.isNull};
243-
${ev.value} = ${res.value};
200+
if(!$conditionMet) {
201+
${cond.code}
202+
if (!${cond.isNull} && ${cond.value}) {
203+
${res.code}
204+
${ev.isNull} = ${res.isNull};
205+
${ev.value} = ${res.value};
206+
$conditionMet = true;
207+
continue;
208+
}
244209
}
245210
"""
246211
}
247212

248-
var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
249-
250-
elseValue.foreach { elseExpr =>
213+
val elseCode = elseValue.map { elseExpr =>
251214
val res = elseExpr.genCode(ctx)
252-
generatedCode +=
253-
s"""
215+
s"""
216+
if(!$conditionMet) {
254217
${res.code}
255218
${ev.isNull} = ${res.isNull};
256219
${ev.value} = ${res.value};
257-
"""
220+
}
221+
"""
258222
}
259223

260-
generatedCode += "}\n" * cases.size
224+
val allConditions = cases ++ elseCode
225+
226+
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
227+
allConditions.mkString("\n")
228+
} else {
229+
// This generates code like:
230+
// conditionMet = caseWhen_1(i);
231+
// if(conditionMet) {
232+
// continue;
233+
// }
234+
// conditionMet = caseWhen_2(i);
235+
// if(conditionMet) {
236+
// continue;
237+
// }
238+
// ...
239+
// and the declared methods are:
240+
// private boolean caseWhen_1234() {
241+
// boolean conditionMet = false;
242+
// do {
243+
// // here the evaluation of the conditions
244+
// } while (false);
245+
// return conditionMet;
246+
// }
247+
ctx.splitExpressions(allConditions, "caseWhen",
248+
("InternalRow", ctx.INPUT_ROW) :: Nil,
249+
returnType = ctx.JAVA_BOOLEAN,
250+
makeSplitFunction = {
251+
func =>
252+
s"""
253+
${ctx.JAVA_BOOLEAN} $conditionMet = false;
254+
do {
255+
$func
256+
} while (false);
257+
return $conditionMet;
258+
"""
259+
},
260+
foldFunctions = { funcCalls =>
261+
funcCalls.map { funcCall =>
262+
s"""
263+
$conditionMet = $funcCall;
264+
if ($conditionMet) {
265+
continue;
266+
}"""
267+
}.mkString
268+
})
269+
}
261270

262271
ev.copy(code = s"""
263-
boolean ${ev.isNull} = true;
264-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
265-
$generatedCode""")
272+
${ev.isNull} = true;
273+
${ev.value} = ${ctx.defaultValue(dataType)};
274+
${ctx.JAVA_BOOLEAN} $conditionMet = false;
275+
do {
276+
$code
277+
} while (false);""")
266278
}
267279
}
268280

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
138138
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
139139
Batch("Check Cartesian Products", Once,
140140
CheckCartesianProducts) ::
141-
Batch("OptimizeCodegen", Once,
142-
OptimizeCodegen) ::
143141
Batch("RewriteSubquery", Once,
144142
RewritePredicateSubquery,
145143
CollapseProject) :: Nil

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
552552
}
553553

554554

555-
/**
556-
* Optimizes expressions by replacing according to CodeGen configuration.
557-
*/
558-
object OptimizeCodegen extends Rule[LogicalPlan] {
559-
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
560-
case e: CaseWhen if canCodegen(e) => e.toCodegen()
561-
}
562-
563-
private def canCodegen(e: CaseWhen): Boolean = {
564-
val numBranches = e.branches.size + e.elseValue.size
565-
numBranches <= SQLConf.get.maxCaseBranchesForCodegen
566-
}
567-
}
568-
569-
570555
/**
571556
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
572557
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -599,12 +599,6 @@ object SQLConf {
599599
.booleanConf
600600
.createWithDefault(true)
601601

602-
val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches")
603-
.internal()
604-
.doc("The maximum number of switches supported with codegen.")
605-
.intConf
606-
.createWithDefault(20)
607-
608602
val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
609603
.internal()
610604
.doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
@@ -1140,8 +1134,6 @@ class SQLConf extends Serializable with Logging {
11401134

11411135
def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
11421136

1143-
def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
1144-
11451137
def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
11461138

11471139
def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
7777
}
7878

7979
test("SPARK-13242: case-when expression with large number of branches (or cases)") {
80-
val cases = 50
80+
val cases = 500
8181
val clauses = 20
8282

8383
// Generate an individual case
@@ -88,13 +88,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
8888
(condition, Literal(n))
8989
}
9090

91-
val expression = CaseWhen((1 to cases).map(generateCase(_)))
91+
val expression = CaseWhen((1 to cases).map(generateCase))
9292

9393
val plan = GenerateMutableProjection.generate(Seq(expression))
94-
val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
94+
val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"$clauses:$cases")))
9595
val actual = plan(input).toSeq(Seq(expression.dataType))
9696

97-
assert(actual(0) == cases)
97+
assert(actual.head == cases)
9898
}
9999

100100
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {

0 commit comments

Comments
 (0)