Skip to content

Commit 8c68718

Browse files
juliuszsompolskicloud-fan
authored andcommitted
[SPARK-26159] Codegen for LocalTableScanExec and RDDScanExec
## What changes were proposed in this pull request? Implement codegen for `LocalTableScanExec` and `ExistingRDDExec`. Refactor to share code between `LocalTableScanExec`, `ExistingRDDExec`, `InputAdapter` and `RowDataSourceScanExec`. The difference in `doProduce` between these four was that `ExistingRDDExec` and `RowDataSourceScanExec` triggered adding an `UnsafeProjection`, while `InputAdapter` and `LocalTableScanExec` did not. In the new trait `InputRDDCodegen` I added a flag `createUnsafeProjection` which the operators set accordingly. Note: `LocalTableScanExec` explicitly creates its input as `UnsafeRows`, so it was obvious why it doesn't need an `UnsafeProjection`. But if an `InputAdapter` may take input that is `InternalRows` but not `UnsafeRows`, then I think it doesn't need an unsafe projection just because any other operator that is its parent would do that. That assumes that that any parent operator would always result in some `UnsafeProjection` being eventually added, and hence the output of the `WholeStageCodegen` unit would be `UnsafeRows`. If these assumptions hold, I think `createUnsafeProjection` could be set to `(parent == null)`. Note: Do not codegen `LocalTableScanExec` when it's the only operator. `LocalTableScanExec` has optimized driver-only `executeCollect` and `executeTake` code paths that are used to return `Command` results without starting Spark Jobs. They can no longer be used if the `LocalTableScanExec` gets optimized. ## How was this patch tested? Covered and used in existing tests. Closes apache#23127 from juliuszsompolski/SPARK-26159. Authored-by: Juliusz Sompolski <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2d89d10 commit 8c68718

File tree

6 files changed

+86
-51
lines changed

6 files changed

+86
-51
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def explain(self, extended=False):
257257
258258
>>> df.explain()
259259
== Physical Plan ==
260-
Scan ExistingRDD[age#0,name#1]
260+
*(1) Scan ExistingRDD[age#0,name#1]
261261
262262
>>> df.explain(True)
263263
== Parsed Logical Plan ==

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ case class RowDataSourceScanExec(
8484
rdd: RDD[InternalRow],
8585
@transient relation: BaseRelation,
8686
override val tableIdentifier: Option[TableIdentifier])
87-
extends DataSourceScanExec {
87+
extends DataSourceScanExec with InputRDDCodegen {
8888

8989
def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput)
9090

@@ -104,30 +104,10 @@ case class RowDataSourceScanExec(
104104
}
105105
}
106106

107-
override def inputRDDs(): Seq[RDD[InternalRow]] = {
108-
rdd :: Nil
109-
}
107+
// Input can be InternalRow, has to be turned into UnsafeRows.
108+
override protected val createUnsafeProjection: Boolean = true
110109

111-
override protected def doProduce(ctx: CodegenContext): String = {
112-
val numOutputRows = metricTerm(ctx, "numOutputRows")
113-
// PhysicalRDD always just has one input
114-
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
115-
val exprRows = output.zipWithIndex.map{ case (a, i) =>
116-
BoundReference(i, a.dataType, a.nullable)
117-
}
118-
val row = ctx.freshName("row")
119-
ctx.INPUT_ROW = row
120-
ctx.currentVars = null
121-
val columnsRowInput = exprRows.map(_.genCode(ctx))
122-
s"""
123-
|while ($input.hasNext()) {
124-
| InternalRow $row = (InternalRow) $input.next();
125-
| $numOutputRows.add(1);
126-
| ${consume(ctx, columnsRowInput).trim}
127-
| if (shouldStop()) return;
128-
|}
129-
""".stripMargin
130-
}
110+
override def inputRDD: RDD[InternalRow] = rdd
131111

132112
override val metadata: Map[String, String] = {
133113
val markedFilters = for (filter <- filters) yield {

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ case class RDDScanExec(
175175
rdd: RDD[InternalRow],
176176
name: String,
177177
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
178-
override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
178+
override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode with InputRDDCodegen {
179179

180180
private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("")
181181

@@ -199,4 +199,9 @@ case class RDDScanExec(
199199
override def simpleString: String = {
200200
s"$nodeName${truncatedString(output, "[", ",", "]")}"
201201
}
202+
203+
// Input can be InternalRow, has to be turned into UnsafeRows.
204+
override protected val createUnsafeProjection: Boolean = true
205+
206+
override def inputRDD: RDD[InternalRow] = rdd
202207
}

sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
3131
*/
3232
case class LocalTableScanExec(
3333
output: Seq[Attribute],
34-
@transient rows: Seq[InternalRow]) extends LeafExecNode {
34+
@transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen {
3535

3636
override lazy val metrics = Map(
3737
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -76,4 +76,12 @@ case class LocalTableScanExec(
7676
longMetric("numOutputRows").add(taken.size)
7777
taken
7878
}
79+
80+
// Input is already UnsafeRows.
81+
override protected val createUnsafeProjection: Boolean = false
82+
83+
// Do not codegen when there is no parent - to support the fast driver-local collect/take paths.
84+
override def supportCodegen: Boolean = (parent != null)
85+
86+
override def inputRDD: RDD[InternalRow] = rdd
7987
}

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,15 @@ trait CodegenSupport extends SparkPlan {
350350
*/
351351
def needStopCheck: Boolean = parent.needStopCheck
352352

353+
/**
354+
* Helper default should stop check code.
355+
*/
356+
def shouldStopCheckCode: String = if (needStopCheck) {
357+
"if (shouldStop()) return;"
358+
} else {
359+
"// shouldStop check is eliminated"
360+
}
361+
353362
/**
354363
* A sequence of checks which evaluate to true if the downstream Limit operators have not received
355364
* enough records and reached the limit. If current node is a data producing node, it can leverage
@@ -406,14 +415,61 @@ trait BlockingOperatorWithCodegen extends CodegenSupport {
406415
override def limitNotReachedChecks: Seq[String] = Nil
407416
}
408417

418+
/**
419+
* Leaf codegen node reading from a single RDD.
420+
*/
421+
trait InputRDDCodegen extends CodegenSupport {
422+
423+
def inputRDD: RDD[InternalRow]
424+
425+
// If the input can be InternalRows, an UnsafeProjection needs to be created.
426+
protected val createUnsafeProjection: Boolean
427+
428+
override def inputRDDs(): Seq[RDD[InternalRow]] = {
429+
inputRDD :: Nil
430+
}
431+
432+
override def doProduce(ctx: CodegenContext): String = {
433+
// Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen
434+
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
435+
forceInline = true)
436+
val row = ctx.freshName("row")
437+
438+
val outputVars = if (createUnsafeProjection) {
439+
// creating the vars will make the parent consume add an unsafe projection.
440+
ctx.INPUT_ROW = row
441+
ctx.currentVars = null
442+
output.zipWithIndex.map { case (a, i) =>
443+
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
444+
}
445+
} else {
446+
null
447+
}
448+
449+
val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) {
450+
val numOutputRows = metricTerm(ctx, "numOutputRows")
451+
s"$numOutputRows.add(1);"
452+
} else {
453+
""
454+
}
455+
s"""
456+
| while ($limitNotReachedCond $input.hasNext()) {
457+
| InternalRow $row = (InternalRow) $input.next();
458+
| ${updateNumOutputRowsMetrics}
459+
| ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}
460+
| ${shouldStopCheckCode}
461+
| }
462+
""".stripMargin
463+
}
464+
}
409465

410466
/**
411467
* InputAdapter is used to hide a SparkPlan from a subtree that supports codegen.
412468
*
413469
* This is the leaf node of a tree with WholeStageCodegen that is used to generate code
414470
* that consumes an RDD iterator of InternalRow.
415471
*/
416-
case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
472+
case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen {
417473

418474
override def output: Seq[Attribute] = child.output
419475

@@ -429,24 +485,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
429485
child.doExecuteBroadcast()
430486
}
431487

432-
override def inputRDDs(): Seq[RDD[InternalRow]] = {
433-
child.execute() :: Nil
434-
}
488+
override def inputRDD: RDD[InternalRow] = child.execute()
435489

436-
override def doProduce(ctx: CodegenContext): String = {
437-
// Right now, InputAdapter is only used when there is one input RDD.
438-
// Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
439-
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
440-
forceInline = true)
441-
val row = ctx.freshName("row")
442-
s"""
443-
| while ($limitNotReachedCond $input.hasNext()) {
444-
| InternalRow $row = (InternalRow) $input.next();
445-
| ${consume(ctx, null, row).trim}
446-
| if (shouldStop()) return;
447-
| }
448-
""".stripMargin
449-
}
490+
// InputAdapter does not need UnsafeProjection.
491+
protected val createUnsafeProjection: Boolean = false
450492

451493
override def generateTreeString(
452494
depth: Int,

sql/core/src/test/resources/sql-tests/results/operators.sql.out

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct<plan:string>
201201
-- !query 24 output
202202
== Physical Plan ==
203203
*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]
204-
+- Scan OneRowRelation[]
204+
+- *Scan OneRowRelation[]
205205

206206

207207
-- !query 25
@@ -211,7 +211,7 @@ struct<plan:string>
211211
-- !query 25 output
212212
== Physical Plan ==
213213
*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]
214-
+- Scan OneRowRelation[]
214+
+- *Scan OneRowRelation[]
215215

216216

217217
-- !query 26
@@ -221,7 +221,7 @@ struct<plan:string>
221221
-- !query 26 output
222222
== Physical Plan ==
223223
*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]
224-
+- Scan OneRowRelation[]
224+
+- *Scan OneRowRelation[]
225225

226226

227227
-- !query 27
@@ -231,7 +231,7 @@ struct<plan:string>
231231
-- !query 27 output
232232
== Physical Plan ==
233233
*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]
234-
+- Scan OneRowRelation[]
234+
+- *Scan OneRowRelation[]
235235

236236

237237
-- !query 28
@@ -241,7 +241,7 @@ struct<plan:string>
241241
-- !query 28 output
242242
== Physical Plan ==
243243
*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]
244-
+- Scan OneRowRelation[]
244+
+- *Scan OneRowRelation[]
245245

246246

247247
-- !query 29
@@ -251,7 +251,7 @@ struct<plan:string>
251251
-- !query 29 output
252252
== Physical Plan ==
253253
*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]
254-
+- Scan OneRowRelation[]
254+
+- *Scan OneRowRelation[]
255255

256256

257257
-- !query 30

0 commit comments

Comments
 (0)