Skip to content

Commit 5678e68

Browse files
committed
[SPARK-27393][SQL] Show ReusedSubquery in the plan when the subquery is reused
## What changes were proposed in this pull request? With this change, we can easily identify the plan difference when subquery is reused. When the reuse is enabled, the plan looks like ``` == Physical Plan == CollectLimit 1 +- *(1) Project [(Subquery subquery240 + ReusedSubquery Subquery subquery240) AS (scalarsubquery() + scalarsubquery())apache-spark-on-k8s#253] : :- Subquery subquery240 : : +- *(2) HashAggregate(keys=[], functions=[avg(cast(key#13 as bigint))], output=[avg(key)apache-spark-on-k8s#250]) : : +- Exchange SinglePartition : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(key#13 as bigint))], output=[sum#256, count#257L]) : : +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13] : : +- Scan[obj#12] : +- ReusedSubquery Subquery subquery240 +- *(1) SerializeFromObject +- Scan[obj#12] ``` When the reuse is disabled, the plan looks like ``` == Physical Plan == CollectLimit 1 +- *(1) Project [(Subquery subquery286 + Subquery subquery287) AS (scalarsubquery() + scalarsubquery())apache-spark-on-k8s#299] : :- Subquery subquery286 : : +- *(2) HashAggregate(keys=[], functions=[avg(cast(key#13 as bigint))], output=[avg(key)apache-spark-on-k8s#296]) : : +- Exchange SinglePartition : : +- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(key#13 as bigint))], output=[sum#302, count#303L]) : : +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13] : : +- Scan[obj#12] : +- Subquery subquery287 : +- *(2) HashAggregate(keys=[], functions=[avg(cast(key#13 as bigint))], output=[avg(key)apache-spark-on-k8s#298]) : +- Exchange SinglePartition : +- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(key#13 as bigint))], output=[sum#306, count#307L]) : +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$TestData, true])).key AS key#13] : +- Scan[obj#12] +- *(1) SerializeFromObject +- Scan[obj#12] ``` ## How was this patch tested? Modified the existing test. Closes apache#24258 from gatorsmile/followupSPARK-27279. Authored-by: gatorsmile <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent 568db94 commit 5678e68

File tree

8 files changed

+98
-46
lines changed

8 files changed

+98
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
3737
/** Updates the expression with a new plan. */
3838
def withNewPlan(plan: T): PlanExpression[T]
3939

40+
/** Defines how the canonicalization should work for this expression. */
41+
def canonicalize(attrs: AttributeSeq): PlanExpression[T]
42+
4043
protected def conditionString: String = children.mkString("[", " && ", "]")
4144
}
4245

@@ -58,7 +61,7 @@ abstract class SubqueryExpression(
5861
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
5962
case _ => false
6063
}
61-
def canonicalize(attrs: AttributeSeq): SubqueryExpression = {
64+
override def canonicalize(attrs: AttributeSeq): SubqueryExpression = {
6265
// Normalize the outer references in the subquery plan.
6366
val normalizedPlan = plan.transformAllExpressions {
6467
case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ object QueryPlan extends PredicateHelper {
279279
*/
280280
def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = {
281281
e.transformUp {
282-
case s: SubqueryExpression => s.canonicalize(input)
282+
case s: PlanExpression[_] => s.canonicalize(input)
283283
case ar: AttributeReference =>
284284
val ordinal = input.indexOf(ar.exprId)
285285
if (ordinal == -1) {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ private[execution] object SparkPlanInfo {
5252
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
5353
val children = plan match {
5454
case ReusedExchangeExec(_, child) => child :: Nil
55+
case ReusedSubqueryExec(child) => child :: Nil
5556
case _ => plan.children ++ plan.subqueries
5657
}
5758
val metrics = plan.metrics.toSeq.map { case (key, metric) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -671,21 +671,28 @@ object CoalesceExec {
671671
}
672672

673673
/**
674-
* Physical plan for a subquery.
674+
* Parent class for different types of subquery plans
675675
*/
676-
case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
677-
678-
override lazy val metrics = Map(
679-
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
680-
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"))
676+
abstract class BaseSubqueryExec extends SparkPlan {
677+
def name: String
678+
def child: SparkPlan
681679

682680
override def output: Seq[Attribute] = child.output
683681

684682
override def outputPartitioning: Partitioning = child.outputPartitioning
685683

686684
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
685+
}
687686

688-
override def doCanonicalize(): SparkPlan = child.canonicalized
687+
/**
688+
* Physical plan for a subquery.
689+
*/
690+
case class SubqueryExec(name: String, child: SparkPlan)
691+
extends BaseSubqueryExec with UnaryExecNode {
692+
693+
override lazy val metrics = Map(
694+
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
695+
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"))
689696

690697
@transient
691698
private lazy val relationFuture: Future[Array[InternalRow]] = {
@@ -709,6 +716,10 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
709716
}(SubqueryExec.executionContext)
710717
}
711718

719+
protected override def doCanonicalize(): SparkPlan = {
720+
SubqueryExec("Subquery", child.canonicalized)
721+
}
722+
712723
protected override def doPrepare(): Unit = {
713724
relationFuture
714725
}
@@ -726,3 +737,23 @@ object SubqueryExec {
726737
private[execution] val executionContext = ExecutionContext.fromExecutorService(
727738
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
728739
}
740+
741+
/**
742+
* A wrapper for reused [[BaseSubqueryExec]].
743+
*/
744+
case class ReusedSubqueryExec(child: BaseSubqueryExec)
745+
extends BaseSubqueryExec with LeafExecNode {
746+
747+
override def name: String = child.name
748+
749+
override def output: Seq[Attribute] = child.output
750+
override def doCanonicalize(): SparkPlan = child.canonicalized
751+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
752+
override def outputPartitioning: Partitioning = child.outputPartitioning
753+
754+
protected override def doPrepare(): Unit = child.prepare()
755+
756+
protected override def doExecute(): RDD[InternalRow] = child.execute()
757+
758+
override def executeCollect(): Array[InternalRow] = child.executeCollect()
759+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.sql.SparkSession
2424
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
25-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
25+
import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, InSet, Literal, PlanExpression}
2626
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2727
import org.apache.spark.sql.catalyst.rules.Rule
2828
import org.apache.spark.sql.internal.SQLConf
@@ -31,11 +31,16 @@ import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
3131
/**
3232
* The base class for subquery that is used in SparkPlan.
3333
*/
34-
abstract class ExecSubqueryExpression extends PlanExpression[SubqueryExec] {
34+
abstract class ExecSubqueryExpression extends PlanExpression[BaseSubqueryExec] {
3535
/**
3636
* Fill the expression with collected result from executed plan.
3737
*/
3838
def updateResult(): Unit
39+
40+
override def canonicalize(attrs: AttributeSeq): ExecSubqueryExpression = {
41+
withNewPlan(plan.canonicalized.asInstanceOf[BaseSubqueryExec])
42+
.asInstanceOf[ExecSubqueryExpression]
43+
}
3944
}
4045

4146
object ExecSubqueryExpression {
@@ -56,15 +61,15 @@ object ExecSubqueryExpression {
5661
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
5762
*/
5863
case class ScalarSubquery(
59-
plan: SubqueryExec,
64+
plan: BaseSubqueryExec,
6065
exprId: ExprId)
6166
extends ExecSubqueryExpression {
6267

6368
override def dataType: DataType = plan.schema.fields.head.dataType
6469
override def children: Seq[Expression] = Nil
6570
override def nullable: Boolean = true
6671
override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields)
67-
override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query)
72+
override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query)
6873

6974
override def semanticEquals(other: Expression): Boolean = other match {
7075
case s: ScalarSubquery => plan.sameResult(s.plan)
@@ -129,13 +134,14 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {
129134
return plan
130135
}
131136
// Build a hash map using schema of subqueries to avoid O(N*N) sameResult calls.
132-
val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
137+
val subqueries = mutable.HashMap[StructType, ArrayBuffer[BaseSubqueryExec]]()
133138
plan transformAllExpressions {
134139
case sub: ExecSubqueryExpression =>
135-
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
140+
val sameSchema =
141+
subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[BaseSubqueryExec]())
136142
val sameResult = sameSchema.find(_.sameResult(sub.plan))
137143
if (sameResult.isDefined) {
138-
sub.withNewPlan(sameResult.get)
144+
sub.withNewPlan(ReusedSubqueryExec(sameResult.get))
139145
} else {
140146
sameSchema += sub.plan
141147
sub

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ object SparkPlanGraph {
103103
// Point to the re-used subquery
104104
val node = exchanges(planInfo)
105105
edges += SparkPlanGraphEdge(node.id, parent.id)
106+
case "ReusedSubquery" =>
107+
// Re-used subquery might appear before the original subquery, so skip this node and let
108+
// the previous `case` make sure the re-used and the original point to the same node.
109+
buildSparkPlanGraphNode(
110+
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, subgraph, exchanges)
106111
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
107112
// Point to the re-used exchange
108113
val node = exchanges(planInfo.children.head)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import java.util.concurrent.atomic.AtomicBoolean
2525
import org.apache.spark.{AccumulatorSuite, SparkException}
2626
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2727
import org.apache.spark.sql.catalyst.util.StringUtils
28-
import org.apache.spark.sql.execution.{aggregate, ScalarSubquery, SubqueryExec}
2928
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
3029
import org.apache.spark.sql.execution.datasources.FilePartition
3130
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
@@ -113,33 +112,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
113112
}
114113
}
115114

116-
test("Reuse Subquery") {
117-
Seq(true, false).foreach { reuse =>
118-
withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) {
119-
val df = sql(
120-
"""
121-
|SELECT (SELECT avg(key) FROM testData) + (SELECT avg(key) FROM testData)
122-
|FROM testData
123-
|LIMIT 1
124-
""".stripMargin)
125-
126-
import scala.collection.mutable.ArrayBuffer
127-
val subqueries = ArrayBuffer[SubqueryExec]()
128-
df.queryExecution.executedPlan.transformAllExpressions {
129-
case s @ ScalarSubquery(plan: SubqueryExec, _) =>
130-
subqueries += plan
131-
s
132-
}
133-
134-
if (reuse) {
135-
assert(subqueries.distinct.size == 1, "Subquery reusing not working correctly")
136-
} else {
137-
assert(subqueries.distinct.size == 2, "There should be 2 subqueries when not reusing")
138-
}
139-
}
140-
}
141-
}
142-
143115
test("SPARK-6743: no columns from cache") {
144116
Seq(
145117
(83, 0, 38),
@@ -288,7 +260,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
288260
val df = sql(sqlText)
289261
// First, check if we have GeneratedAggregate.
290262
val hasGeneratedAgg = df.queryExecution.sparkPlan
291-
.collect { case _: aggregate.HashAggregateExec => true }
263+
.collect { case _: HashAggregateExec => true }
292264
.nonEmpty
293265
if (!hasGeneratedAgg) {
294266
fail(

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2323
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
24-
import org.apache.spark.sql.execution.{ExecSubqueryExpression, FileSourceScanExec, WholeStageCodegenExec}
24+
import org.apache.spark.sql.execution.{ExecSubqueryExpression, FileSourceScanExec, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec}
2525
import org.apache.spark.sql.execution.datasources.FileScanRDD
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.test.SharedSQLContext
2728

2829
class SubquerySuite extends QueryTest with SharedSQLContext {
@@ -1337,4 +1338,37 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
13371338
checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
13381339
}
13391340
}
1341+
1342+
test("SPARK-27279: Reuse Subquery") {
1343+
Seq(true, false).foreach { reuse =>
1344+
withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) {
1345+
val df = sql(
1346+
"""
1347+
|SELECT (SELECT avg(key) FROM testData) + (SELECT avg(key) FROM testData)
1348+
|FROM testData
1349+
|LIMIT 1
1350+
""".stripMargin)
1351+
1352+
var countSubqueryExec = 0
1353+
var countReuseSubqueryExec = 0
1354+
df.queryExecution.executedPlan.transformAllExpressions {
1355+
case s @ ScalarSubquery(_: SubqueryExec, _) =>
1356+
countSubqueryExec = countSubqueryExec + 1
1357+
s
1358+
case s @ ScalarSubquery(_: ReusedSubqueryExec, _) =>
1359+
countReuseSubqueryExec = countReuseSubqueryExec + 1
1360+
s
1361+
}
1362+
1363+
if (reuse) {
1364+
assert(countSubqueryExec == 1, "Subquery reusing not working correctly")
1365+
assert(countReuseSubqueryExec == 1, "Subquery reusing not working correctly")
1366+
} else {
1367+
assert(countSubqueryExec == 2, "expect 2 SubqueryExec when not reusing")
1368+
assert(countReuseSubqueryExec == 0,
1369+
"expect 0 ReusedSubqueryExec when not reusing")
1370+
}
1371+
}
1372+
}
1373+
}
13401374
}

0 commit comments

Comments
 (0)