Skip to content

Commit ea2fbf4

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-22705][SQL] Case, Coalesce, and In use less global variables
## What changes were proposed in this pull request? This PR accomplishes the following two items. 1. Reduce # of global variables from two to one for generated code of `Case` and `Coalesce` and remove global variables for generated code of `In`. 2. Make lifetime of global variable local within an operation Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures that an variable is not passed to arguments in a method split by `CodegenContext.splitExpressions()`, which is addressed by #19865. ## How was this patch tested? Added new tests into `PredicateSuite`, `NullExpressionsSuite`, and `ConditionalExpressionSuite`. Author: Kazuaki Ishizaki <[email protected]> Closes #19901 from kiszk/SPARK-22705.
1 parent e103adf commit ea2fbf4

File tree

6 files changed

+91
-49
lines changed

6 files changed

+91
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,18 @@ case class CaseWhen(
180180
}
181181

182182
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
183-
// This variable represents whether the first successful condition is met or not.
184-
// It is initialized to `false` and it is set to `true` when the first condition which
185-
// evaluates to `true` is met and therefore is not needed to go on anymore on the computation
186-
// of the following conditions.
187-
val conditionMet = ctx.freshName("caseWhenConditionMet")
188-
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
189-
ctx.addMutableState(ctx.javaType(dataType), ev.value)
183+
// This variable holds the state of the result:
184+
// -1 means the condition is not met yet and the result is unknown.
185+
val NOT_MATCHED = -1
186+
// 0 means the condition is met and result is not null.
187+
val HAS_NONNULL = 0
188+
// 1 means the condition is met and result is null.
189+
val HAS_NULL = 1
190+
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
191+
// We won't go on anymore on the computation.
192+
val resultState = ctx.freshName("caseWhenResultState")
193+
val tmpResult = ctx.freshName("caseWhenTmpResult")
194+
ctx.addMutableState(ctx.javaType(dataType), tmpResult)
190195

191196
// these blocks are meant to be inside a
192197
// do {
@@ -200,9 +205,8 @@ case class CaseWhen(
200205
|${cond.code}
201206
|if (!${cond.isNull} && ${cond.value}) {
202207
| ${res.code}
203-
| ${ev.isNull} = ${res.isNull};
204-
| ${ev.value} = ${res.value};
205-
| $conditionMet = true;
208+
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
209+
| $tmpResult = ${res.value};
206210
| continue;
207211
|}
208212
""".stripMargin
@@ -212,59 +216,63 @@ case class CaseWhen(
212216
val res = elseExpr.genCode(ctx)
213217
s"""
214218
|${res.code}
215-
|${ev.isNull} = ${res.isNull};
216-
|${ev.value} = ${res.value};
219+
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
220+
|$tmpResult = ${res.value};
217221
""".stripMargin
218222
}
219223

220224
val allConditions = cases ++ elseCode
221225

222226
// This generates code like:
223-
// conditionMet = caseWhen_1(i);
224-
// if(conditionMet) {
227+
// caseWhenResultState = caseWhen_1(i);
228+
// if(caseWhenResultState != -1) {
225229
// continue;
226230
// }
227-
// conditionMet = caseWhen_2(i);
228-
// if(conditionMet) {
231+
// caseWhenResultState = caseWhen_2(i);
232+
// if(caseWhenResultState != -1) {
229233
// continue;
230234
// }
231235
// ...
232236
// and the declared methods are:
233-
// private boolean caseWhen_1234() {
234-
// boolean conditionMet = false;
237+
// private byte caseWhen_1234() {
238+
// byte caseWhenResultState = -1;
235239
// do {
236240
// // here the evaluation of the conditions
237241
// } while (false);
238-
// return conditionMet;
242+
// return caseWhenResultState;
239243
// }
240244
val codes = ctx.splitExpressionsWithCurrentInputs(
241245
expressions = allConditions,
242246
funcName = "caseWhen",
243-
returnType = ctx.JAVA_BOOLEAN,
247+
returnType = ctx.JAVA_BYTE,
244248
makeSplitFunction = func =>
245249
s"""
246-
|${ctx.JAVA_BOOLEAN} $conditionMet = false;
250+
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
247251
|do {
248252
| $func
249253
|} while (false);
250-
|return $conditionMet;
254+
|return $resultState;
251255
""".stripMargin,
252256
foldFunctions = _.map { funcCall =>
253257
s"""
254-
|$conditionMet = $funcCall;
255-
|if ($conditionMet) {
258+
|$resultState = $funcCall;
259+
|if ($resultState != $NOT_MATCHED) {
256260
| continue;
257261
|}
258262
""".stripMargin
259263
}.mkString)
260264

261-
ev.copy(code = s"""
262-
${ev.isNull} = true;
263-
${ev.value} = ${ctx.defaultValue(dataType)};
264-
${ctx.JAVA_BOOLEAN} $conditionMet = false;
265-
do {
266-
$codes
267-
} while (false);""")
265+
ev.copy(code =
266+
s"""
267+
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
268+
|$tmpResult = ${ctx.defaultValue(dataType)};
269+
|do {
270+
| $codes
271+
|} while (false);
272+
|// TRUE if any condition is met and the result is null, or no any condition is met.
273+
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
274+
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
275+
""".stripMargin)
268276
}
269277
}
270278

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,35 +72,39 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7272
}
7373

7474
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
75-
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
76-
ctx.addMutableState(ctx.javaType(dataType), ev.value)
75+
val tmpIsNull = ctx.freshName("coalesceTmpIsNull")
76+
ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
7777

7878
// all the evals are meant to be in a do { ... } while (false); loop
7979
val evals = children.map { e =>
8080
val eval = e.genCode(ctx)
8181
s"""
8282
|${eval.code}
8383
|if (!${eval.isNull}) {
84-
| ${ev.isNull} = false;
84+
| $tmpIsNull = false;
8585
| ${ev.value} = ${eval.value};
8686
| continue;
8787
|}
8888
""".stripMargin
8989
}
9090

91+
val resultType = ctx.javaType(dataType)
9192
val codes = ctx.splitExpressionsWithCurrentInputs(
9293
expressions = evals,
9394
funcName = "coalesce",
95+
returnType = resultType,
9496
makeSplitFunction = func =>
9597
s"""
98+
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
9699
|do {
97100
| $func
98101
|} while (false);
102+
|return ${ev.value};
99103
""".stripMargin,
100104
foldFunctions = _.map { funcCall =>
101105
s"""
102-
|$funcCall;
103-
|if (!${ev.isNull}) {
106+
|${ev.value} = $funcCall;
107+
|if (!$tmpIsNull) {
104108
| continue;
105109
|}
106110
""".stripMargin
@@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
109113

110114
ev.copy(code =
111115
s"""
112-
|${ev.isNull} = true;
113-
|${ev.value} = ${ctx.defaultValue(dataType)};
116+
|$tmpIsNull = true;
117+
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
114118
|do {
115119
| $codes
116120
|} while (false);
121+
|final boolean ${ev.isNull} = $tmpIsNull;
117122
""".stripMargin)
118123
}
119124
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,37 +237,44 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
237237
val javaDataType = ctx.javaType(value.dataType)
238238
val valueGen = value.genCode(ctx)
239239
val listGen = list.map(_.genCode(ctx))
240-
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
241-
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
240+
// inTmpResult has 3 possible values:
241+
// -1 means no matches found and there is at least one value in the list evaluated to null
242+
val HAS_NULL = -1
243+
// 0 means no matches found and all values in the list are not null
244+
val NOT_MATCHED = 0
245+
// 1 means one value in the list is matched
246+
val MATCHED = 1
247+
val tmpResult = ctx.freshName("inTmpResult")
242248
val valueArg = ctx.freshName("valueArg")
243249
// All the blocks are meant to be inside a do { ... } while (false); loop.
244250
// The evaluation of variables can be stopped when we find a matching value.
245251
val listCode = listGen.map(x =>
246252
s"""
247253
|${x.code}
248254
|if (${x.isNull}) {
249-
| ${ev.isNull} = true;
255+
| $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
250256
|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
251-
| ${ev.isNull} = false;
252-
| ${ev.value} = true;
257+
| $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
253258
| continue;
254259
|}
255260
""".stripMargin)
256261

257262
val codes = ctx.splitExpressionsWithCurrentInputs(
258263
expressions = listCode,
259264
funcName = "valueIn",
260-
extraArguments = (javaDataType, valueArg) :: Nil,
265+
extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
266+
returnType = ctx.JAVA_BYTE,
261267
makeSplitFunction = body =>
262268
s"""
263269
|do {
264270
| $body
265271
|} while (false);
272+
|return $tmpResult;
266273
""".stripMargin,
267274
foldFunctions = _.map { funcCall =>
268275
s"""
269-
|$funcCall;
270-
|if (${ev.value}) {
276+
|$tmpResult = $funcCall;
277+
|if ($tmpResult == $MATCHED) {
271278
| continue;
272279
|}
273280
""".stripMargin
@@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
276283
ev.copy(code =
277284
s"""
278285
|${valueGen.code}
279-
|${ev.value} = false;
280-
|${ev.isNull} = ${valueGen.isNull};
281-
|if (!${ev.isNull}) {
286+
|byte $tmpResult = $HAS_NULL;
287+
|if (!${valueGen.isNull}) {
288+
| $tmpResult = 0;
282289
| $javaDataType $valueArg = ${valueGen.value};
283290
| do {
284291
| $codes
285292
| } while (false);
286293
|}
294+
|final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL);
295+
|final boolean ${ev.value} = ($tmpResult == $MATCHED);
287296
""".stripMargin)
288297
}
289298

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2223
import org.apache.spark.sql.types._
2324

2425
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
145146
IndexedSeq((Literal(12) === Literal(1), Literal(42)),
146147
(Literal(12) === Literal(42), Literal(1))))
147148
}
149+
150+
test("SPARK-22705: case when should use less global variables") {
151+
val ctx = new CodegenContext()
152+
CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
153+
assert(ctx.mutableStates.size == 1)
154+
}
148155
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
22+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2223
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
2324
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
2425
import org.apache.spark.sql.types._
@@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
155156
checkEvaluation(Coalesce(inputs), "x_1")
156157
}
157158

159+
test("SPARK-22705: Coalesce should use less global variables") {
160+
val ctx = new CodegenContext()
161+
Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
162+
assert(ctx.mutableStates.size == 1)
163+
}
164+
158165
test("AtLeastNNonNulls should not throw 64kb exception") {
159166
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
160167
checkEvaluation(AtLeastNNonNulls(1, inputs), true)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
246246
checkEvaluation(In(Literal(1.0D), sets), true)
247247
}
248248

249+
test("SPARK-22705: In should use less global variables") {
250+
val ctx = new CodegenContext()
251+
In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
252+
assert(ctx.mutableStates.isEmpty)
253+
}
254+
249255
test("INSET") {
250256
val hS = HashSet[Any]() + 1 + 2
251257
val nS = HashSet[Any]() + 1 + 2 + null

0 commit comments

Comments
 (0)