Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 1a98574

Browse files
gengliangwanggatorsmile
authored andcommitted
[SPARK-21979][SQL] Improve QueryPlanConstraints framework
## What changes were proposed in this pull request? Improve QueryPlanConstraints framework, make it robust and simple. In apache#15319, constraints for expressions like `a = f(b, c)` is resolved. However, for expressions like ```scala a = f(b, c) && c = g(a, b) ``` The current QueryPlanConstraints framework will produce non-converging constraints. Essentially, the problem is caused by having both the name and child of aliases in the same constraint set. We infer constraints, and push down constraints as predicates in filters, later on these predicates are propagated as constraints, etc.. Simply using the alias names only can resolve these problems. The size of constraints is reduced without losing any information. We can always get these inferred constraints on child of aliases when pushing down filters. Also, the EqualNullSafe between name and child in propagating alias is meaningless ```scala allConstraints += EqualNullSafe(e, a.toAttribute) ``` It just produces redundant constraints. ## How was this patch tested? Unit test Author: Wang Gengliang <[email protected]> Closes apache#19201 from gengliangwang/QueryPlanConstraints.
1 parent c5f9b89 commit 1a98574

File tree

4 files changed

+65
-87
lines changed

4 files changed

+65
-87
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ abstract class UnaryNode extends LogicalPlan {
297297
case expr: Expression if expr.semanticEquals(e) =>
298298
a.toAttribute
299299
})
300-
allConstraints += EqualNullSafe(e, a.toAttribute)
301300
case _ => // Don't change.
302301
}
303302

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan =>
106106
* Infers an additional set of constraints from a given set of equality constraints.
107107
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
108108
* additional constraint of the form `b = 5`.
109-
*
110-
* [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
111-
* as they are often useless and can lead to a non-converging set of constraints.
112109
*/
113110
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
114-
val constraintClasses = generateEquivalentConstraintClasses(constraints)
115-
111+
val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
116112
var inferredConstraints = Set.empty[Expression]
117-
constraints.foreach {
113+
aliasedConstraints.foreach {
118114
case eq @ EqualTo(l: Attribute, r: Attribute) =>
119-
val candidateConstraints = constraints - eq
120-
inferredConstraints ++= candidateConstraints.map(_ transform {
121-
case a: Attribute if a.semanticEquals(l) &&
122-
!isRecursiveDeduction(r, constraintClasses) => r
123-
})
124-
inferredConstraints ++= candidateConstraints.map(_ transform {
125-
case a: Attribute if a.semanticEquals(r) &&
126-
!isRecursiveDeduction(l, constraintClasses) => l
127-
})
115+
val candidateConstraints = aliasedConstraints - eq
116+
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
117+
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
128118
case _ => // No inference
129119
}
130120
inferredConstraints -- constraints
131121
}
132122

133123
/**
134-
* Generate a sequence of expression sets from constraints, where each set stores an equivalence
135-
* class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
136-
* expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
137-
* to an selected attribute.
124+
* Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
125+
* Thus non-converging inference can be prevented.
126+
* E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
127+
* Also, the size of constraints is reduced without losing any information.
128+
* When the inferred filters are pushed down the operators that generate the alias,
129+
* the alias names used in filters are replaced by the aliased expressions.
138130
*/
139-
private def generateEquivalentConstraintClasses(
140-
constraints: Set[Expression]): Seq[Set[Expression]] = {
141-
var constraintClasses = Seq.empty[Set[Expression]]
142-
constraints.foreach {
143-
case eq @ EqualTo(l: Attribute, r: Attribute) =>
144-
// Transform [[Alias]] to its child.
145-
val left = aliasMap.getOrElse(l, l)
146-
val right = aliasMap.getOrElse(r, r)
147-
// Get the expression set for an equivalence constraint class.
148-
val leftConstraintClass = getConstraintClass(left, constraintClasses)
149-
val rightConstraintClass = getConstraintClass(right, constraintClasses)
150-
if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
151-
// Combine the two sets.
152-
constraintClasses = constraintClasses
153-
.diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
154-
(leftConstraintClass ++ rightConstraintClass)
155-
} else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
156-
// Update equivalence class of `left` expression.
157-
constraintClasses = constraintClasses
158-
.diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
159-
} else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
160-
// Update equivalence class of `right` expression.
161-
constraintClasses = constraintClasses
162-
.diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
163-
} else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
164-
// Create new equivalence constraint class since neither expression presents
165-
// in any classes.
166-
constraintClasses = constraintClasses :+ Set(left, right)
167-
}
168-
case _ => // Skip
131+
private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
132+
: Set[Expression] = {
133+
val attributesInEqualTo = constraints.flatMap {
134+
case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
135+
case _ => Nil
169136
}
170-
171-
constraintClasses
172-
}
173-
174-
/**
175-
* Get all expressions equivalent to the selected expression.
176-
*/
177-
private def getConstraintClass(
178-
expr: Expression,
179-
constraintClasses: Seq[Set[Expression]]): Set[Expression] =
180-
constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
181-
182-
/**
183-
* Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
184-
* has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
185-
* Here we first get all expressions equal to `attr` and then check whether at least one of them
186-
* is a child of the referenced expression.
187-
*/
188-
private def isRecursiveDeduction(
189-
attr: Attribute,
190-
constraintClasses: Seq[Set[Expression]]): Boolean = {
191-
val expr = aliasMap.getOrElse(attr, attr)
192-
getConstraintClass(expr, constraintClasses).exists { e =>
193-
expr.children.exists(_.semanticEquals(e))
137+
var aliasedConstraints = constraints
138+
attributesInEqualTo.foreach { a =>
139+
if (aliasMap.contains(a)) {
140+
val child = aliasMap.get(a).get
141+
aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
142+
}
194143
}
144+
aliasedConstraints
195145
}
146+
147+
private def replaceConstraints(
148+
constraints: Set[Expression],
149+
source: Expression,
150+
destination: Attribute): Set[Expression] = constraints.map(_ transform {
151+
case e: Expression if e.semanticEquals(source) => destination
152+
})
196153
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
151151
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
152152
.analyze
153153
val correctAnswer = t1
154-
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
154+
.where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
155155
.select('a, 'b.as('d)).as("t")
156-
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
156+
.join(t2.where(IsNotNull('a)), Inner,
157157
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
158158
.analyze
159159
val optimized = Optimize.execute(originalQuery)
@@ -176,24 +176,48 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
176176
&& "t.int_col".attr === "t2.a".attr))
177177
.analyze
178178
val correctAnswer = t1
179-
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
180-
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))
181-
&& Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b)))
182-
&& 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b
183-
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
184-
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)))
179+
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
180+
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
181+
&& 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
182+
&& 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
183+
&& 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
185184
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
186185
.select('int_col, 'd, 'a).as("t")
187-
.join(t2
188-
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
189-
&& 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
186+
.join(
187+
t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
188+
'a === Coalesce(Seq('a, 'a))),
189+
Inner,
190190
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
191191
&& "t.int_col".attr === "t2.a".attr))
192192
.analyze
193193
val optimized = Optimize.execute(originalQuery)
194194
comparePlans(optimized, correctAnswer)
195195
}
196196

197+
test("inner join with EqualTo expressions containing part of each other: don't generate " +
198+
"constraints for recursive functions") {
199+
val t1 = testRelation.subquery('t1)
200+
val t2 = testRelation.subquery('t2)
201+
202+
// We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
203+
// complicated constraints through the constraint inference procedure.
204+
val originalQuery = t1
205+
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
206+
.where('a === 'd && 'c === 'e)
207+
.join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
208+
.analyze
209+
val correctAnswer = t1
210+
.where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
211+
'c === Coalesce(Seq('a, 'b)))
212+
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
213+
.join(t2.where(IsNotNull('a) && IsNotNull('c)),
214+
Inner,
215+
Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
216+
.analyze
217+
val optimized = Optimize.execute(originalQuery)
218+
comparePlans(optimized, correctAnswer)
219+
}
220+
197221
test("generate correct filters for alias that don't produce recursive constraints") {
198222
val t1 = testRelation.subquery('t1)
199223

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
134134
verifyConstraints(aliasedRelation.analyze.constraints,
135135
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
136136
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
137-
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
138-
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
139137
resolveColumn(aliasedRelation.analyze, "z") > 10,
140138
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
141139

0 commit comments

Comments
 (0)