Skip to content

Commit 7022190

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-22596][SQL] set ctx.currentVars in CodegenSupport.consume
## What changes were proposed in this pull request? `ctx.currentVars` means the input variables for the current operator, which is already decided in `CodegenSupport`, we can set it there instead of `doConsume`. also add more comments to help people understand the codegen framework. After this PR, we now have a principle about setting `ctx.currentVars` and `ctx.INPUT_ROW`: 1. for non-whole-stage-codegen path, never set them. (permit some special cases like generating ordering) 2. for whole-stage-codegen `produce` path, mostly we don't need to set them, but blocking operators may need to set them for expressions that produce data from data source, sort buffer, aggregate buffer, etc. 3. for whole-stage-codegen `consume` path, mostly we don't need to set them because `currentVars` is automatically set to child input variables and `INPUT_ROW` is mostly not used. A few plans need to tweak them as they may have different inputs, or they use the input row. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes #19803 from cloud-fan/codegen.
1 parent a1877f4 commit 7022190

File tree

8 files changed

+59
-50
lines changed

8 files changed

+59
-50
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,24 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
5959
}
6060

6161
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
62-
val javaType = ctx.javaType(dataType)
63-
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
6462
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
6563
val oev = ctx.currentVars(ordinal)
6664
ev.isNull = oev.isNull
6765
ev.value = oev.value
68-
val code = oev.code
69-
oev.code = ""
70-
ev.copy(code = code)
71-
} else if (nullable) {
72-
ev.copy(code = s"""
73-
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
74-
$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""")
66+
ev.copy(code = oev.code)
7567
} else {
76-
ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false")
68+
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
69+
val javaType = ctx.javaType(dataType)
70+
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
71+
if (nullable) {
72+
ev.copy(code =
73+
s"""
74+
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
75+
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
76+
""".stripMargin)
77+
} else {
78+
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
79+
}
7780
}
7881
}
7982
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,17 @@ class CodegenContext {
133133
term
134134
}
135135

136+
/**
137+
* Holding the variable name of the input row of the current operator, will be used by
138+
* `BoundReference` to generate code.
139+
*
140+
* Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW`
141+
* to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set
142+
* `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
143+
* `Expression.genCode`.
144+
*/
145+
final var INPUT_ROW = "i"
146+
136147
/**
137148
* Holding a list of generated columns as input of current operator, will be used by
138149
* BoundReference to generate code.
@@ -386,9 +397,6 @@ class CodegenContext {
386397
final val JAVA_FLOAT = "float"
387398
final val JAVA_DOUBLE = "double"
388399

389-
/** The variable name of the input row in generated code. */
390-
final var INPUT_ROW = "i"
391-
392400
/**
393401
* The map from a variable name to it's next ID.
394402
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ case class RowDataSourceScanExec(
123123
|while ($input.hasNext()) {
124124
| InternalRow $row = (InternalRow) $input.next();
125125
| $numOutputRows.add(1);
126-
| ${consume(ctx, columnsRowInput, null).trim}
126+
| ${consume(ctx, columnsRowInput).trim}
127127
| if (shouldStop()) return;
128128
|}
129129
""".stripMargin
@@ -355,19 +355,21 @@ case class FileSourceScanExec(
355355
// PhysicalRDD always just has one input
356356
val input = ctx.freshName("input")
357357
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
358-
val exprRows = output.zipWithIndex.map{ case (a, i) =>
359-
BoundReference(i, a.dataType, a.nullable)
360-
}
361358
val row = ctx.freshName("row")
359+
362360
ctx.INPUT_ROW = row
363361
ctx.currentVars = null
364-
val columnsRowInput = exprRows.map(_.genCode(ctx))
362+
// Always provide `outputVars`, so that the framework can help us build unsafe row if the input
363+
// row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
364+
val outputVars = output.zipWithIndex.map{ case (a, i) =>
365+
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
366+
}
365367
val inputRow = if (needsUnsafeRowConversion) null else row
366368
s"""
367369
|while ($input.hasNext()) {
368370
| InternalRow $row = (InternalRow) $input.next();
369371
| $numOutputRows.add(1);
370-
| ${consume(ctx, columnsRowInput, inputRow).trim}
372+
| ${consume(ctx, outputVars, inputRow).trim}
371373
| if (shouldStop()) return;
372374
|}
373375
""".stripMargin

sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,6 @@ case class ExpandExec(
133133
* size explosion.
134134
*/
135135

136-
// Set input variables
137-
ctx.currentVars = input
138-
139136
// Tracks whether a column has the same output for all rows.
140137
// Size of sameOutput array should equal N.
141138
// If sameOutput(i) is true, then the i-th column has the same value for all output rows given

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ case class GenerateExec(
135135
override def needCopyResult: Boolean = true
136136

137137
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
138-
ctx.currentVars = input
139-
140138
// Add input rows to the values when we are joining
141139
val values = if (join) {
142140
input

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,22 @@ trait CodegenSupport extends SparkPlan {
108108

109109
/**
110110
* Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
111+
*
112+
* Note that `outputVars` and `row` can't both be null.
111113
*/
112114
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
113115
val inputVars =
114-
if (row != null) {
116+
if (outputVars != null) {
117+
assert(outputVars.length == output.length)
118+
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
119+
outputVars.map(_.copy())
120+
} else {
121+
assert(row != null, "outputVars and row cannot both be null.")
115122
ctx.currentVars = null
116123
ctx.INPUT_ROW = row
117124
output.zipWithIndex.map { case (attr, i) =>
118125
BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
119126
}
120-
} else {
121-
assert(outputVars != null)
122-
assert(outputVars.length == output.length)
123-
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
124-
outputVars.map(_.copy())
125127
}
126128

127129
val rowVar = if (row != null) {
@@ -147,6 +149,11 @@ trait CodegenSupport extends SparkPlan {
147149
}
148150
}
149151

152+
// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
153+
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
154+
// generate code of `rowVar` manually.
155+
ctx.currentVars = inputVars
156+
ctx.INPUT_ROW = null
150157
ctx.freshNamePrefix = parent.variablePrefix
151158
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
152159
s"""
@@ -193,7 +200,8 @@ trait CodegenSupport extends SparkPlan {
193200
def usedInputs: AttributeSet = references
194201

195202
/**
196-
* Generate the Java source code to process the rows from child SparkPlan.
203+
* Generate the Java source code to process the rows from child SparkPlan. This should only be
204+
* called from `consume`.
197205
*
198206
* This should be override by subclass to support codegen.
199207
*
@@ -207,6 +215,11 @@ trait CodegenSupport extends SparkPlan {
207215
* }
208216
*
209217
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
218+
* When consuming as a listing of variables, the code to produce the input is already
219+
* generated and `CodegenContext.currentVars` is already set. When consuming as UnsafeRow,
220+
* implementations need to put `row.code` in the generated code and set
221+
* `CodegenContext.INPUT_ROW` manually. Some plans may need more tweaks as they have
222+
* different inputs(join build side, aggregate buffer, etc.), or other special cases.
210223
*/
211224
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
212225
throw new UnsupportedOperationException

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
5656
}
5757

5858
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
59-
val exprs = projectList.map(x =>
60-
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
61-
ctx.currentVars = input
59+
val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
6260
val resultVars = exprs.map(_.genCode(ctx))
6361
// Evaluation of non-deterministic expressions can't be deferred.
6462
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
@@ -152,8 +150,6 @@ case class FilterExec(condition: Expression, child: SparkPlan)
152150
""".stripMargin
153151
}
154152

155-
ctx.currentVars = input
156-
157153
// To generate the predicates we will follow this algorithm.
158154
// For each predicate that is not IsNotNull, we will generate them one by one loading attributes
159155
// as necessary. For each of both attributes, if there is an IsNotNull predicate we will

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,8 @@ case class DeserializeToObjectExec(
8181
}
8282

8383
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
84-
val bound = ExpressionCanonicalizer.execute(
85-
BindReferences.bindReference(deserializer, child.output))
86-
ctx.currentVars = input
87-
val resultVars = bound.genCode(ctx) :: Nil
88-
consume(ctx, resultVars)
84+
val resultObj = BindReferences.bindReference(deserializer, child.output).genCode(ctx)
85+
consume(ctx, resultObj :: Nil)
8986
}
9087

9188
override protected def doExecute(): RDD[InternalRow] = {
@@ -118,11 +115,9 @@ case class SerializeFromObjectExec(
118115
}
119116

120117
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
121-
val bound = serializer.map { expr =>
122-
ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
118+
val resultVars = serializer.map { expr =>
119+
BindReferences.bindReference[Expression](expr, child.output).genCode(ctx)
123120
}
124-
ctx.currentVars = input
125-
val resultVars = bound.map(_.genCode(ctx))
126121
consume(ctx, resultVars)
127122
}
128123

@@ -224,12 +219,9 @@ case class MapElementsExec(
224219
val funcObj = Literal.create(func, ObjectType(funcClass))
225220
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
226221

227-
val bound = ExpressionCanonicalizer.execute(
228-
BindReferences.bindReference(callFunc, child.output))
229-
ctx.currentVars = input
230-
val resultVars = bound.genCode(ctx) :: Nil
222+
val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx)
231223

232-
consume(ctx, resultVars)
224+
consume(ctx, result :: Nil)
233225
}
234226

235227
override protected def doExecute(): RDD[InternalRow] = {

0 commit comments

Comments
 (0)