Skip to content

Commit d87d30e

Browse files
committed
[SPARK-23564][SQL] infer additional filters from constraints for join's children
## What changes were proposed in this pull request? The existing query constraints framework has 2 steps: 1. propagate constraints bottom up. 2. use constraints to infer additional filters for better data pruning. For step 2, it mostly helps with Join, because we can connect the constraints from children to the join condition and infer powerful filters to prune the data of the join sides. e.g., the left side has constraints `a = 1`, the join condition is `left.a = right.a`, then we can infer `right.a = 1` to the right side and prune the right side a lot. However, the current logic of inferring filters from constraints for Join is pretty weak. It infers the filters from Join's constraints. Some joins like left semi/anti exclude output from right side and the right side constraints will be lost here. This PR propose to check the left and right constraints individually, expand the constraints with join condition and add filters to children of join directly, instead of adding to the join condition. This reverts apache#20670 , covers apache#20717 and apache#20816 This is inspired by the original PRs and the tests are all from these PRs. Thanks to the authors mgaido91 maryannxue KaiXinXiaoLei ! ## How was this patch tested? new tests Author: Wenchen Fan <[email protected]> Closes apache#21083 from cloud-fan/join.
1 parent f70f46d commit d87d30e

File tree

3 files changed

+124
-121
lines changed

3 files changed

+124
-121
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -637,13 +637,11 @@ object CollapseWindow extends Rule[LogicalPlan] {
637637
* constraints. These filters are currently inserted to the existing conditions in the Filter
638638
* operators and on either side of Join operators.
639639
*
640-
* In addition, for left/right outer joins, infer predicate from the preserved side of the Join
641-
* operator and push the inferred filter over to the null-supplying side. For example, if the
642-
* preserved side has constraints of the form 'a > 5' and the join condition is 'a = b', in
643-
* which 'b' is an attribute from the null-supplying side, a [[Filter]] operator of 'b > 5' will
644-
* be applied to the null-supplying side.
640+
* Note: While this optimization is applicable to a lot of types of join, it primarily benefits
641+
* Inner and LeftSemi joins.
645642
*/
646-
object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelper {
643+
object InferFiltersFromConstraints extends Rule[LogicalPlan]
644+
with PredicateHelper with ConstraintHelper {
647645

648646
def apply(plan: LogicalPlan): LogicalPlan = {
649647
if (SQLConf.get.constraintPropagationEnabled) {
@@ -664,53 +662,52 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
664662
}
665663

666664
case join @ Join(left, right, joinType, conditionOpt) =>
667-
// Only consider constraints that can be pushed down completely to either the left or the
668-
// right child
669-
val constraints = join.allConstraints.filter { c =>
670-
c.references.subsetOf(left.outputSet) || c.references.subsetOf(right.outputSet)
671-
}
672-
// Remove those constraints that are already enforced by either the left or the right child
673-
val additionalConstraints = constraints -- (left.constraints ++ right.constraints)
674-
val newConditionOpt = conditionOpt match {
675-
case Some(condition) =>
676-
val newFilters = additionalConstraints -- splitConjunctivePredicates(condition)
677-
if (newFilters.nonEmpty) Option(And(newFilters.reduce(And), condition)) else conditionOpt
678-
case None =>
679-
additionalConstraints.reduceOption(And)
680-
}
681-
// Infer filter for left/right outer joins
682-
val newLeftOpt = joinType match {
683-
case RightOuter if newConditionOpt.isDefined =>
684-
val inferredConstraints = left.getRelevantConstraints(
685-
left.constraints
686-
.union(right.constraints)
687-
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
688-
val newFilters = inferredConstraints
689-
.filterNot(left.constraints.contains)
690-
.reduceLeftOption(And)
691-
newFilters.map(Filter(_, left))
692-
case _ => None
693-
}
694-
val newRightOpt = joinType match {
695-
case LeftOuter if newConditionOpt.isDefined =>
696-
val inferredConstraints = right.getRelevantConstraints(
697-
right.constraints
698-
.union(left.constraints)
699-
.union(splitConjunctivePredicates(newConditionOpt.get).toSet))
700-
val newFilters = inferredConstraints
701-
.filterNot(right.constraints.contains)
702-
.reduceLeftOption(And)
703-
newFilters.map(Filter(_, right))
704-
case _ => None
705-
}
665+
joinType match {
666+
// For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
667+
// inner join, it just drops the right side in the final output.
668+
case _: InnerLike | LeftSemi =>
669+
val allConstraints = getAllConstraints(left, right, conditionOpt)
670+
val newLeft = inferNewFilter(left, allConstraints)
671+
val newRight = inferNewFilter(right, allConstraints)
672+
join.copy(left = newLeft, right = newRight)
706673

707-
if ((newConditionOpt.isDefined && (newConditionOpt ne conditionOpt))
708-
|| newLeftOpt.isDefined || newRightOpt.isDefined) {
709-
Join(newLeftOpt.getOrElse(left), newRightOpt.getOrElse(right), joinType, newConditionOpt)
710-
} else {
711-
join
674+
// For right outer join, we can only infer additional filters for left side.
675+
case RightOuter =>
676+
val allConstraints = getAllConstraints(left, right, conditionOpt)
677+
val newLeft = inferNewFilter(left, allConstraints)
678+
join.copy(left = newLeft)
679+
680+
// For left join, we can only infer additional filters for right side.
681+
case LeftOuter | LeftAnti =>
682+
val allConstraints = getAllConstraints(left, right, conditionOpt)
683+
val newRight = inferNewFilter(right, allConstraints)
684+
join.copy(right = newRight)
685+
686+
case _ => join
712687
}
713688
}
689+
690+
private def getAllConstraints(
691+
left: LogicalPlan,
692+
right: LogicalPlan,
693+
conditionOpt: Option[Expression]): Set[Expression] = {
694+
val baseConstraints = left.constraints.union(right.constraints)
695+
.union(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil).toSet)
696+
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
697+
}
698+
699+
private def inferNewFilter(plan: LogicalPlan, constraints: Set[Expression]): LogicalPlan = {
700+
val newPredicates = constraints
701+
.union(constructIsNotNullConstraints(constraints, plan.output))
702+
.filter { c =>
703+
c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
704+
} -- plan.constraints
705+
if (newPredicates.isEmpty) {
706+
plan
707+
} else {
708+
Filter(newPredicates.reduce(And), plan)
709+
}
710+
}
714711
}
715712

716713
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,28 @@ package org.apache.spark.sql.catalyst.plans.logical
2020
import org.apache.spark.sql.catalyst.expressions._
2121

2222

23-
trait QueryPlanConstraints { self: LogicalPlan =>
23+
trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
2424

2525
/**
26-
* An [[ExpressionSet]] that contains an additional set of constraints, such as equality
27-
* constraints and `isNotNull` constraints, etc.
26+
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
27+
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
28+
* evaluate to `true` for all rows produced.
2829
*/
29-
lazy val allConstraints: ExpressionSet = {
30+
lazy val constraints: ExpressionSet = {
3031
if (conf.constraintPropagationEnabled) {
31-
ExpressionSet(validConstraints
32-
.union(inferAdditionalConstraints(validConstraints))
33-
.union(constructIsNotNullConstraints(validConstraints)))
32+
ExpressionSet(
33+
validConstraints
34+
.union(inferAdditionalConstraints(validConstraints))
35+
.union(constructIsNotNullConstraints(validConstraints, output))
36+
.filter { c =>
37+
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
38+
}
39+
)
3440
} else {
3541
ExpressionSet(Set.empty)
3642
}
3743
}
3844

39-
/**
40-
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
41-
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
42-
* evaluate to `true` for all rows produced.
43-
*/
44-
lazy val constraints: ExpressionSet = ExpressionSet(allConstraints.filter(selfReferenceOnly))
45-
4645
/**
4746
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
4847
* based on the given operator's constraint propagation logic. These constraints are then
@@ -52,30 +51,42 @@ trait QueryPlanConstraints { self: LogicalPlan =>
5251
* See [[Canonicalize]] for more details.
5352
*/
5453
protected def validConstraints: Set[Expression] = Set.empty
54+
}
55+
56+
trait ConstraintHelper {
5557

5658
/**
57-
* Returns an [[ExpressionSet]] that contains an additional set of constraints, such as
58-
* equality constraints and `isNotNull` constraints, etc., and that only contains references
59-
* to this [[LogicalPlan]] node.
59+
* Infers an additional set of constraints from a given set of equality constraints.
60+
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
61+
* additional constraint of the form `b = 5`.
6062
*/
61-
def getRelevantConstraints(constraints: Set[Expression]): ExpressionSet = {
62-
val allRelevantConstraints =
63-
if (conf.constraintPropagationEnabled) {
64-
constraints
65-
.union(inferAdditionalConstraints(constraints))
66-
.union(constructIsNotNullConstraints(constraints))
67-
} else {
68-
constraints
69-
}
70-
ExpressionSet(allRelevantConstraints.filter(selfReferenceOnly))
63+
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
64+
var inferredConstraints = Set.empty[Expression]
65+
constraints.foreach {
66+
case eq @ EqualTo(l: Attribute, r: Attribute) =>
67+
val candidateConstraints = constraints - eq
68+
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
69+
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
70+
case _ => // No inference
71+
}
72+
inferredConstraints -- constraints
7173
}
7274

75+
private def replaceConstraints(
76+
constraints: Set[Expression],
77+
source: Expression,
78+
destination: Attribute): Set[Expression] = constraints.map(_ transform {
79+
case e: Expression if e.semanticEquals(source) => destination
80+
})
81+
7382
/**
7483
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
7584
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
7685
* returns a constraint of the form `isNotNull(a)`
7786
*/
78-
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
87+
def constructIsNotNullConstraints(
88+
constraints: Set[Expression],
89+
output: Seq[Attribute]): Set[Expression] = {
7990
// First, we propagate constraints from the null intolerant expressions.
8091
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)
8192

@@ -111,32 +122,4 @@ trait QueryPlanConstraints { self: LogicalPlan =>
111122
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
112123
case _ => Seq.empty[Attribute]
113124
}
114-
115-
/**
116-
* Infers an additional set of constraints from a given set of equality constraints.
117-
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
118-
* additional constraint of the form `b = 5`.
119-
*/
120-
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
121-
var inferredConstraints = Set.empty[Expression]
122-
constraints.foreach {
123-
case eq @ EqualTo(l: Attribute, r: Attribute) =>
124-
val candidateConstraints = constraints - eq
125-
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
126-
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
127-
case _ => // No inference
128-
}
129-
inferredConstraints -- constraints
130-
}
131-
132-
private def replaceConstraints(
133-
constraints: Set[Expression],
134-
source: Expression,
135-
destination: Attribute): Set[Expression] = constraints.map(_ transform {
136-
case e: Expression if e.semanticEquals(source) => destination
137-
})
138-
139-
private def selfReferenceOnly(e: Expression): Boolean = {
140-
e.references.nonEmpty && e.references.subsetOf(outputSet) && e.deterministic
141-
}
142125
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,25 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
3535
InferFiltersFromConstraints,
3636
CombineFilters,
3737
SimplifyBinaryComparison,
38-
BooleanSimplification) :: Nil
38+
BooleanSimplification,
39+
PruneFilters) :: Nil
3940
}
4041

4142
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
4243

44+
private def testConstraintsAfterJoin(
45+
x: LogicalPlan,
46+
y: LogicalPlan,
47+
expectedLeft: LogicalPlan,
48+
expectedRight: LogicalPlan,
49+
joinType: JoinType) = {
50+
val condition = Some("x.a".attr === "y.a".attr)
51+
val originalQuery = x.join(y, joinType, condition).analyze
52+
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
53+
val optimized = Optimize.execute(originalQuery)
54+
comparePlans(optimized, correctAnswer)
55+
}
56+
4357
test("filter: filter out constraints in condition") {
4458
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
4559
val correctAnswer = testRelation
@@ -196,13 +210,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
196210
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
197211
val x = testRelation.subquery('x)
198212
val y = testRelation.subquery('y)
199-
val condition = Some("x.a".attr === "y.a".attr)
200-
val originalQuery = x.join(y, LeftSemi, condition).analyze
201-
val left = x.where(IsNotNull('a))
202-
val right = y.where(IsNotNull('a))
203-
val correctAnswer = left.join(right, LeftSemi, condition).analyze
204-
val optimized = Optimize.execute(originalQuery)
205-
comparePlans(optimized, correctAnswer)
213+
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi)
206214
}
207215

208216
test("SPARK-21479: Outer join after-join filters push down to null-supplying side") {
@@ -232,12 +240,27 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
232240
test("SPARK-21479: Outer join no filter push down to preserved side") {
233241
val x = testRelation.subquery('x)
234242
val y = testRelation.subquery('y)
235-
val condition = Some("x.a".attr === "y.a".attr)
236-
val originalQuery = x.join(y.where("y.a".attr === 1), LeftOuter, condition).analyze
237-
val left = x
238-
val right = y.where(IsNotNull('a) && 'a === 1)
239-
val correctAnswer = left.join(right, LeftOuter, condition).analyze
240-
val optimized = Optimize.execute(originalQuery)
241-
comparePlans(optimized, correctAnswer)
243+
testConstraintsAfterJoin(
244+
x, y.where("a".attr === 1),
245+
x, y.where(IsNotNull('a) && 'a === 1),
246+
LeftOuter)
247+
}
248+
249+
test("SPARK-23564: left anti join should filter out null join keys on right side") {
250+
val x = testRelation.subquery('x)
251+
val y = testRelation.subquery('y)
252+
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti)
253+
}
254+
255+
test("SPARK-23564: left outer join should filter out null join keys on right side") {
256+
val x = testRelation.subquery('x)
257+
val y = testRelation.subquery('y)
258+
testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter)
259+
}
260+
261+
test("SPARK-23564: right outer join should filter out null join keys on left side") {
262+
val x = testRelation.subquery('x)
263+
val y = testRelation.subquery('y)
264+
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
242265
}
243266
}

0 commit comments

Comments
 (0)