Skip to content

Commit ffb362a

Browse files
committed
[SPARK-19712][SQL][FOLLOW-UP] reduce code duplication
## What changes were proposed in this pull request? abstract some common code into a method. ## How was this patch tested? existing tests Closes apache#24281 from cloud-fan/minor. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent d04a737 commit ffb362a

File tree

1 file changed

+47
-112
lines changed

1 file changed

+47
-112
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

Lines changed: 47 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
3535
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
3636
// LeftSemi/LeftAnti over Project
3737
case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
38-
if pList.forall(_.deterministic) &&
38+
if pList.forall(_.deterministic) &&
3939
!pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
4040
canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
4141
if (joinCond.isEmpty) {
@@ -52,101 +52,29 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
5252
}
5353

5454
// LeftSemi/LeftAnti over Aggregate
55-
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
56-
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
55+
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
56+
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
5757
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
58-
if (joinCond.isEmpty) {
59-
// No join condition, just push down Join below Aggregate
60-
agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint))
61-
} else {
62-
val aliasMap = PushDownPredicate.getAliasMap(agg)
63-
64-
// For each join condition, expand the alias and check if the condition can be evaluated
65-
// using attributes produced by the aggregate operator's child operator.
66-
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
67-
val replaced = replaceAlias(cond, aliasMap)
68-
cond.references.nonEmpty &&
69-
replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
70-
}
71-
72-
// Check if the remaining predicates do not contain columns from the right
73-
// hand side of the join. Since the remaining predicates will be kept
74-
// as a filter over aggregate, this check is necessary after the left semi
75-
// or left anti join is moved below aggregate. The reason is, for this kind
76-
// of join, we only output from the left leg of the join.
77-
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
78-
79-
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
80-
val pushDownPredicate = pushDown.reduce(And)
81-
val replaced = replaceAlias(pushDownPredicate, aliasMap)
82-
val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint))
83-
// If there is no more filter to stay up, just return the Aggregate over Join.
84-
// Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)".
85-
if (stayUp.isEmpty) {
86-
newAgg
87-
} else {
88-
joinType match {
89-
// In case of Left semi join, the part of the join condition which does not refer to
90-
// to child attributes of the aggregate operator are kept as a Filter over window.
91-
case LeftSemi => Filter(stayUp.reduce(And), newAgg)
92-
// In case of left anti join, the join is pushed down when the entire join condition
93-
// is eligible to be pushed down to preserve the semantics of left anti join.
94-
case _ => join
95-
}
96-
}
97-
} else {
98-
// The join condition is not a subset of the Aggregate's GROUP BY columns,
99-
// no push down.
100-
join
101-
}
58+
val aliasMap = PushDownPredicate.getAliasMap(agg)
59+
val canPushDownPredicate = (predicate: Expression) => {
60+
val replaced = replaceAlias(predicate, aliasMap)
61+
predicate.references.nonEmpty &&
62+
replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
63+
}
64+
val makeJoinCondition = (predicates: Seq[Expression]) => {
65+
replaceAlias(predicates.reduce(And), aliasMap)
10266
}
67+
pushDownJoin(join, canPushDownPredicate, makeJoinCondition)
10368

10469
// LeftSemi/LeftAnti over Window
105-
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
106-
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
107-
if (joinCond.isEmpty) {
108-
// No join condition, just push down Join below Window
109-
w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint))
110-
} else {
111-
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++
112-
rightOp.outputSet
113-
114-
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
115-
cond.references.subsetOf(partitionAttrs)
116-
}
117-
118-
// Check if the remaining predicates do not contain columns from the right
119-
// hand side of the join. Since the remaining predicates will be kept
120-
// as a filter over window, this check is necessary after the left semi
121-
// or left anti join is moved below window. The reason is, for this kind
122-
// of join, we only output from the left leg of the join.
123-
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
124-
125-
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
126-
val predicate = pushDown.reduce(And)
127-
val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint))
128-
if (stayUp.isEmpty) {
129-
newPlan
130-
} else {
131-
joinType match {
132-
// In case of Left semi join, the part of the join condition which does not refer to
133-
// to partition attributes of the window operator are kept as a Filter over window.
134-
case LeftSemi => Filter(stayUp.reduce(And), newPlan)
135-
// In case of left anti join, the join is pushed down when the entire join condition
136-
// is eligible to be pushed down to preserve the semantics of left anti join.
137-
case _ => join
138-
}
139-
}
140-
} else {
141-
// The join condition is not a subset of the Window's PARTITION BY clause,
142-
// no push down.
143-
join
144-
}
145-
}
70+
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(_), _, _)
71+
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
72+
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet
73+
pushDownJoin(join, _.references.subsetOf(partitionAttrs), _.reduce(And))
14674

14775
// LeftSemi/LeftAnti over Union
148-
case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
149-
if canPushThroughCondition(union.children, joinCond, rightOp) =>
76+
case Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
77+
if canPushThroughCondition(union.children, joinCond, rightOp) =>
15078
if (joinCond.isEmpty) {
15179
// Push down the Join below Union
15280
val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) }
@@ -165,11 +93,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
16593
}
16694

16795
// LeftSemi/LeftAnti over UnaryNode
168-
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
169-
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
170-
pushDownJoin(join, u.child) { joinCond =>
171-
u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint)))
172-
}
96+
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _)
97+
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
98+
val validAttrs = u.child.outputSet ++ rightOp.outputSet
99+
pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And))
173100
}
174101

175102
/**
@@ -192,35 +119,43 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
192119
}
193120
}
194121

195-
196122
private def pushDownJoin(
197123
join: Join,
198-
grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = {
124+
canPushDownPredicate: Expression => Boolean,
125+
makeJoinCondition: Seq[Expression] => Expression): LogicalPlan = {
126+
assert(join.left.children.length == 1)
127+
199128
if (join.condition.isEmpty) {
200-
insertJoin(None)
129+
join.left.withNewChildren(Seq(join.copy(left = join.left.children.head)))
201130
} else {
202131
val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get)
203-
.partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)}
132+
.partition(canPushDownPredicate)
133+
134+
// Check if the remaining predicates do not contain columns from the right hand side of the
135+
// join. Since the remaining predicates will be kept as a filter over the operator under join,
136+
// this check is necessary after the left-semi/anti join is pushed down. The reason is, for
137+
// this kind of join, we only output from the left leg of the join.
138+
val referRightSideCols = AttributeSet(stayUp.toSet).intersect(join.right.outputSet).nonEmpty
204139

205-
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet)
206-
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
207-
val newChild = insertJoin(Option(pushDown.reduceLeft(And)))
208-
if (stayUp.nonEmpty) {
140+
if (pushDown.isEmpty || referRightSideCols) {
141+
join
142+
} else {
143+
val newPlan = join.left.withNewChildren(Seq(join.copy(
144+
left = join.left.children.head, condition = Some(makeJoinCondition(pushDown)))))
145+
// If there is no more filter to stay up, return the new plan that has join pushed down.
146+
if (stayUp.isEmpty) {
147+
newPlan
148+
} else {
209149
join.joinType match {
210150
// In case of Left semi join, the part of the join condition which does not refer to
211-
// to attributes of the grandchild are kept as a Filter over window.
212-
case LeftSemi => Filter(stayUp.reduce(And), newChild)
213-
// In case of left anti join, the join is pushed down when the entire join condition
214-
// is eligible to be pushed down to preserve the semantics of left anti join.
151+
// to attributes of the grandchild are kept as a Filter above.
152+
case LeftSemi => Filter(stayUp.reduce(And), newPlan)
153+
// In case of left-anti join, the join is pushed down only when the entire join
154+
// condition is eligible to be pushed down to preserve the semantics of left-anti join.
215155
case _ => join
216156
}
217-
} else {
218-
newChild
219157
}
220-
} else {
221-
join
222158
}
223159
}
224160
}
225161
}
226-

0 commit comments

Comments
 (0)