Skip to content

Commit 8a0ed5a

Browse files
committed
[SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions()
## What changes were proposed in this pull request? Passing global variables to the split method is dangerous, as any mutating to it is ignored and may lead to unexpected behavior. To prevent this, one approach is to make sure no expression would output global variables: Localizing lifetime of mutable states in expressions. Another approach is, when calling `ctx.splitExpression`, make sure we don't use children's output as parameter names. Approach 1 is actually hard to do, as we need to check all expressions and operators that support whole-stage codegen. Approach 2 is easier as the callers of `ctx.splitExpressions` are not too many. Besides, approach 2 is more flexible, as children's output may be other stuff that can't be parameter name: literal, inlined statement(a + 1), etc. close #19865 close #19938 ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes #20021 from cloud-fan/codegen.
1 parent 4c2efde commit 8a0ed5a

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression {
602602

603603
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
604604
val evalChildren = children.map(_.genCode(ctx))
605-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull")
605+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
606606
val evals = evalChildren.map(eval =>
607607
s"""
608608
|${eval.code}
609-
|if (!${eval.isNull} && ($tmpIsNull ||
609+
|if (!${eval.isNull} && (${ev.isNull} ||
610610
| ${ctx.genGreater(dataType, ev.value, eval.value)})) {
611-
| $tmpIsNull = false;
611+
| ${ev.isNull} = false;
612612
| ${ev.value} = ${eval.value};
613613
|}
614614
""".stripMargin
@@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression {
628628
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
629629
ev.copy(code =
630630
s"""
631-
|$tmpIsNull = true;
631+
|${ev.isNull} = true;
632632
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
633633
|$codes
634-
|final boolean ${ev.isNull} = $tmpIsNull;
635634
""".stripMargin)
636635
}
637636
}
@@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression {
682681

683682
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
684683
val evalChildren = children.map(_.genCode(ctx))
685-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull")
684+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
686685
val evals = evalChildren.map(eval =>
687686
s"""
688687
|${eval.code}
689-
|if (!${eval.isNull} && ($tmpIsNull ||
688+
|if (!${eval.isNull} && (${ev.isNull} ||
690689
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
691-
| $tmpIsNull = false;
690+
| ${ev.isNull} = false;
692691
| ${ev.value} = ${eval.value};
693692
|}
694693
""".stripMargin
@@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
708707
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
709708
ev.copy(code =
710709
s"""
711-
|$tmpIsNull = true;
710+
|${ev.isNull} = true;
712711
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
713712
|$codes
714-
|final boolean ${ev.isNull} = $tmpIsNull;
715713
""".stripMargin)
716714
}
717715
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class CodegenContext {
128128
* `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
129129
* `Expression.genCode`.
130130
*/
131-
final var INPUT_ROW = "i"
131+
var INPUT_ROW = "i"
132132

133133
/**
134134
* Holding a list of generated columns as input of current operator, will be used by
@@ -146,22 +146,30 @@ class CodegenContext {
146146
* as a member variable
147147
*
148148
* They will be kept as member variables in generated classes like `SpecificProjection`.
149+
*
150+
* Exposed for tests only.
149151
*/
150-
val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
152+
private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
151153
mutable.ArrayBuffer.empty[(String, String)]
152154

153155
/**
154156
* The mapping between mutable state types and corrseponding compacted arrays.
155157
* The keys are java type string. The values are [[MutableStateArrays]] which encapsulates
156158
* the compacted arrays for the mutable states with the same java type.
159+
*
160+
* Exposed for tests only.
157161
*/
158-
val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
162+
private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
159163
mutable.Map.empty[String, MutableStateArrays]
160164

161165
// An array holds the code that will initialize each state
162-
val mutableStateInitCode: mutable.ArrayBuffer[String] =
166+
// Exposed for tests only.
167+
private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
163168
mutable.ArrayBuffer.empty[String]
164169

170+
// Tracks the names of all the mutable states.
171+
private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty
172+
165173
/**
166174
* This class holds a set of names of mutableStateArrays that is used for compacting mutable
167175
* states for a certain type, and holds the next available slot of the current compacted array.
@@ -172,7 +180,11 @@ class CodegenContext {
172180

173181
private[this] var currentIndex = 0
174182

175-
private def createNewArray() = arrayNames.append(freshName("mutableStateArray"))
183+
private def createNewArray() = {
184+
val newArrayName = freshName("mutableStateArray")
185+
mutableStateNames += newArrayName
186+
arrayNames.append(newArrayName)
187+
}
176188

177189
def getCurrentIndex: Int = currentIndex
178190

@@ -241,6 +253,7 @@ class CodegenContext {
241253
val initCode = initFunc(varName)
242254
inlinedMutableStates += ((javaType, varName))
243255
mutableStateInitCode += initCode
256+
mutableStateNames += varName
244257
varName
245258
} else {
246259
val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
@@ -930,6 +943,15 @@ class CodegenContext {
930943
// inline execution if only one block
931944
blocks.head
932945
} else {
946+
if (Utils.isTesting) {
947+
// Passing global variables to the split method is dangerous, as any mutating to it is
948+
// ignored and may lead to unexpected behavior.
949+
arguments.foreach { case (_, name) =>
950+
assert(!mutableStateNames.contains(name),
951+
s"split function argument $name cannot be a global variable.")
952+
}
953+
}
954+
933955
val func = freshName(funcName)
934956
val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
935957
val functions = blocks.zipWithIndex.map { case (body, i) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ case class CaseWhen(
190190
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
191191
// We won't go on anymore on the computation.
192192
val resultState = ctx.freshName("caseWhenResultState")
193-
val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult")
193+
ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
194194

195195
// these blocks are meant to be inside a
196196
// do {
@@ -205,7 +205,7 @@ case class CaseWhen(
205205
|if (!${cond.isNull} && ${cond.value}) {
206206
| ${res.code}
207207
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
208-
| $tmpResult = ${res.value};
208+
| ${ev.value} = ${res.value};
209209
| continue;
210210
|}
211211
""".stripMargin
@@ -216,7 +216,7 @@ case class CaseWhen(
216216
s"""
217217
|${res.code}
218218
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
219-
|$tmpResult = ${res.value};
219+
|${ev.value} = ${res.value};
220220
""".stripMargin
221221
}
222222

@@ -264,13 +264,11 @@ case class CaseWhen(
264264
ev.copy(code =
265265
s"""
266266
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
267-
|$tmpResult = ${ctx.defaultValue(dataType)};
268267
|do {
269268
| $codes
270269
|} while (false);
271270
|// TRUE if any condition is met and the result is null, or no any condition is met.
272271
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
273-
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
274272
""".stripMargin)
275273
}
276274
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7272
}
7373

7474
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
75-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull")
75+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
7676

7777
// all the evals are meant to be in a do { ... } while (false); loop
7878
val evals = children.map { e =>
7979
val eval = e.genCode(ctx)
8080
s"""
8181
|${eval.code}
8282
|if (!${eval.isNull}) {
83-
| $tmpIsNull = false;
83+
| ${ev.isNull} = false;
8484
| ${ev.value} = ${eval.value};
8585
| continue;
8686
|}
@@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
103103
foldFunctions = _.map { funcCall =>
104104
s"""
105105
|${ev.value} = $funcCall;
106-
|if (!$tmpIsNull) {
106+
|if (!${ev.isNull}) {
107107
| continue;
108108
|}
109109
""".stripMargin
@@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
112112

113113
ev.copy(code =
114114
s"""
115-
|$tmpIsNull = true;
115+
|${ev.isNull} = true;
116116
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
117117
|do {
118118
| $codes
119119
|} while (false);
120-
|final boolean ${ev.isNull} = $tmpIsNull;
121120
""".stripMargin)
122121
}
123122
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
285285
|${valueGen.code}
286286
|byte $tmpResult = $HAS_NULL;
287287
|if (!${valueGen.isNull}) {
288-
| $tmpResult = 0;
288+
| $tmpResult = $NOT_MATCHED;
289289
| $javaDataType $valueArg = ${valueGen.value};
290290
| do {
291291
| $codes

0 commit comments

Comments
 (0)