Skip to content

Commit 8a4a29b

Browse files
authored
Merge pull request apache-spark-on-k8s#443 from palantir/juang/cherry-pick-ae-02
[AE2.3-02][SPARK-23128] Add QueryStage and the framework for adaptive execution (auto setting the number of reducer)
2 parents 8bfddaf + f477a24 commit 8a4a29b

File tree

18 files changed

+750
-410
lines changed

18 files changed

+750
-410
lines changed

core/src/main/scala/org/apache/spark/MapOutputStatistics.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ package org.apache.spark
2525
* (may be inexact due to use of compressed map statuses)
2626
*/
2727
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
28+
extends Serializable

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,19 @@ object SQLConf {
280280

281281
val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS =
282282
buildConf("spark.sql.adaptive.minNumPostShufflePartitions")
283-
.internal()
284-
.doc("The advisory minimal number of post-shuffle partitions provided to " +
285-
"ExchangeCoordinator. This setting is used in our test to make sure we " +
286-
"have enough parallelism to expose issues that will not be exposed with a " +
287-
"single partition. When the value is a non-positive value, this setting will " +
288-
"not be provided to ExchangeCoordinator.")
283+
.doc("The advisory minimum number of post-shuffle partitions used in adaptive execution.")
284+
.intConf
285+
.checkValue(numPartitions => numPartitions > 0, "The minimum shuffle partition number " +
286+
"must be a positive integer.")
287+
.createWithDefault(1)
288+
289+
val SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS =
290+
buildConf("spark.sql.adaptive.maxNumPostShufflePartitions")
291+
.doc("The advisory maximum number of post-shuffle partitions used in adaptive execution.")
289292
.intConf
290-
.createWithDefault(-1)
293+
.checkValue(numPartitions => numPartitions > 0, "The maximum shuffle partition number " +
294+
"must be a positive integer.")
295+
.createWithDefault(500)
291296

292297
val SUBEXPRESSION_ELIMINATION_ENABLED =
293298
buildConf("spark.sql.subexpressionElimination.enabled")
@@ -1710,8 +1715,9 @@ class SQLConf extends Serializable with Logging {
17101715

17111716
def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
17121717

1713-
def minNumPostShufflePartitions: Int =
1714-
getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1718+
def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1719+
1720+
def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS)
17151721

17161722
def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)
17171723

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
2727
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils
30+
import org.apache.spark.sql.execution.adaptive.PlanQueryStage
3031
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
3132
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
3233
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
@@ -84,7 +85,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
8485
* row format conversions as needed.
8586
*/
8687
protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
87-
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
88+
if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
89+
adaptivePreparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
90+
} else {
91+
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
92+
}
8893
}
8994

9095
/** A sequence of rules that will be applied in order to the physical plan before execution. */
@@ -95,6 +100,15 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
95100
ReuseExchange(sparkSession.sessionState.conf),
96101
ReuseSubquery(sparkSession.sessionState.conf))
97102

103+
protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq(
104+
PlanSubqueries(sparkSession),
105+
EnsureRequirements(sparkSession.sessionState.conf),
106+
ReuseSubquery(sparkSession.sessionState.conf),
107+
// PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees
108+
// by inserting leaf node QueryStageInput. Transforming the plan after applying this rule will
109+
// only transform node in a sub-tree.
110+
PlanQueryStage(sparkSession.sessionState.conf))
111+
98112
protected def stringOrError[A](f: => A): String =
99113
try f.toString catch { case e: AnalysisException => e.toString }
100114

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.execution.adaptive.QueryStageInput
2122
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
2223
import org.apache.spark.sql.execution.metric.SQLMetricInfo
2324

@@ -51,6 +52,7 @@ private[execution] object SparkPlanInfo {
5152
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
5253
val children = plan match {
5354
case ReusedExchangeExec(_, child) => child :: Nil
55+
case i: QueryStageInput => i.childStage :: Nil
5456
case _ => plan.children ++ plan.subqueries
5557
}
5658
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.execution.command.ExecutedCommandExec
26+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec}
27+
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.StructType
29+
30+
/**
31+
* Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it adds a
32+
* QueryStage and a QueryStageInput. If reusing Exchange is enabled, it finds duplicated exchanges
33+
* and uses the same QueryStage for all the references.
34+
*/
35+
case class PlanQueryStage(conf: SQLConf) extends Rule[SparkPlan] {
36+
37+
def apply(plan: SparkPlan): SparkPlan = {
38+
39+
val newPlan = if (!conf.exchangeReuseEnabled) {
40+
plan.transformUp {
41+
case e: ShuffleExchangeExec =>
42+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
43+
case e: BroadcastExchangeExec =>
44+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
45+
}
46+
} else {
47+
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
48+
val stages = mutable.HashMap[StructType, ArrayBuffer[QueryStage]]()
49+
50+
plan.transformUp {
51+
case exchange: Exchange =>
52+
val sameSchema = stages.getOrElseUpdate(exchange.schema, ArrayBuffer[QueryStage]())
53+
val samePlan = sameSchema.find { s =>
54+
exchange.sameResult(s.child)
55+
}
56+
if (samePlan.isDefined) {
57+
// Keep the output of this exchange, the following plans require that to resolve
58+
// attributes.
59+
exchange match {
60+
case e: ShuffleExchangeExec => ShuffleQueryStageInput(
61+
samePlan.get.asInstanceOf[ShuffleQueryStage], exchange.output)
62+
case e: BroadcastExchangeExec => BroadcastQueryStageInput(
63+
samePlan.get.asInstanceOf[BroadcastQueryStage], exchange.output)
64+
}
65+
} else {
66+
val queryStageInput = exchange match {
67+
case e: ShuffleExchangeExec =>
68+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
69+
case e: BroadcastExchangeExec =>
70+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
71+
}
72+
sameSchema += queryStageInput.childStage
73+
queryStageInput
74+
}
75+
}
76+
}
77+
ResultQueryStage(newPlan)
78+
}
79+
}

0 commit comments

Comments
 (0)