Skip to content

Commit 0605ad7

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-22543][SQL] fix java 64kb compile error for deeply nested expressions
## What changes were proposed in this pull request? A frequently reported issue of Spark is the Java 64kb compile error. This is because Spark generates a very big method and it's usually caused by 3 reasons: 1. a deep expression tree, e.g. a very complex filter condition 2. many individual expressions, e.g. expressions can have many children, operators can have many expressions. 3. a deep query plan tree (with whole stage codegen) This PR focuses on 1. There are already several patches(#15620 #18972 #18641) trying to fix this issue and some of them are already merged. However this is an endless job as every non-leaf expression has this issue. This PR proposes to fix this issue in `Expression.genCode`, to make sure the code for a single expression won't grow too big. According to maropu 's benchmark, no regression is found with TPCDS (thanks maropu !): https://docs.google.com/spreadsheets/d/1K3_7lX05-ZgxDXi9X_GleNnDjcnJIfoSlSCDZcL4gdg/edit?usp=sharing ## How was this patch tested? existing test Author: Wenchen Fan <[email protected]> Author: Wenchen Fan <[email protected]> Closes #19767 from cloud-fan/codegen.
1 parent 327d25f commit 0605ad7

File tree

7 files changed

+62
-163
lines changed

7 files changed

+62
-163
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,48 @@ abstract class Expression extends TreeNode[Expression] {
104104
}.getOrElse {
105105
val isNull = ctx.freshName("isNull")
106106
val value = ctx.freshName("value")
107-
val ve = doGenCode(ctx, ExprCode("", isNull, value))
108-
if (ve.code.nonEmpty) {
107+
val eval = doGenCode(ctx, ExprCode("", isNull, value))
108+
reduceCodeSize(ctx, eval)
109+
if (eval.code.nonEmpty) {
109110
// Add `this` in the comment.
110-
ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim)
111+
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
111112
} else {
112-
ve
113+
eval
113114
}
114115
}
115116
}
116117

118+
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
119+
// TODO: support whole stage codegen too
120+
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
121+
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
122+
val globalIsNull = ctx.freshName("globalIsNull")
123+
ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
124+
val localIsNull = eval.isNull
125+
eval.isNull = globalIsNull
126+
s"$globalIsNull = $localIsNull;"
127+
} else {
128+
""
129+
}
130+
131+
val javaType = ctx.javaType(dataType)
132+
val newValue = ctx.freshName("value")
133+
134+
val funcName = ctx.freshName(nodeName)
135+
val funcFullName = ctx.addNewFunction(funcName,
136+
s"""
137+
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
138+
| ${eval.code.trim}
139+
| $setIsNull
140+
| return ${eval.value};
141+
|}
142+
""".stripMargin)
143+
144+
eval.value = newValue
145+
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
146+
}
147+
}
148+
117149
/**
118150
* Returns Java source code that can be compiled to evaluate this expression.
119151
* The default behavior is to call the eval method of the expression. Concrete expression

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -930,36 +930,6 @@ class CodegenContext {
930930
}
931931
}
932932

933-
/**
934-
* Wrap the generated code of expression, which was created from a row object in INPUT_ROW,
935-
* by a function. ev.isNull and ev.value are passed by global variables
936-
*
937-
* @param ev the code to evaluate expressions.
938-
* @param dataType the data type of ev.value.
939-
* @param baseFuncName the split function name base.
940-
*/
941-
def createAndAddFunction(
942-
ev: ExprCode,
943-
dataType: DataType,
944-
baseFuncName: String): (String, String, String) = {
945-
val globalIsNull = freshName("isNull")
946-
addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;")
947-
val globalValue = freshName("value")
948-
addMutableState(javaType(dataType), globalValue,
949-
s"$globalValue = ${defaultValue(dataType)};")
950-
val funcName = freshName(baseFuncName)
951-
val funcBody =
952-
s"""
953-
|private void $funcName(InternalRow ${INPUT_ROW}) {
954-
| ${ev.code.trim}
955-
| $globalIsNull = ${ev.isNull};
956-
| $globalValue = ${ev.value};
957-
|}
958-
""".stripMargin
959-
val fullFuncName = addNewFunction(funcName, funcBody)
960-
(fullFuncName, globalIsNull, globalValue)
961-
}
962-
963933
/**
964934
* Perform a function which generates a sequence of ExprCodes with a given mapping between
965935
* expressions and common expressions, instead of using the mapping in current context.
@@ -1065,7 +1035,8 @@ class CodegenContext {
10651035
* elimination will be performed. Subexpression elimination assumes that the code for each
10661036
* expression will be combined in the `expressions` order.
10671037
*/
1068-
def generateExpressions(expressions: Seq[Expression],
1038+
def generateExpressions(
1039+
expressions: Seq[Expression],
10691040
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
10701041
if (doSubexpressionElimination) subexpressionElimination(expressions)
10711042
expressions.map(e => e.genCode(this))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -64,52 +64,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
6464
val trueEval = trueValue.genCode(ctx)
6565
val falseEval = falseValue.genCode(ctx)
6666

67-
// place generated code of condition, true value and false value in separate methods if
68-
// their code combined is large
69-
val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length
70-
val generatedCode = if (combinedLength > 1024 &&
71-
// Split these expressions only if they are created from a row object
72-
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
73-
74-
val (condFuncName, condGlobalIsNull, condGlobalValue) =
75-
ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr")
76-
val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
77-
ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr")
78-
val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
79-
ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr")
67+
val code =
8068
s"""
81-
$condFuncName(${ctx.INPUT_ROW});
82-
boolean ${ev.isNull} = false;
83-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
84-
if (!$condGlobalIsNull && $condGlobalValue) {
85-
$trueFuncName(${ctx.INPUT_ROW});
86-
${ev.isNull} = $trueGlobalIsNull;
87-
${ev.value} = $trueGlobalValue;
88-
} else {
89-
$falseFuncName(${ctx.INPUT_ROW});
90-
${ev.isNull} = $falseGlobalIsNull;
91-
${ev.value} = $falseGlobalValue;
92-
}
93-
"""
94-
}
95-
else {
96-
s"""
97-
${condEval.code}
98-
boolean ${ev.isNull} = false;
99-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
100-
if (!${condEval.isNull} && ${condEval.value}) {
101-
${trueEval.code}
102-
${ev.isNull} = ${trueEval.isNull};
103-
${ev.value} = ${trueEval.value};
104-
} else {
105-
${falseEval.code}
106-
${ev.isNull} = ${falseEval.isNull};
107-
${ev.value} = ${falseEval.value};
108-
}
109-
"""
110-
}
111-
112-
ev.copy(code = generatedCode)
69+
|${condEval.code}
70+
|boolean ${ev.isNull} = false;
71+
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
72+
|if (!${condEval.isNull} && ${condEval.value}) {
73+
| ${trueEval.code}
74+
| ${ev.isNull} = ${trueEval.isNull};
75+
| ${ev.value} = ${trueEval.value};
76+
|} else {
77+
| ${falseEval.code}
78+
| ${ev.isNull} = ${falseEval.isNull};
79+
| ${ev.value} = ${falseEval.value};
80+
|}
81+
""".stripMargin
82+
ev.copy(code = code)
11383
}
11484

11585
override def toString: String = s"if ($predicate) $trueValue else $falseValue"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ case class Alias(child: Expression, name: String)(
140140

141141
/** Just a simple passthrough for code generation. */
142142
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
143-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
143+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
144+
throw new IllegalStateException("Alias.doGenCode should not be called.")
145+
}
144146

145147
override def dataType: DataType = child.dataType
146148
override def nullable: Boolean = child.nullable

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 2 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -378,46 +378,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
378378
val eval2 = right.genCode(ctx)
379379

380380
// The result should be `false`, if any of them is `false` whenever the other is null or not.
381-
382-
// place generated code of eval1 and eval2 in separate methods if their code combined is large
383-
val combinedLength = eval1.code.length + eval2.code.length
384-
if (combinedLength > 1024 &&
385-
// Split these expressions only if they are created from a row object
386-
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
387-
388-
val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
389-
ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
390-
val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
391-
ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
392-
if (!left.nullable && !right.nullable) {
393-
val generatedCode = s"""
394-
$eval1FuncName(${ctx.INPUT_ROW});
395-
boolean ${ev.value} = false;
396-
if (${eval1GlobalValue}) {
397-
$eval2FuncName(${ctx.INPUT_ROW});
398-
${ev.value} = ${eval2GlobalValue};
399-
}
400-
"""
401-
ev.copy(code = generatedCode, isNull = "false")
402-
} else {
403-
val generatedCode = s"""
404-
$eval1FuncName(${ctx.INPUT_ROW});
405-
boolean ${ev.isNull} = false;
406-
boolean ${ev.value} = false;
407-
if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) {
408-
} else {
409-
$eval2FuncName(${ctx.INPUT_ROW});
410-
if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) {
411-
} else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
412-
${ev.value} = true;
413-
} else {
414-
${ev.isNull} = true;
415-
}
416-
}
417-
"""
418-
ev.copy(code = generatedCode)
419-
}
420-
} else if (!left.nullable && !right.nullable) {
381+
if (!left.nullable && !right.nullable) {
421382
ev.copy(code = s"""
422383
${eval1.code}
423384
boolean ${ev.value} = false;
@@ -480,46 +441,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
480441
val eval2 = right.genCode(ctx)
481442

482443
// The result should be `true`, if any of them is `true` whenever the other is null or not.
483-
484-
// place generated code of eval1 and eval2 in separate methods if their code combined is large
485-
val combinedLength = eval1.code.length + eval2.code.length
486-
if (combinedLength > 1024 &&
487-
// Split these expressions only if they are created from a row object
488-
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
489-
490-
val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
491-
ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
492-
val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
493-
ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
494-
if (!left.nullable && !right.nullable) {
495-
val generatedCode = s"""
496-
$eval1FuncName(${ctx.INPUT_ROW});
497-
boolean ${ev.value} = true;
498-
if (!${eval1GlobalValue}) {
499-
$eval2FuncName(${ctx.INPUT_ROW});
500-
${ev.value} = ${eval2GlobalValue};
501-
}
502-
"""
503-
ev.copy(code = generatedCode, isNull = "false")
504-
} else {
505-
val generatedCode = s"""
506-
$eval1FuncName(${ctx.INPUT_ROW});
507-
boolean ${ev.isNull} = false;
508-
boolean ${ev.value} = true;
509-
if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) {
510-
} else {
511-
$eval2FuncName(${ctx.INPUT_ROW});
512-
if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) {
513-
} else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
514-
${ev.value} = false;
515-
} else {
516-
${ev.isNull} = true;
517-
}
518-
}
519-
"""
520-
ev.copy(code = generatedCode)
521-
}
522-
} else if (!left.nullable && !right.nullable) {
444+
if (!left.nullable && !right.nullable) {
523445
ev.isNull = "false"
524446
ev.copy(code = s"""
525447
${eval1.code}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
9797
assert(actual(0) == cases)
9898
}
9999

100-
test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") {
100+
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {
101101
var strExpr: Expression = Literal("abc")
102102
for (_ <- 1 to 150) {
103103
strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8")
@@ -342,7 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
342342
projection(row)
343343
}
344344

345-
test("SPARK-21720: split large predications into blocks due to JVM code size limit") {
345+
test("SPARK-22543: split large predicates into blocks due to JVM code size limit") {
346346
val length = 600
347347

348348
val input = new GenericInternalRow(length)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ case class HashAggregateExec(
179179
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
180180
val initAgg = ctx.freshName("initAgg")
181181
ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;")
182+
// The generated function doesn't have input row in the code context.
183+
ctx.INPUT_ROW = null
182184

183185
// generate variables for aggregation buffer
184186
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])

0 commit comments

Comments
 (0)