Skip to content

Commit 5ed7660

Browse files
maryannxuegatorsmile
authored andcommitted
[SPARK-24802][SQL][FOLLOW-UP] Add a new config for Optimization Rule Exclusion
## What changes were proposed in this pull request? This is an extension to the original PR, in which rule exclusion did not work for classes derived from Optimizer, e.g., SparkOptimizer. To solve this issue, Optimizer and its derived classes will define/override `defaultBatches` and `nonExcludableRules` in order to define its default rule set as well as rules that cannot be excluded by the SQL config. In the meantime, Optimizer's `batches` method is dedicated to the rule exclusion logic and is defined "final". ## How was this patch tested? Added UT. Author: maryannxue <[email protected]> Closes apache#21876 from maryannxue/rule-exclusion.
1 parent 58353d7 commit 5ed7660

File tree

5 files changed

+69
-17
lines changed

5 files changed

+69
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
4646

4747
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
4848

49+
/**
50+
* Defines the default rule batches in the Optimizer.
51+
*
52+
* Implementations of this class should override this method, and [[nonExcludableRules]] if
53+
* necessary, instead of [[batches]]. The rule batches that eventually run in the Optimizer,
54+
* i.e., returned by [[batches]], will be (defaultBatches - (excludedRules - nonExcludableRules)).
55+
*/
4956
def defaultBatches: Seq[Batch] = {
5057
val operatorOptimizationRuleSet =
5158
Seq(
@@ -160,6 +167,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
160167
UpdateNullabilityInAttributeReferences)
161168
}
162169

170+
/**
171+
* Defines rules that cannot be excluded from the Optimizer even if they are specified in
172+
* SQL config "excludedRules".
173+
*
174+
* Implementations of this class can override this method if necessary. The rule batches
175+
* that eventually run in the Optimizer, i.e., returned by [[batches]], will be
176+
* (defaultBatches - (excludedRules - nonExcludableRules)).
177+
*/
163178
def nonExcludableRules: Seq[String] =
164179
EliminateDistinct.ruleName ::
165180
EliminateSubqueryAliases.ruleName ::
@@ -202,7 +217,14 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
202217
*/
203218
def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil
204219

205-
override def batches: Seq[Batch] = {
220+
/**
221+
* Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that
222+
* eventually run in the Optimizer.
223+
*
224+
* Implementations of this class should override [[defaultBatches]], and [[nonExcludableRules]]
225+
* if necessary, instead of this method.
226+
*/
227+
final override def batches: Seq[Batch] = {
206228
val excludedRulesConf =
207229
SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq)
208230
val excludedRules = excludedRulesConf.filter { ruleName =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class OptimizerExtendableSuite extends SparkFunSuite {
4747
DummyRule) :: Nil
4848
}
4949

50-
override def batches: Seq[Batch] = super.batches ++ myBatches
50+
override def defaultBatches: Seq[Batch] = super.defaultBatches ++ myBatches
5151
}
5252

5353
test("Extending batches possible") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class OptimizerRuleExclusionSuite extends PlanTest {
2828

2929
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
3030

31-
private def verifyExcludedRules(excludedRuleNames: Seq[String]) {
32-
val optimizer = new SimpleTestOptimizer()
31+
private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]) {
32+
val nonExcludableRules = optimizer.nonExcludableRules
33+
34+
val excludedRuleNames = rulesToExclude.filter(!nonExcludableRules.contains(_))
3335
// Batches whose rules are all to be excluded should be removed as a whole.
3436
val excludedBatchNames = optimizer.batches
3537
.filter(batch => batch.rules.forall(rule => excludedRuleNames.contains(rule.ruleName)))
@@ -38,21 +40,31 @@ class OptimizerRuleExclusionSuite extends PlanTest {
3840
withSQLConf(
3941
OPTIMIZER_EXCLUDED_RULES.key -> excludedRuleNames.foldLeft("")((l, r) => l + "," + r)) {
4042
val batches = optimizer.batches
43+
// Verify removed batches.
4144
assert(batches.forall(batch => !excludedBatchNames.contains(batch.name)))
45+
// Verify removed rules.
4246
assert(
4347
batches
4448
.forall(batch => batch.rules.forall(rule => !excludedRuleNames.contains(rule.ruleName))))
49+
// Verify non-excludable rules retained.
50+
nonExcludableRules.foreach { nonExcludableRule =>
51+
assert(
52+
optimizer.batches
53+
.exists(batch => batch.rules.exists(rule => rule.ruleName == nonExcludableRule)))
54+
}
4555
}
4656
}
4757

4858
test("Exclude a single rule from multiple batches") {
4959
verifyExcludedRules(
60+
new SimpleTestOptimizer(),
5061
Seq(
5162
PushPredicateThroughJoin.ruleName))
5263
}
5364

5465
test("Exclude multiple rules from single or multiple batches") {
5566
verifyExcludedRules(
67+
new SimpleTestOptimizer(),
5668
Seq(
5769
CombineUnions.ruleName,
5870
RemoveLiteralFromGroupExpressions.ruleName,
@@ -61,27 +73,42 @@ class OptimizerRuleExclusionSuite extends PlanTest {
6173

6274
test("Exclude non-existent rule with other valid rules") {
6375
verifyExcludedRules(
76+
new SimpleTestOptimizer(),
6477
Seq(
6578
LimitPushDown.ruleName,
6679
InferFiltersFromConstraints.ruleName,
6780
"DummyRuleName"))
6881
}
6982

7083
test("Try to exclude a non-excludable rule") {
71-
val excludedRules = Seq(
72-
ReplaceIntersectWithSemiJoin.ruleName,
73-
PullupCorrelatedPredicates.ruleName)
84+
verifyExcludedRules(
85+
new SimpleTestOptimizer(),
86+
Seq(
87+
ReplaceIntersectWithSemiJoin.ruleName,
88+
PullupCorrelatedPredicates.ruleName))
89+
}
7490

75-
val optimizer = new SimpleTestOptimizer()
91+
test("Custom optimizer") {
92+
val optimizer = new SimpleTestOptimizer() {
93+
override def defaultBatches: Seq[Batch] =
94+
Batch("push", Once,
95+
PushDownPredicate,
96+
PushPredicateThroughJoin,
97+
PushProjectionThroughUnion) ::
98+
Batch("pull", Once,
99+
PullupCorrelatedPredicates) :: Nil
76100

77-
withSQLConf(
78-
OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) {
79-
excludedRules.foreach { excludedRule =>
80-
assert(
81-
optimizer.batches
82-
.exists(batch => batch.rules.exists(rule => rule.ruleName == excludedRule)))
83-
}
101+
override def nonExcludableRules: Seq[String] =
102+
PushDownPredicate.ruleName ::
103+
PullupCorrelatedPredicates.ruleName :: Nil
84104
}
105+
106+
verifyExcludedRules(
107+
optimizer,
108+
Seq(
109+
PushDownPredicate.ruleName,
110+
PushProjectionThroughUnion.ruleName,
111+
PullupCorrelatedPredicates.ruleName))
85112
}
86113

87114
test("Verify optimized plan after excluding CombineUnions rule") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
4444
EmptyFunctionRegistry,
4545
new SQLConf())) {
4646
val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
47-
override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches
47+
override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches
4848
}
4949

5050
test("check for invalid plan after execution of rule") {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@ class SparkOptimizer(
2828
experimentalMethods: ExperimentalMethods)
2929
extends Optimizer(catalog) {
3030

31-
override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+
31+
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
3232
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
3333
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
3434
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++
3535
postHocOptimizationBatches :+
3636
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
3737

38+
override def nonExcludableRules: Seq[String] =
39+
super.nonExcludableRules :+ ExtractPythonUDFFromAggregate.ruleName
40+
3841
/**
3942
* Optimization batches that are executed before the regular optimization batches (also before
4043
* the finish analysis batch).

0 commit comments

Comments
 (0)