Skip to content

Commit b70e483

Browse files
committed
[SPARK-22617][SQL] make splitExpressions extract current input of the context
## What changes were proposed in this pull request? Mostly when we call `CodegenContext.splitExpressions`, we want to split the code into methods and pass the current inputs of the codegen context to these methods so that the code in these methods can still be evaluated. This PR makes the expectation clear, while still keep the advanced version of `splitExpressions` to customize the inputs to pass to generated methods. ## How was this patch tested? existing test Author: Wenchen Fan <[email protected]> Closes #19827 from cloud-fan/codegen.
1 parent 1e07fff commit b70e483

File tree

11 files changed

+108
-86
lines changed

11 files changed

+108
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression {
614614
}
615615
"""
616616
}
617-
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
617+
val codes = ctx.splitExpressions(evalChildren.map(updateEval))
618618
ev.copy(code = s"""
619619
${ev.isNull} = true;
620620
${ev.value} = ${ctx.defaultValue(dataType)};
@@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
680680
}
681681
"""
682682
}
683-
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval))
683+
val codes = ctx.splitExpressions(evalChildren.map(updateEval))
684684
ev.copy(code = s"""
685685
${ev.isNull} = true;
686686
${ev.value} = ${ctx.defaultValue(dataType)};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -781,15 +781,18 @@ class CodegenContext {
781781
* beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it
782782
* instead, because classes have a constant pool limit of 65,536 named values.
783783
*
784-
* @param row the variable name of row that is used by expressions
784+
* Note that we will extract the current inputs of this context and pass them to the generated
785+
* functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole
786+
* stage codegen path. Whole stage codegen path is not supported yet.
787+
*
785788
* @param expressions the codes to evaluate expressions.
786789
*/
787-
def splitExpressions(row: String, expressions: Seq[String]): String = {
788-
if (row == null || currentVars != null) {
789-
// Cannot split these expressions because they are not created from a row object.
790+
def splitExpressions(expressions: Seq[String]): String = {
791+
// TODO: support whole stage codegen
792+
if (INPUT_ROW == null || currentVars != null) {
790793
return expressions.mkString("\n")
791794
}
792-
splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil)
795+
splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil)
793796
}
794797

795798
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
9191
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
9292
}
9393

94-
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
95-
val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
94+
val allProjections = ctx.splitExpressions(projectionCodes)
95+
val allUpdates = ctx.splitExpressions(updates)
9696

9797
val codeBody = s"""
9898
public java.lang.Object generate(Object[] references) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
4545
ctx: CodegenContext,
4646
input: String,
4747
schema: StructType): ExprCode = {
48-
val tmp = ctx.freshName("tmp")
48+
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
49+
val tmpInput = ctx.freshName("tmpInput")
4950
val output = ctx.freshName("safeRow")
5051
val values = ctx.freshName("values")
5152
// These expressions could be split into multiple functions
@@ -54,17 +55,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5455
val rowClass = classOf[GenericInternalRow].getName
5556

5657
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
57-
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
58+
val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
5859
s"""
59-
if (!$tmp.isNullAt($i)) {
60+
if (!$tmpInput.isNullAt($i)) {
6061
${converter.code}
6162
$values[$i] = ${converter.value};
6263
}
6364
"""
6465
}
65-
val allFields = ctx.splitExpressions(tmp, fieldWriters)
66+
val allFields = ctx.splitExpressions(
67+
expressions = fieldWriters,
68+
funcName = "writeFields",
69+
arguments = Seq("InternalRow" -> tmpInput)
70+
)
6671
val code = s"""
67-
final InternalRow $tmp = $input;
72+
final InternalRow $tmpInput = $input;
6873
$values = new Object[${schema.length}];
6974
$allFields
7075
final InternalRow $output = new $rowClass($values);
@@ -78,20 +83,22 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
7883
ctx: CodegenContext,
7984
input: String,
8085
elementType: DataType): ExprCode = {
81-
val tmp = ctx.freshName("tmp")
86+
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
87+
val tmpInput = ctx.freshName("tmpInput")
8288
val output = ctx.freshName("safeArray")
8389
val values = ctx.freshName("values")
8490
val numElements = ctx.freshName("numElements")
8591
val index = ctx.freshName("index")
8692
val arrayClass = classOf[GenericArrayData].getName
8793

88-
val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
94+
val elementConverter = convertToSafe(
95+
ctx, ctx.getValue(tmpInput, elementType, index), elementType)
8996
val code = s"""
90-
final ArrayData $tmp = $input;
91-
final int $numElements = $tmp.numElements();
97+
final ArrayData $tmpInput = $input;
98+
final int $numElements = $tmpInput.numElements();
9299
final Object[] $values = new Object[$numElements];
93100
for (int $index = 0; $index < $numElements; $index++) {
94-
if (!$tmp.isNullAt($index)) {
101+
if (!$tmpInput.isNullAt($index)) {
95102
${elementConverter.code}
96103
$values[$index] = ${elementConverter.value};
97104
}
@@ -107,14 +114,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
107114
input: String,
108115
keyType: DataType,
109116
valueType: DataType): ExprCode = {
110-
val tmp = ctx.freshName("tmp")
117+
val tmpInput = ctx.freshName("tmpInput")
111118
val output = ctx.freshName("safeMap")
112119
val mapClass = classOf[ArrayBasedMapData].getName
113120

114-
val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
115-
val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
121+
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
122+
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
116123
val code = s"""
117-
final MapData $tmp = $input;
124+
final MapData $tmpInput = $input;
118125
${keyConverter.code}
119126
${valueConverter.code}
120127
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
@@ -152,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
152159
}
153160
"""
154161
}
155-
val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
162+
val allExpressions = ctx.splitExpressions(expressionCodes)
156163

157164
val codeBody = s"""
158165
public java.lang.Object generate(Object[] references) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
3636
case NullType => true
3737
case t: AtomicType => true
3838
case _: CalendarIntervalType => true
39-
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
39+
case t: StructType => t.forall(field => canSupport(field.dataType))
4040
case t: ArrayType if canSupport(t.elementType) => true
4141
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
4242
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
@@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4949
input: String,
5050
fieldTypes: Seq[DataType],
5151
bufferHolder: String): String = {
52+
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
53+
val tmpInput = ctx.freshName("tmpInput")
5254
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
53-
val javaType = ctx.javaType(dt)
54-
val isNullVar = ctx.freshName("isNull")
55-
val valueVar = ctx.freshName("value")
56-
val defaultValue = ctx.defaultValue(dt)
57-
val readValue = ctx.getValue(input, dt, i.toString)
58-
val code =
59-
s"""
60-
boolean $isNullVar = $input.isNullAt($i);
61-
$javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
62-
"""
63-
ExprCode(code, isNullVar, valueVar)
55+
ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
6456
}
6557

6658
s"""
67-
if ($input instanceof UnsafeRow) {
68-
${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
59+
final InternalRow $tmpInput = $input;
60+
if ($tmpInput instanceof UnsafeRow) {
61+
${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
6962
} else {
70-
${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
63+
${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)}
7164
}
7265
"""
7366
}
@@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
167160
}
168161
}
169162

163+
val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) {
164+
// TODO: support whole stage codegen
165+
writeFields.mkString("\n")
166+
} else {
167+
assert(row != null, "the input row name cannot be null when generating code to write it.")
168+
ctx.splitExpressions(
169+
expressions = writeFields,
170+
funcName = "writeFields",
171+
arguments = Seq("InternalRow" -> row))
172+
}
173+
170174
s"""
171175
$resetWriter
172-
${ctx.splitExpressions(row, writeFields)}
176+
$writeFieldsCode
173177
""".trim
174178
}
175179

@@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
179183
input: String,
180184
elementType: DataType,
181185
bufferHolder: String): String = {
186+
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
187+
val tmpInput = ctx.freshName("tmpInput")
182188
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
183189
val arrayWriter = ctx.freshName("arrayWriter")
184190
ctx.addMutableState(arrayWriterClass, arrayWriter,
185191
s"$arrayWriter = new $arrayWriterClass();")
186192
val numElements = ctx.freshName("numElements")
187193
val index = ctx.freshName("index")
188-
val element = ctx.freshName("element")
189194

190195
val et = elementType match {
191196
case udt: UserDefinedType[_] => udt.sqlType
@@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
201206
}
202207

203208
val tmpCursor = ctx.freshName("tmpCursor")
209+
val element = ctx.getValue(tmpInput, et, index)
204210
val writeElement = et match {
205211
case t: StructType =>
206212
s"""
@@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
233239

234240
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
235241
s"""
236-
if ($input instanceof UnsafeArrayData) {
237-
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
242+
final ArrayData $tmpInput = $input;
243+
if ($tmpInput instanceof UnsafeArrayData) {
244+
${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
238245
} else {
239-
final int $numElements = $input.numElements();
246+
final int $numElements = $tmpInput.numElements();
240247
$arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
241248

242249
for (int $index = 0; $index < $numElements; $index++) {
243-
if ($input.isNullAt($index)) {
250+
if ($tmpInput.isNullAt($index)) {
244251
$arrayWriter.setNull$primitiveTypeName($index);
245252
} else {
246-
final $jt $element = ${ctx.getValue(input, et, index)};
247253
$writeElement
248254
}
249255
}
@@ -258,31 +264,28 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
258264
keyType: DataType,
259265
valueType: DataType,
260266
bufferHolder: String): String = {
261-
val keys = ctx.freshName("keys")
262-
val values = ctx.freshName("values")
267+
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
268+
val tmpInput = ctx.freshName("tmpInput")
263269
val tmpCursor = ctx.freshName("tmpCursor")
264270

265-
266271
// Writes out unsafe map according to the format described in `UnsafeMapData`.
267272
s"""
268-
if ($input instanceof UnsafeMapData) {
269-
${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
273+
final MapData $tmpInput = $input;
274+
if ($tmpInput instanceof UnsafeMapData) {
275+
${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
270276
} else {
271-
final ArrayData $keys = $input.keyArray();
272-
final ArrayData $values = $input.valueArray();
273-
274277
// preserve 8 bytes to write the key array numBytes later.
275278
$bufferHolder.grow(8);
276279
$bufferHolder.cursor += 8;
277280

278281
// Remember the current cursor so that we can write numBytes of key array later.
279282
final int $tmpCursor = $bufferHolder.cursor;
280283

281-
${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
284+
${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)}
282285
// Write the numBytes of key array into the first 8 bytes.
283286
Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
284287

285-
${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
288+
${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)}
286289
}
287290
"""
288291
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
6363
val (preprocess, assigns, postprocess, arrayData) =
6464
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
6565
ev.copy(
66-
code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess,
66+
code = preprocess + ctx.splitExpressions(assigns) + postprocess,
6767
value = arrayData,
6868
isNull = "false")
6969
}
@@ -216,10 +216,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
216216
s"""
217217
final boolean ${ev.isNull} = false;
218218
$preprocessKeyData
219-
${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
219+
${ctx.splitExpressions(assignKeys)}
220220
$postprocessKeyData
221221
$preprocessValueData
222-
${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
222+
${ctx.splitExpressions(assignValues)}
223223
$postprocessValueData
224224
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
225225
"""
@@ -351,24 +351,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
351351
val rowClass = classOf[GenericInternalRow].getName
352352
val values = ctx.freshName("values")
353353
ctx.addMutableState("Object[]", values, s"$values = null;")
354-
355-
ev.copy(code = s"""
356-
$values = new Object[${valExprs.size}];""" +
357-
ctx.splitExpressions(
358-
ctx.INPUT_ROW,
359-
valExprs.zipWithIndex.map { case (e, i) =>
360-
val eval = e.genCode(ctx)
361-
eval.code + s"""
354+
val valuesCode = ctx.splitExpressions(
355+
valExprs.zipWithIndex.map { case (e, i) =>
356+
val eval = e.genCode(ctx)
357+
s"""
358+
${eval.code}
362359
if (${eval.isNull}) {
363360
$values[$i] = null;
364361
} else {
365362
$values[$i] = ${eval.value};
366363
}"""
367-
}) +
364+
})
365+
366+
ev.copy(code =
368367
s"""
369-
final InternalRow ${ev.value} = new $rowClass($values);
370-
$values = null;
371-
""", isNull = "false")
368+
|$values = new Object[${valExprs.size}];
369+
|$valuesCode
370+
|final InternalRow ${ev.value} = new $rowClass($values);
371+
|$values = null;
372+
""".stripMargin, isNull = "false")
372373
}
373374

374375
override def prettyName: String = "named_struct"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
203203
ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];")
204204
val values = children.tail
205205
val dataTypes = values.take(numFields).map(_.dataType)
206-
val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
206+
val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
207207
val fields = Seq.tabulate(numFields) { col =>
208208
val index = row * numFields + col
209209
if (index < values.length) values(index) else Literal(null, dataTypes(col))

0 commit comments

Comments
 (0)