Skip to content

Commit 780586a

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-17701][SQL] Refactor RowDataSourceScanExec so its sameResult call does not compare strings
## What changes were proposed in this pull request? Currently, `RowDataSourceScanExec` and `FileSourceScanExec` rely on a "metadata" string map to implement equality comparison, since the RDDs they depend on cannot be directly compared. This has resulted in a number of correctness bugs around exchange reuse, e.g. SPARK-17673 and SPARK-16818. To make these comparisons less brittle, we should refactor these classes to compare constructor parameters directly instead of relying on the metadata map. This PR refactors `RowDataSourceScanExec`, `FileSourceScanExec` will be fixed in the follow-up PR. ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes apache#18600 from cloud-fan/minor.
1 parent d2d2a5d commit 780586a

File tree

5 files changed

+56
-80
lines changed

5 files changed

+56
-80
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,23 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
3333
import org.apache.spark.sql.execution.datasources._
3434
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
3535
import org.apache.spark.sql.execution.metric.SQLMetrics
36-
import org.apache.spark.sql.internal.SQLConf
37-
import org.apache.spark.sql.sources.BaseRelation
36+
import org.apache.spark.sql.sources.{BaseRelation, Filter}
3837
import org.apache.spark.sql.types.StructType
3938
import org.apache.spark.util.Utils
4039

4140
trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
4241
val relation: BaseRelation
43-
val metastoreTableIdentifier: Option[TableIdentifier]
42+
val tableIdentifier: Option[TableIdentifier]
4443

4544
protected val nodeNamePrefix: String = ""
4645

4746
override val nodeName: String = {
48-
s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}"
47+
s"Scan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}"
4948
}
5049

50+
// Metadata that describes more details of this scan.
51+
protected def metadata: Map[String, String]
52+
5153
override def simpleString: String = {
5254
val metadataEntries = metadata.toSeq.sorted.map {
5355
case (key, value) =>
@@ -73,34 +75,25 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
7375

7476
/** Physical plan node for scanning data from a relation. */
7577
case class RowDataSourceScanExec(
76-
output: Seq[Attribute],
78+
fullOutput: Seq[Attribute],
79+
requiredColumnsIndex: Seq[Int],
80+
filters: Set[Filter],
81+
handledFilters: Set[Filter],
7782
rdd: RDD[InternalRow],
7883
@transient relation: BaseRelation,
79-
override val outputPartitioning: Partitioning,
80-
override val metadata: Map[String, String],
81-
override val metastoreTableIdentifier: Option[TableIdentifier])
84+
override val tableIdentifier: Option[TableIdentifier])
8285
extends DataSourceScanExec {
8386

87+
def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput)
88+
8489
override lazy val metrics =
8590
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
8691

87-
val outputUnsafeRows = relation match {
88-
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
89-
!SparkSession.getActiveSession.get.sessionState.conf.getConf(
90-
SQLConf.PARQUET_VECTORIZED_READER_ENABLED)
91-
case _: HadoopFsRelation => true
92-
case _ => false
93-
}
94-
9592
protected override def doExecute(): RDD[InternalRow] = {
96-
val unsafeRow = if (outputUnsafeRows) {
97-
rdd
98-
} else {
99-
rdd.mapPartitionsWithIndexInternal { (index, iter) =>
100-
val proj = UnsafeProjection.create(schema)
101-
proj.initialize(index)
102-
iter.map(proj)
103-
}
93+
val unsafeRow = rdd.mapPartitionsWithIndexInternal { (index, iter) =>
94+
val proj = UnsafeProjection.create(schema)
95+
proj.initialize(index)
96+
iter.map(proj)
10497
}
10598

10699
val numOutputRows = longMetric("numOutputRows")
@@ -126,24 +119,31 @@ case class RowDataSourceScanExec(
126119
ctx.INPUT_ROW = row
127120
ctx.currentVars = null
128121
val columnsRowInput = exprRows.map(_.genCode(ctx))
129-
val inputRow = if (outputUnsafeRows) row else null
130122
s"""
131123
|while ($input.hasNext()) {
132124
| InternalRow $row = (InternalRow) $input.next();
133125
| $numOutputRows.add(1);
134-
| ${consume(ctx, columnsRowInput, inputRow).trim}
126+
| ${consume(ctx, columnsRowInput, null).trim}
135127
| if (shouldStop()) return;
136128
|}
137129
""".stripMargin
138130
}
139131

140-
// Only care about `relation` and `metadata` when canonicalizing.
132+
override val metadata: Map[String, String] = {
133+
val markedFilters = for (filter <- filters) yield {
134+
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
135+
}
136+
Map(
137+
"ReadSchema" -> output.toStructType.catalogString,
138+
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
139+
}
140+
141+
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
141142
override lazy val canonicalized: SparkPlan =
142143
copy(
143-
output.map(QueryPlan.normalizeExprId(_, output)),
144+
fullOutput.map(QueryPlan.normalizeExprId(_, fullOutput)),
144145
rdd = null,
145-
outputPartitioning = null,
146-
metastoreTableIdentifier = None)
146+
tableIdentifier = None)
147147
}
148148

149149
/**
@@ -154,15 +154,15 @@ case class RowDataSourceScanExec(
154154
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
155155
* @param partitionFilters Predicates to use for partition pruning.
156156
* @param dataFilters Filters on non-partition columns.
157-
* @param metastoreTableIdentifier identifier for the table in the metastore.
157+
* @param tableIdentifier identifier for the table in the metastore.
158158
*/
159159
case class FileSourceScanExec(
160160
@transient relation: HadoopFsRelation,
161161
output: Seq[Attribute],
162162
requiredSchema: StructType,
163163
partitionFilters: Seq[Expression],
164164
dataFilters: Seq[Expression],
165-
override val metastoreTableIdentifier: Option[TableIdentifier])
165+
override val tableIdentifier: Option[TableIdentifier])
166166
extends DataSourceScanExec with ColumnarBatchScan {
167167

168168
val supportsBatch: Boolean = relation.fileFormat.supportBatch(
@@ -261,7 +261,6 @@ case class FileSourceScanExec(
261261
private val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
262262
logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
263263

264-
// These metadata values make scan plans uniquely identifiable for equality checking.
265264
override val metadata: Map[String, String] = {
266265
def seqToString(seq: Seq[Any]) = seq.mkString("[", ", ", "]")
267266
val location = relation.location

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
7171
super.makeCopy(newArgs)
7272
}
7373

74-
/**
75-
* @return Metadata that describes more details of this SparkPlan.
76-
*/
77-
def metadata: Map[String, String] = Map.empty
78-
7974
/**
8075
* @return All metrics containing metrics of this SparkPlan.
8176
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class SparkPlanInfo(
3131
val nodeName: String,
3232
val simpleString: String,
3333
val children: Seq[SparkPlanInfo],
34-
val metadata: Map[String, String],
3534
val metrics: Seq[SQLMetricInfo]) {
3635

3736
override def hashCode(): Int = {
@@ -58,7 +57,6 @@ private[execution] object SparkPlanInfo {
5857
new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
5958
}
6059

61-
new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
62-
plan.metadata, metrics)
60+
new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan), metrics)
6361
}
6462
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import java.util.concurrent.Callable
2121

22-
import scala.collection.mutable.ArrayBuffer
23-
2422
import org.apache.spark.internal.Logging
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.sql._
@@ -288,10 +286,11 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
288286
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
289287
RowDataSourceScanExec(
290288
l.output,
289+
l.output.indices,
290+
Set.empty,
291+
Set.empty,
291292
toCatalystRDD(l, baseRelation.buildScan()),
292293
baseRelation,
293-
UnknownPartitioning(0),
294-
Map.empty,
295294
None) :: Nil
296295

297296
case _ => Nil
@@ -354,36 +353,10 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
354353
val (unhandledPredicates, pushedFilters, handledFilters) =
355354
selectFilters(relation.relation, candidatePredicates)
356355

357-
// A set of column attributes that are only referenced by pushed down filters. We can eliminate
358-
// them from requested columns.
359-
val handledSet = {
360-
val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains)
361-
val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references))
362-
AttributeSet(handledPredicates.flatMap(_.references)) --
363-
(projectSet ++ unhandledSet).map(relation.attributeMap)
364-
}
365-
366356
// Combines all Catalyst filter `Expression`s that are either not convertible to data source
367357
// `Filter`s or cannot be handled by `relation`.
368358
val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And)
369359

370-
// These metadata values make scan plans uniquely identifiable for equality checking.
371-
// TODO(SPARK-17701) using strings for equality checking is brittle
372-
val metadata: Map[String, String] = {
373-
val pairs = ArrayBuffer.empty[(String, String)]
374-
375-
// Mark filters which are handled by the underlying DataSource with an Astrisk
376-
if (pushedFilters.nonEmpty) {
377-
val markedFilters = for (filter <- pushedFilters) yield {
378-
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
379-
}
380-
pairs += ("PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
381-
}
382-
pairs += ("ReadSchema" ->
383-
StructType.fromAttributes(projects.map(_.toAttribute)).catalogString)
384-
pairs.toMap
385-
}
386-
387360
if (projects.map(_.toAttribute) == projects &&
388361
projectSet.size == projects.size &&
389362
filterSet.subsetOf(projectSet)) {
@@ -395,24 +368,36 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
395368
.asInstanceOf[Seq[Attribute]]
396369
// Match original case of attributes.
397370
.map(relation.attributeMap)
398-
// Don't request columns that are only referenced by pushed filters.
399-
.filterNot(handledSet.contains)
400371

401372
val scan = RowDataSourceScanExec(
402-
projects.map(_.toAttribute),
373+
relation.output,
374+
requestedColumns.map(relation.output.indexOf),
375+
pushedFilters.toSet,
376+
handledFilters,
403377
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
404-
relation.relation, UnknownPartitioning(0), metadata,
378+
relation.relation,
405379
relation.catalogTable.map(_.identifier))
406380
filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
407381
} else {
382+
// A set of column attributes that are only referenced by pushed down filters. We can
383+
// eliminate them from requested columns.
384+
val handledSet = {
385+
val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains)
386+
val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references))
387+
AttributeSet(handledPredicates.flatMap(_.references)) --
388+
(projectSet ++ unhandledSet).map(relation.attributeMap)
389+
}
408390
// Don't request columns that are only referenced by pushed filters.
409391
val requestedColumns =
410392
(projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
411393

412394
val scan = RowDataSourceScanExec(
413-
requestedColumns,
395+
relation.output,
396+
requestedColumns.map(relation.output.indexOf),
397+
pushedFilters.toSet,
398+
handledFilters,
414399
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
415-
relation.relation, UnknownPartitioning(0), metadata,
400+
relation.relation,
416401
relation.catalogTable.map(_.identifier))
417402
execution.ProjectExec(
418403
projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan))

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ object SparkPlanGraph {
113113
}
114114
val node = new SparkPlanGraphNode(
115115
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
116-
planInfo.simpleString, planInfo.metadata, metrics)
116+
planInfo.simpleString, metrics)
117117
if (subgraph == null) {
118118
nodes += node
119119
} else {
@@ -143,7 +143,6 @@ private[ui] class SparkPlanGraphNode(
143143
val id: Long,
144144
val name: String,
145145
val desc: String,
146-
val metadata: Map[String, String],
147146
val metrics: Seq[SQLPlanMetric]) {
148147

149148
def makeDotNode(metricsValue: Map[Long, String]): String = {
@@ -177,7 +176,7 @@ private[ui] class SparkPlanGraphCluster(
177176
desc: String,
178177
val nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
179178
metrics: Seq[SQLPlanMetric])
180-
extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) {
179+
extends SparkPlanGraphNode(id, name, desc, metrics) {
181180

182181
override def makeDotNode(metricsValue: Map[Long, String]): String = {
183182
val duration = metrics.filter(_.name.startsWith(WholeStageCodegenExec.PIPELINE_DURATION_METRIC))

0 commit comments

Comments
 (0)