Skip to content

Commit a23eb35

Browse files
author
Robert Kruszewski
committed
Merge branch 'palantir-master' into rk/more-merge
2 parents 1f9772d + a51fa9c commit a23eb35

File tree

23 files changed

+843
-414
lines changed

23 files changed

+843
-414
lines changed

FORK.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
* core: Broadcast, CoarseGrainedExecutorBackend, CoarseGrainedSchedulerBackend, Executor, MemoryStore, SparkContext, TorrentBroadcast
1919
* kubernetes: ExecutorPodsAllocator, ExecutorPodsLifecycleManager, ExecutorPodsPollingSnapshotSource, ExecutorPodsSnapshot, ExecutorPodsWatchSnapshotSource, KubernetesClusterSchedulerBackend
2020
* yarn: YarnClusterSchedulerBackend, YarnSchedulerBackend
21+
22+
* [SPARK-26626](https://issues.apache.org/jira/browse/SPARK-26626) - Limited the maximum size of repeatedly substituted aliases
23+
2124
# Added
2225

2326
* Gradle plugin to easily create custom docker images for use with k8s
@@ -27,4 +30,4 @@
2730
* [SPARK-25908](https://issues.apache.org/jira/browse/SPARK-25908) - Removal of `monotonicall_increasing_id`, `toDegree`, `toRadians`, `approxCountDistinct`, `unionAll`
2831
* [SPARK-25862](https://issues.apache.org/jira/browse/SPARK-25862) - Removal of `unboundedPreceding`, `unboundedFollowing`, `currentRow`
2932
* [SPARK-26127](https://issues.apache.org/jira/browse/SPARK-26127) - Removal of deprecated setters from tree regression and classification models
30-
* [SPARK-25867](https://issues.apache.org/jira/browse/SPARK-25867) - Removal of KMeans computeCost
33+
* [SPARK-25867](https://issues.apache.org/jira/browse/SPARK-25867) - Removal of KMeans computeCost

core/src/main/scala/org/apache/spark/MapOutputStatistics.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ package org.apache.spark
2525
* (may be inexact due to use of compressed map statuses)
2626
*/
2727
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
28+
extends Serializable

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,8 @@ object CollapseProject extends Rule[LogicalPlan] {
649649

650650
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
651651
case p1 @ Project(_, p2: Project) =>
652-
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
652+
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
653+
hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) {
653654
p1
654655
} else {
655656
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
@@ -682,6 +683,28 @@ object CollapseProject extends Rule[LogicalPlan] {
682683
}.exists(!_.deterministic))
683684
}
684685

686+
private def hasOversizedRepeatedAliases(
687+
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
688+
val aliases = collectAliases(lower)
689+
690+
// Count how many times each alias is used in the upper Project.
691+
// If an alias is only used once, we can safely substitute it without increasing the overall
692+
// tree size
693+
val referenceCounts = AttributeMap(
694+
upper
695+
.flatMap(_.collect { case a: Attribute => a })
696+
.groupBy(identity)
697+
.mapValues(_.size).toSeq
698+
)
699+
700+
// Check for any aliases that are used more than once, and are larger than the configured
701+
// maximum size
702+
aliases.exists({ case (attribute, expression) =>
703+
referenceCounts.getOrElse(attribute, 0) > 1 &&
704+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
705+
})
706+
}
707+
685708
private def buildCleanedProjectList(
686709
upper: Seq[NamedExpression],
687710
lower: Seq[NamedExpression]): Seq[NamedExpression] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical._
26+
import org.apache.spark.sql.internal.SQLConf
2627

2728
/**
2829
* A pattern that matches any number of project or filter operations on top of another relational
@@ -58,8 +59,13 @@ object PhysicalOperation extends PredicateHelper {
5859
plan match {
5960
case Project(fields, child) if fields.forall(_.deterministic) =>
6061
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
61-
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
62-
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
62+
if (hasOversizedRepeatedAliases(fields, aliases)) {
63+
// Skip substitution if it could overly increase the overall tree size and risk OOMs
64+
(None, Nil, plan, Map.empty)
65+
} else {
66+
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
67+
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
68+
}
6369

6470
case Filter(condition, child) if condition.deterministic =>
6571
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
@@ -77,6 +83,26 @@ object PhysicalOperation extends PredicateHelper {
7783
case a @ Alias(child, _) => a.toAttribute -> child
7884
}.toMap
7985

86+
private def hasOversizedRepeatedAliases(fields: Seq[Expression],
87+
aliases: Map[Attribute, Expression]): Boolean = {
88+
// Count how many times each alias is used in the fields.
89+
// If an alias is only used once, we can safely substitute it without increasing the overall
90+
// tree size
91+
val referenceCounts = AttributeMap(
92+
fields
93+
.flatMap(_.collect { case a: Attribute => a })
94+
.groupBy(identity)
95+
.mapValues(_.size).toSeq
96+
)
97+
98+
// Check for any aliases that are used more than once, and are larger than the configured
99+
// maximum size
100+
aliases.exists({ case (attribute, expression) =>
101+
referenceCounts.getOrElse(attribute, 0) > 1 &&
102+
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
103+
})
104+
}
105+
80106
private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
81107
expr.transform {
82108
case a @ Alias(ref: AttributeReference, name) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
8989

9090
lazy val containsChild: Set[TreeNode[_]] = children.toSet
9191

92+
lazy val treeSize: Long = children.map(_.treeSize).sum + 1
93+
9294
private lazy val _hashCode: Int = scala.util.hashing.MurmurHash3.productHash(this)
9395
override def hashCode(): Int = _hashCode
9496

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,19 @@ object SQLConf {
280280

281281
val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS =
282282
buildConf("spark.sql.adaptive.minNumPostShufflePartitions")
283-
.internal()
284-
.doc("The advisory minimal number of post-shuffle partitions provided to " +
285-
"ExchangeCoordinator. This setting is used in our test to make sure we " +
286-
"have enough parallelism to expose issues that will not be exposed with a " +
287-
"single partition. When the value is a non-positive value, this setting will " +
288-
"not be provided to ExchangeCoordinator.")
283+
.doc("The advisory minimum number of post-shuffle partitions used in adaptive execution.")
284+
.intConf
285+
.checkValue(numPartitions => numPartitions > 0, "The minimum shuffle partition number " +
286+
"must be a positive integer.")
287+
.createWithDefault(1)
288+
289+
val SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS =
290+
buildConf("spark.sql.adaptive.maxNumPostShufflePartitions")
291+
.doc("The advisory maximum number of post-shuffle partitions used in adaptive execution.")
289292
.intConf
290-
.createWithDefault(-1)
293+
.checkValue(numPartitions => numPartitions > 0, "The maximum shuffle partition number " +
294+
"must be a positive integer.")
295+
.createWithDefault(500)
291296

292297
val SUBEXPRESSION_ELIMINATION_ENABLED =
293298
buildConf("spark.sql.subexpressionElimination.enabled")
@@ -1621,6 +1626,15 @@ object SQLConf {
16211626
""" "... N more fields" placeholder.""")
16221627
.intConf
16231628
.createWithDefault(25)
1629+
1630+
val MAX_REPEATED_ALIAS_SIZE =
1631+
buildConf("spark.sql.maxRepeatedAliasSize")
1632+
.internal()
1633+
.doc("The maximum size of alias expression that will be substituted multiple times " +
1634+
"(size defined by the number of nodes in the expression tree). " +
1635+
"Used by the CollapseProject optimizer, and PhysicalOperation.")
1636+
.intConf
1637+
.createWithDefault(100)
16241638
}
16251639

16261640
/**
@@ -1726,8 +1740,9 @@ class SQLConf extends Serializable with Logging {
17261740

17271741
def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)
17281742

1729-
def minNumPostShufflePartitions: Int =
1730-
getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1743+
def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS)
1744+
1745+
def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS)
17311746

17321747
def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)
17331748

@@ -2053,6 +2068,8 @@ class SQLConf extends Serializable with Logging {
20532068

20542069
def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS)
20552070

2071+
def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE)
2072+
20562073
/** ********************** SQLConf functionality methods ************ */
20572074

20582075
/** Set Spark SQL configuration properties. */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,22 @@ class CollapseProjectSuite extends PlanTest {
138138
assert(projects.size === 1)
139139
assert(hasMetadata(optimized))
140140
}
141+
142+
test("ensure oversize aliases are not repeatedly substituted") {
143+
var query: LogicalPlan = testRelation
144+
for( a <- 1 to 100) {
145+
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
146+
}
147+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
148+
assert(projects.size >= 12)
149+
}
150+
151+
test("ensure oversize aliases are still substituted once") {
152+
var query: LogicalPlan = testRelation
153+
for( a <- 1 to 20) {
154+
query = query.select(('a + 'b).as('a), 'b)
155+
}
156+
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
157+
assert(projects.size === 1)
158+
}
141159
}

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
3232
import org.apache.spark.sql.catalyst.rules.Rule
3333
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3434
import org.apache.spark.sql.catalyst.util.truncatedString
35+
import org.apache.spark.sql.execution.adaptive.PlanQueryStage
3536
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
3637
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
3738
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
@@ -96,7 +97,11 @@ class QueryExecution(
9697
* row format conversions as needed.
9798
*/
9899
protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
99-
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
100+
if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
101+
adaptivePreparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
102+
} else {
103+
preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
104+
}
100105
}
101106

102107
/** A sequence of rules that will be applied in order to the physical plan before execution. */
@@ -107,6 +112,15 @@ class QueryExecution(
107112
ReuseExchange(sparkSession.sessionState.conf),
108113
ReuseSubquery(sparkSession.sessionState.conf))
109114

115+
protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq(
116+
PlanSubqueries(sparkSession),
117+
EnsureRequirements(sparkSession.sessionState.conf),
118+
ReuseSubquery(sparkSession.sessionState.conf),
119+
// PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees
120+
// by inserting leaf node QueryStageInput. Transforming the plan after applying this rule will
121+
// only transform node in a sub-tree.
122+
PlanQueryStage(sparkSession.sessionState.conf))
123+
110124
protected def stringOrError[A](f: => A): String =
111125
try f.toString catch { case e: AnalysisException => e.toString }
112126

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.execution.adaptive.QueryStageInput
2122
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
2223
import org.apache.spark.sql.execution.metric.SQLMetricInfo
2324

@@ -51,6 +52,7 @@ private[execution] object SparkPlanInfo {
5152
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
5253
val children = plan match {
5354
case ReusedExchangeExec(_, child) => child :: Nil
55+
case i: QueryStageInput => i.childStage :: Nil
5456
case _ => plan.children ++ plan.subqueries
5557
}
5658
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.execution.command.ExecutedCommandExec
26+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec}
27+
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.StructType
29+
30+
/**
31+
* Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it adds a
32+
* QueryStage and a QueryStageInput. If reusing Exchange is enabled, it finds duplicated exchanges
33+
* and uses the same QueryStage for all the references.
34+
*/
35+
case class PlanQueryStage(conf: SQLConf) extends Rule[SparkPlan] {
36+
37+
def apply(plan: SparkPlan): SparkPlan = {
38+
39+
val newPlan = if (!conf.exchangeReuseEnabled) {
40+
plan.transformUp {
41+
case e: ShuffleExchangeExec =>
42+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
43+
case e: BroadcastExchangeExec =>
44+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
45+
}
46+
} else {
47+
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
48+
val stages = mutable.HashMap[StructType, ArrayBuffer[QueryStage]]()
49+
50+
plan.transformUp {
51+
case exchange: Exchange =>
52+
val sameSchema = stages.getOrElseUpdate(exchange.schema, ArrayBuffer[QueryStage]())
53+
val samePlan = sameSchema.find { s =>
54+
exchange.sameResult(s.child)
55+
}
56+
if (samePlan.isDefined) {
57+
// Keep the output of this exchange, the following plans require that to resolve
58+
// attributes.
59+
exchange match {
60+
case e: ShuffleExchangeExec => ShuffleQueryStageInput(
61+
samePlan.get.asInstanceOf[ShuffleQueryStage], exchange.output)
62+
case e: BroadcastExchangeExec => BroadcastQueryStageInput(
63+
samePlan.get.asInstanceOf[BroadcastQueryStage], exchange.output)
64+
}
65+
} else {
66+
val queryStageInput = exchange match {
67+
case e: ShuffleExchangeExec =>
68+
ShuffleQueryStageInput(ShuffleQueryStage(e), e.output)
69+
case e: BroadcastExchangeExec =>
70+
BroadcastQueryStageInput(BroadcastQueryStage(e), e.output)
71+
}
72+
sameSchema += queryStageInput.childStage
73+
queryStageInput
74+
}
75+
}
76+
}
77+
ResultQueryStage(newPlan)
78+
}
79+
}

0 commit comments

Comments
 (0)