Skip to content

Commit 5cc2987

Browse files
peter-tothhvanhovell
authored andcommitted
[SPARK-25767][SQL] Fix lazily evaluated stream of expressions in code generation
## What changes were proposed in this pull request? Code generation is incorrect if `outputVars` parameter of `consume` method in `CodegenSupport` contains a lazily evaluated stream of expressions. This PR fixes the issue by forcing the evaluation of `inputVars` before generating the code for UnsafeRow. ## How was this patch tested? Tested with the sample program provided in https://issues.apache.org/jira/browse/SPARK-25767 Closes apache#22789 from peter-toth/SPARK-25767. Authored-by: Peter Toth <[email protected]> Signed-off-by: Herman van Hovell <[email protected]> (cherry picked from commit 7fe5cff) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 22bec3c commit 5cc2987

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ trait CodegenSupport extends SparkPlan {
146146
if (outputVars != null) {
147147
assert(outputVars.length == output.length)
148148
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
149-
outputVars.map(_.copy())
149+
outputVars.map(_.copy()) match {
150+
case stream: Stream[ExprCode] => stream.force
151+
case other => other
152+
}
150153
} else {
151154
assert(row != null, "outputVars and row cannot both be null.")
152155
ctx.currentVars = null

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,4 +319,15 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
319319
assert(df.limit(1).collect() === Array(Row("bat", 8.0)))
320320
}
321321
}
322+
323+
test("SPARK-25767: Lazy evaluated stream of expressions handled correctly") {
324+
val a = Seq(1).toDF("key")
325+
val b = Seq((1, "a")).toDF("key", "value")
326+
val c = Seq(1).toDF("key")
327+
328+
val ab = a.join(b, Stream("key"), "left")
329+
val abc = ab.join(c, Seq("key"), "left")
330+
331+
checkAnswer(abc, Row(1, "a"))
332+
}
322333
}

0 commit comments

Comments
 (0)