Skip to content

Commit b3c7f3f

Browse files
carsonwangJustin Uang
authored andcommitted
Add QueryStage and the framework for adaptive execution
1 parent b50649d commit b3c7f3f

File tree

16 files changed

+594
-410
lines changed

16 files changed

+594
-410
lines changed

core/src/main/scala/org/apache/spark/MapOutputStatistics.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ package org.apache.spark
2525
* (may be inexact due to use of compressed map statuses)
2626
*/
2727
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
28+
extends Serializable

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,19 @@ object SQLConf {
280280

281281
val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS =
282282
buildConf("spark.sql.adaptive.minNumPostShufflePartitions")
283-
.internal()
284-
.doc("The advisory minimal number of post-shuffle partitions provided to " +
285-
"ExchangeCoordinator. This setting is used in our test to make sure we " +
286-
"have enough parallelism to expose issues that will not be exposed with a " +
287-
"single partition. When the value is a non-positive value, this setting will " +
288-
"not be provided to ExchangeCoordinator.")
283+
.doc("The advisory minimum number of post-shuffle partitions used in adaptive execution.")
284+
.intConf
285+
.checkValue(numPartitions => numPartitions > 0, "The minimum shuffle partition number " +
286+
"must be a positive integer.")
287+
.createWithDefault(1)
288+
289+
val SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS =
290+
buildConf("spark.sql.adaptive.maxNumPostShufflePartitions")
291+
.doc("The advisory maximum number of post-shuffle partitions used in adaptive execution.")
289292
.intConf
290-
.createWithDefault(-1)
293+
.checkValue(numPartitions => numPartitions > 0, "The maximum shuffle partition number " +
294+
"must be a positive integer.")
295+
.createWithDefault(500)
291296

292297
val SUBEXPRESSION_ELIMINATION_ENABLED =
293298
buildConf("spark.sql.subexpressionElimination.enabled")
@@ -1698,8 +1703,9 @@ class SQLConf extends Serializable with Logging {
16981703

16991704
def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
17001705

1701-
def minNumPostShufflePartitions: Int =
1702-
getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1706+
def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1707+
1708+
def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS)
17031709

17041710
def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)
17051711

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
2727
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils
30+
import org.apache.spark.sql.execution.adaptive.PlanQueryStage
3031
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
3132
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
3233
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
@@ -84,7 +85,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
8485
* row format conversions as needed.
8586
*/
8687
protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
87-
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
88+
if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
89+
adaptivePreparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
90+
} else {
91+
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
92+
}
8893
}
8994

9095
/** A sequence of rules that will be applied in order to the physical plan before execution. */
@@ -95,6 +100,15 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
95100
ReuseExchange(sparkSession.sessionState.conf),
96101
ReuseSubquery(sparkSession.sessionState.conf))
97102

103+
protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq(
104+
PlanSubqueries(sparkSession),
105+
EnsureRequirements(sparkSession.sessionState.conf),
106+
ReuseSubquery(sparkSession.sessionState.conf),
107+
// PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees
108+
// by inserting leaf node QueryStageInput. Transforming the plan after applying this rule will
109+
// only transform node in a sub-tree.
110+
PlanQueryStage(sparkSession.sessionState.conf))
111+
98112
protected def stringOrError[A](f: => A): String =
99113
try f.toString catch { case e: AnalysisException => e.toString }
100114

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.execution.adaptive.QueryStageInput
2122
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
2223
import org.apache.spark.sql.execution.metric.SQLMetricInfo
2324

@@ -51,6 +52,7 @@ private[execution] object SparkPlanInfo {
5152
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
5253
val children = plan match {
5354
case ReusedExchangeExec(_, child) => child :: Nil
55+
case i: QueryStageInput => i.childStage :: Nil
5456
case _ => plan.children ++ plan.subqueries
5557
}
5658
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.execution.command.ExecutedCommandExec
26+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec}
27+
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.StructType
29+
30+
/**
31+
* Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it adds a
32+
* QueryStage and a QueryStageInput. If reusing Exchange is enabled, it finds duplicated exchanges
33+
* and uses the same QueryStage for all the references.
34+
*/
35+
case class PlanQueryStage(conf: SQLConf) extends Rule[SparkPlan] {
36+
37+
def apply(plan: SparkPlan): SparkPlan = {
38+
39+
val newPlan = if (!conf.exchangeReuseEnabled) {
40+
plan.transformUp {
41+
case e: ShuffleExchangeExec =>
42+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
43+
case e: BroadcastExchangeExec =>
44+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
45+
}
46+
} else {
47+
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
48+
val stages = mutable.HashMap[StructType, ArrayBuffer[QueryStage]]()
49+
50+
plan.transformUp {
51+
case exchange: Exchange =>
52+
val sameSchema = stages.getOrElseUpdate(exchange.schema, ArrayBuffer[QueryStage]())
53+
val samePlan = sameSchema.find { s =>
54+
exchange.sameResult(s.child)
55+
}
56+
if (samePlan.isDefined) {
57+
// Keep the output of this exchange, the following plans require that to resolve
58+
// attributes.
59+
exchange match {
60+
case e: ShuffleExchangeExec => ShuffleQueryStageInput(
61+
samePlan.get.asInstanceOf[ShuffleQueryStage], exchange.output)
62+
case e: BroadcastExchangeExec => BroadcastQueryStageInput(
63+
samePlan.get.asInstanceOf[BroadcastQueryStage], exchange.output)
64+
}
65+
} else {
66+
val queryStageInput = exchange match {
67+
case e: ShuffleExchangeExec =>
68+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
69+
case e: BroadcastExchangeExec =>
70+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
71+
}
72+
sameSchema += queryStageInput.childStage
73+
queryStageInput
74+
}
75+
}
76+
}
77+
ResultQueryStage(newPlan)
78+
}
79+
}
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.concurrent.{ExecutionContext, Future}
21+
import scala.concurrent.duration.Duration
22+
23+
import org.apache.spark.MapOutputStatistics
24+
import org.apache.spark.broadcast
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.sql.catalyst.InternalRow
27+
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
29+
import org.apache.spark.sql.execution._
30+
import org.apache.spark.sql.execution.exchange._
31+
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
32+
import org.apache.spark.util.ThreadUtils
33+
34+
/**
35+
* In adaptive execution mode, an execution plan is divided into multiple QueryStages. Each
36+
* QueryStage is a sub-tree that runs in a single stage.
37+
*/
38+
abstract class QueryStage extends UnaryExecNode {
39+
40+
var child: SparkPlan
41+
42+
// Ignore this wrapper for canonicalizing.
43+
override def doCanonicalize(): SparkPlan = child.canonicalized
44+
45+
override def output: Seq[Attribute] = child.output
46+
47+
override def outputPartitioning: Partitioning = child.outputPartitioning
48+
49+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
50+
51+
/**
52+
* Execute childStages and wait until all stages are completed. Use a thread pool to avoid
53+
* blocking on one child stage.
54+
*/
55+
def executeChildStages(): Unit = {
56+
// Handle broadcast stages
57+
val broadcastQueryStages: Seq[BroadcastQueryStage] = child.collect {
58+
case bqs: BroadcastQueryStageInput => bqs.childStage
59+
}
60+
val broadcastFutures = broadcastQueryStages.map { queryStage =>
61+
Future { queryStage.prepareBroadcast() }(QueryStage.executionContext)
62+
}
63+
64+
// Submit shuffle stages
65+
val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
66+
val shuffleQueryStages: Seq[ShuffleQueryStage] = child.collect {
67+
case sqs: ShuffleQueryStageInput => sqs.childStage
68+
}
69+
val shuffleStageFutures = shuffleQueryStages.map { queryStage =>
70+
Future {
71+
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
72+
queryStage.execute()
73+
}
74+
}(QueryStage.executionContext)
75+
}
76+
77+
ThreadUtils.awaitResult(
78+
Future.sequence(broadcastFutures)(implicitly, QueryStage.executionContext), Duration.Inf)
79+
ThreadUtils.awaitResult(
80+
Future.sequence(shuffleStageFutures)(implicitly, QueryStage.executionContext), Duration.Inf)
81+
}
82+
83+
/**
84+
* Before executing the plan in this query stage, we execute all child stages, optimize the plan
85+
* in this stage and determine the reducer number based on the child stages' statistics. Finally
86+
* we do a codegen for this query stage and update the UI with the new plan.
87+
*/
88+
def prepareExecuteStage(): Unit = {
89+
// 1. Execute childStages
90+
executeChildStages()
91+
// It is possible to optimize this stage's plan here based on the child stages' statistics.
92+
93+
// 2. Determine reducer number
94+
val queryStageInputs: Seq[ShuffleQueryStageInput] = child.collect {
95+
case input: ShuffleQueryStageInput => input
96+
}
97+
val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
98+
.filter(_ != null).toArray
99+
if (childMapOutputStatistics.length > 0) {
100+
val exchangeCoordinator = new ExchangeCoordinator(
101+
conf.targetPostShuffleInputSize,
102+
conf.minNumPostShufflePartitions)
103+
104+
val partitionStartIndices =
105+
exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics)
106+
child = child.transform {
107+
case ShuffleQueryStageInput(childStage, output, _) =>
108+
ShuffleQueryStageInput(childStage, output, Some(partitionStartIndices))
109+
}
110+
}
111+
112+
// 3. Codegen and update the UI
113+
child = CollapseCodegenStages(sqlContext.conf).apply(child)
114+
val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
115+
if (executionId != null && executionId.nonEmpty) {
116+
val queryExecution = SQLExecution.getQueryExecution(executionId.toLong)
117+
sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate(
118+
executionId.toLong,
119+
queryExecution.toString,
120+
SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)))
121+
}
122+
}
123+
124+
// Caches the created ShuffleRowRDD so we can reuse that.
125+
private var cachedRDD: RDD[InternalRow] = null
126+
127+
def executeStage(): RDD[InternalRow] = child.execute()
128+
129+
/**
130+
* A QueryStage can be reused like Exchange. It is possible that multiple threads try to submit
131+
* the same QueryStage. Use synchronized to make sure it is executed only once.
132+
*/
133+
override def doExecute(): RDD[InternalRow] = synchronized {
134+
if (cachedRDD == null) {
135+
prepareExecuteStage()
136+
cachedRDD = executeStage()
137+
}
138+
cachedRDD
139+
}
140+
141+
override def executeCollect(): Array[InternalRow] = {
142+
prepareExecuteStage()
143+
child.executeCollect()
144+
}
145+
146+
override def executeToIterator(): Iterator[InternalRow] = {
147+
prepareExecuteStage()
148+
child.executeToIterator()
149+
}
150+
151+
override def executeTake(n: Int): Array[InternalRow] = {
152+
prepareExecuteStage()
153+
child.executeTake(n)
154+
}
155+
156+
override def generateTreeString(
157+
depth: Int,
158+
lastChildren: Seq[Boolean],
159+
builder: StringBuilder,
160+
verbose: Boolean,
161+
prefix: String = "",
162+
addSuffix: Boolean = false): StringBuilder = {
163+
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
164+
}
165+
}
166+
167+
/**
168+
* The last QueryStage of an execution plan.
169+
*/
170+
case class ResultQueryStage(var child: SparkPlan) extends QueryStage
171+
172+
/**
173+
* A shuffle QueryStage whose child is a ShuffleExchange.
174+
*/
175+
case class ShuffleQueryStage(var child: SparkPlan) extends QueryStage {
176+
177+
protected var _mapOutputStatistics: MapOutputStatistics = null
178+
179+
def mapOutputStatistics: MapOutputStatistics = _mapOutputStatistics
180+
181+
override def executeStage(): RDD[InternalRow] = {
182+
child match {
183+
case e: ShuffleExchangeExec =>
184+
val result = e.eagerExecute()
185+
_mapOutputStatistics = e.mapOutputStatistics
186+
result
187+
case _ => throw new IllegalArgumentException(
188+
"The child of ShuffleQueryStage must be a ShuffleExchange.")
189+
}
190+
}
191+
}
192+
193+
/**
194+
* A broadcast QueryStage whose child is a BroadcastExchangeExec.
195+
*/
196+
case class BroadcastQueryStage(var child: SparkPlan) extends QueryStage {
197+
override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
198+
child.executeBroadcast()
199+
}
200+
201+
private var prepared = false
202+
203+
def prepareBroadcast() : Unit = synchronized {
204+
if (!prepared) {
205+
executeChildStages()
206+
child = CollapseCodegenStages(sqlContext.conf).apply(child)
207+
// After child stages are completed, prepare() triggers the broadcast.
208+
prepare()
209+
prepared = true
210+
}
211+
}
212+
213+
override def doExecute(): RDD[InternalRow] = {
214+
throw new UnsupportedOperationException(
215+
"BroadcastExchange does not support the execute() code path.")
216+
}
217+
}
218+
219+
object QueryStage {
220+
private[execution] val executionContext = ExecutionContext.fromExecutorService(
221+
ThreadUtils.newDaemonCachedThreadPool("adaptive-query-stage"))
222+
}

0 commit comments

Comments
 (0)