Skip to content

Commit c3843a0

Browse files
ahshahidpeter-toth
authored andcommitted
[SPARK-54881][SQL] Improve BooleanSimplification to handle negation of conjunction and disjunction in one pass
Fix to simplify boolean expression of form like !(expr1 || expr2) in a single pass, where expr1 and expr2 are binary comparison expression ### What changes were proposed in this pull request? In the rule BooleanSimplification , following two changes are done: 1) The current partial function passed as lambda to the transformExpressionUp api, is stored in a "val actualExprTransformer" 2) Instead of passing the lambda to the transformExpressionUp, the val actualExprTransformer, is passed. Till this point the code change is mere refactoring. The main change in the logic is 3) for the two cases case Not(a Or b) => And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) { actualExprTransformer } case Not(a And b) => Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) { actualExprTransformer } The new child node of AND and OR, are immediately acted upon by the partial function of expression transformer using transformExpressionDown, which will be efficient as the traversal on subtree will stop immediately if the node does not contain any NOT operator. ### Why are the changes needed? The change is needed because in the case of tramsformUp, the idempotency is not achieved in the optimal way ( single pass compared to double pass). The issue arises due to rule transforming Not (A || B) => (Not(A) AND Not(B)) Because the new child has added Not operations, they are not acted in that pass due to transformUp. With transformDown, the new children with Not, would be simplified in that pass itself. Please note that merely changing transformExpressionUp to transformExpressionDown, though will fix this issue, it will break idempotency for other cases ( as seen by failure in ConstantFoldingSuite. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added bug test ### Was this patch authored or co-authored using generative AI tooling? No Closes #53658 from ahshahid/SPARK-54881. Authored-by: Asif Hussain Shahid <asif.shahid@gmail.com> Signed-off-by: Peter Toth <peter.toth@gmail.com>
1 parent 5bcbc54 commit c3843a0

File tree

2 files changed

+181
-137
lines changed

2 files changed

+181
-137
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 150 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
361361
_.containsAnyPattern(AND, OR, NOT), ruleId) {
362362
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
363363
_.containsAnyPattern(AND, OR, NOT), ruleId) {
364-
case TrueLiteral And e => e
365-
case e And TrueLiteral => e
366-
case FalseLiteral Or e => e
367-
case e Or FalseLiteral => e
368-
369-
case FalseLiteral And _ => FalseLiteral
370-
case _ And FalseLiteral => FalseLiteral
371-
case TrueLiteral Or _ => TrueLiteral
372-
case _ Or TrueLiteral => TrueLiteral
373-
374-
case a And b if Not(a).semanticEquals(b) =>
375-
If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
376-
case a And b if a.semanticEquals(Not(b)) =>
377-
If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
378-
379-
case a Or b if Not(a).semanticEquals(b) =>
380-
If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
381-
case a Or b if a.semanticEquals(Not(b)) =>
382-
If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
383-
384-
case a And b if a.semanticEquals(b) => a
385-
case a Or b if a.semanticEquals(b) => a
386-
387-
// The following optimizations are applicable only when the operands are not nullable,
388-
// since the three-value logic of AND and OR are different in NULL handling.
389-
// See the chart:
390-
// +---------+---------+---------+---------+
391-
// | operand | operand | OR | AND |
392-
// +---------+---------+---------+---------+
393-
// | TRUE | TRUE | TRUE | TRUE |
394-
// | TRUE | FALSE | TRUE | FALSE |
395-
// | FALSE | FALSE | FALSE | FALSE |
396-
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
397-
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
398-
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
399-
// +---------+---------+---------+---------+
400-
401-
// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
402-
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
403-
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
404-
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
405-
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
406-
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
407-
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
408-
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
409-
410-
// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
411-
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
412-
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
413-
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
414-
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
415-
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
416-
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
417-
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
418-
419-
// Common factor elimination for conjunction
420-
case and @ (left And right) =>
421-
// 1. Split left and right to get the disjunctive predicates,
422-
// i.e. lhs = (a || b), rhs = (a || c)
423-
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
424-
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
425-
// 4. If common is non-empty, apply the formula to get the optimized predicate:
426-
// common || (ldiff && rdiff)
427-
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
428-
// for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c)
429-
// optimized predicate: (a && b && c)
430-
val lhs = splitDisjunctivePredicates(left)
431-
val rhs = splitDisjunctivePredicates(right)
432-
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
433-
if (common.nonEmpty) {
434-
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
435-
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
436-
if (ldiff.isEmpty || rdiff.isEmpty) {
437-
// (a || b || c || ...) && (a || b) => (a || b)
438-
common.reduce(Or)
439-
} else {
440-
// (a || b || c || ...) && (a || b || d || ...) =>
441-
// a || b || ((c || ...) && (d || ...))
442-
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
443-
}
364+
actualExprTransformer
365+
}
366+
}
367+
368+
val actualExprTransformer: PartialFunction[Expression, Expression] = {
369+
case TrueLiteral And e => e
370+
case e And TrueLiteral => e
371+
case FalseLiteral Or e => e
372+
case e Or FalseLiteral => e
373+
374+
case FalseLiteral And _ => FalseLiteral
375+
case _ And FalseLiteral => FalseLiteral
376+
case TrueLiteral Or _ => TrueLiteral
377+
case _ Or TrueLiteral => TrueLiteral
378+
379+
case a And b if Not(a).semanticEquals(b) =>
380+
If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
381+
case a And b if a.semanticEquals(Not(b)) =>
382+
If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
383+
384+
case a Or b if Not(a).semanticEquals(b) =>
385+
If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
386+
case a Or b if a.semanticEquals(Not(b)) =>
387+
If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
388+
389+
case a And b if a.semanticEquals(b) => a
390+
case a Or b if a.semanticEquals(b) => a
391+
392+
// The following optimizations are applicable only when the operands are not nullable,
393+
// since the three-value logic of AND and OR are different in NULL handling.
394+
// See the chart:
395+
// +---------+---------+---------+---------+
396+
// | operand | operand | OR | AND |
397+
// +---------+---------+---------+---------+
398+
// | TRUE | TRUE | TRUE | TRUE |
399+
// | TRUE | FALSE | TRUE | FALSE |
400+
// | FALSE | FALSE | FALSE | FALSE |
401+
// | UNKNOWN | TRUE | TRUE | UNKNOWN |
402+
// | UNKNOWN | FALSE | UNKNOWN | FALSE |
403+
// | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
404+
// +---------+---------+---------+---------+
405+
406+
// (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
407+
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
408+
// (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
409+
case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
410+
// ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
411+
case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
412+
// ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
413+
case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
414+
415+
// (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
416+
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
417+
// (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
418+
case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
419+
// ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
420+
case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
421+
// ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
422+
case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
423+
424+
// Common factor elimination for conjunction
425+
case and @ (left And right) =>
426+
// 1. Split left and right to get the disjunctive predicates,
427+
// i.e. lhs = (a || b), rhs = (a || c)
428+
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
429+
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
430+
// 4. If common is non-empty, apply the formula to get the optimized predicate:
431+
// common || (ldiff && rdiff)
432+
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
433+
// for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c)
434+
// optimized predicate: (a && b && c)
435+
val lhs = splitDisjunctivePredicates(left)
436+
val rhs = splitDisjunctivePredicates(right)
437+
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
438+
if (common.nonEmpty) {
439+
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
440+
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
441+
if (ldiff.isEmpty || rdiff.isEmpty) {
442+
// (a || b || c || ...) && (a || b) => (a || b)
443+
common.reduce(Or)
444444
} else {
445-
// No common factors from disjunctive predicates, reduce common factor from conjunction
446-
val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
447-
val distinct = ExpressionSet(all)
448-
if (all.size == distinct.size) {
449-
// No common factors, return the original predicate
450-
and
451-
} else {
452-
// (a && b) && a && (a && c) => a && b && c
453-
buildBalancedPredicate(distinct.toSeq, And)
454-
}
445+
// (a || b || c || ...) && (a || b || d || ...) =>
446+
// a || b || ((c || ...) && (d || ...))
447+
(common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
455448
}
449+
} else {
450+
// No common factors from disjunctive predicates, reduce common factor from conjunction
451+
val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
452+
val distinct = ExpressionSet(all)
453+
if (all.size == distinct.size) {
454+
// No common factors, return the original predicate
455+
and
456+
} else {
457+
// (a && b) && a && (a && c) => a && b && c
458+
buildBalancedPredicate(distinct.toSeq, And)
459+
}
460+
}
456461

457-
// Common factor elimination for disjunction
458-
case or @ (left Or right) =>
459-
// 1. Split left and right to get the conjunctive predicates,
460-
// i.e. lhs = (a && b), rhs = (a && c)
461-
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
462-
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
463-
// 4. If common is non-empty, apply the formula to get the optimized predicate:
464-
// common && (ldiff || rdiff)
465-
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
466-
// for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
467-
// optimized predicate: (a || b || c)
468-
val lhs = splitConjunctivePredicates(left)
469-
val rhs = splitConjunctivePredicates(right)
470-
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
471-
if (common.nonEmpty) {
472-
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
473-
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
474-
if (ldiff.isEmpty || rdiff.isEmpty) {
475-
// (a && b) || (a && b && c && ...) => a && b
476-
common.reduce(And)
477-
} else {
478-
// (a && b && c && ...) || (a && b && d && ...) =>
479-
// a && b && ((c && ...) || (d && ...))
480-
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
481-
}
462+
// Common factor elimination for disjunction
463+
case or @ (left Or right) =>
464+
// 1. Split left and right to get the conjunctive predicates,
465+
// i.e. lhs = (a && b), rhs = (a && c)
466+
// 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
467+
// 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
468+
// 4. If common is non-empty, apply the formula to get the optimized predicate:
469+
// common && (ldiff || rdiff)
470+
// 5. Else if common is empty, split left and right to get the conjunctive predicates.
471+
// for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
472+
// optimized predicate: (a || b || c)
473+
val lhs = splitConjunctivePredicates(left)
474+
val rhs = splitConjunctivePredicates(right)
475+
val common = lhs.filter(e => rhs.exists(e.semanticEquals))
476+
if (common.nonEmpty) {
477+
val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
478+
val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
479+
if (ldiff.isEmpty || rdiff.isEmpty) {
480+
// (a && b) || (a && b && c && ...) => a && b
481+
common.reduce(And)
482482
} else {
483-
// No common factors in conjunctive predicates, reduce common factor from disjunction
484-
val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
485-
val distinct = ExpressionSet(all)
486-
if (all.size == distinct.size) {
487-
// No common factors, return the original predicate
488-
or
489-
} else {
490-
// (a || b) || a || (a || c) => a || b || c
491-
buildBalancedPredicate(distinct.toSeq, Or)
492-
}
483+
// (a && b && c && ...) || (a && b && d && ...) =>
484+
// a && b && ((c && ...) || (d && ...))
485+
(common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
493486
}
487+
} else {
488+
// No common factors in conjunctive predicates, reduce common factor from disjunction
489+
val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
490+
val distinct = ExpressionSet(all)
491+
if (all.size == distinct.size) {
492+
// No common factors, return the original predicate
493+
or
494+
} else {
495+
// (a || b) || a || (a || c) => a || b || c
496+
buildBalancedPredicate(distinct.toSeq, Or)
497+
}
498+
}
494499

495-
case Not(TrueLiteral) => FalseLiteral
496-
case Not(FalseLiteral) => TrueLiteral
500+
case Not(TrueLiteral) => FalseLiteral
501+
case Not(FalseLiteral) => TrueLiteral
497502

498-
case Not(a GreaterThan b) => LessThanOrEqual(a, b)
499-
case Not(a GreaterThanOrEqual b) => LessThan(a, b)
503+
case Not(a GreaterThan b) => LessThanOrEqual(a, b)
504+
case Not(a GreaterThanOrEqual b) => LessThan(a, b)
500505

501-
case Not(a LessThan b) => GreaterThanOrEqual(a, b)
502-
case Not(a LessThanOrEqual b) => GreaterThan(a, b)
506+
case Not(a LessThan b) => GreaterThanOrEqual(a, b)
507+
case Not(a LessThanOrEqual b) => GreaterThan(a, b)
503508

504-
case Not(a Or b) => And(Not(a), Not(b))
505-
case Not(a And b) => Or(Not(a), Not(b))
509+
// SPARK-54881: push down the NOT operators on children, before attaching the junction Node
510+
// to the main tree. This ensures idempotency in an optimal way and avoids an extra rule
511+
// iteration.
512+
case Not(a Or b) =>
513+
And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) {
514+
actualExprTransformer
515+
}
516+
case Not(a And b) =>
517+
Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), ruleId) {
518+
actualExprTransformer
519+
}
506520

507-
case Not(Not(e)) => e
521+
case Not(Not(e)) => e
508522

509-
case Not(IsNull(e)) => IsNotNull(e)
510-
case Not(IsNotNull(e)) => IsNull(e)
511-
}
523+
case Not(IsNull(e)) => IsNotNull(e)
524+
case Not(IsNotNull(e)) => IsNull(e)
512525
}
513526
}
514527

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,37 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper {
291291
checkCondition(Not(IsNull($"b")), IsNotNull($"b"))
292292
}
293293

294+
test("SPARK-54881: simplify Not(Expr) in single pass") {
295+
def executeRuleOnce(exprToTest: Expression, optimizedExprExpected: Expression): Unit = {
296+
val planAfterRuleApp = BooleanSimplification.apply(testRelation.where(exprToTest).analyze)
297+
val expectedOptPlan = testRelation.where(optimizedExprExpected).analyze
298+
comparePlans(expectedOptPlan, planAfterRuleApp)
299+
}
300+
// check simplify Not(A <= B OR A >= B) to (a > b AND a < b) in single pass
301+
executeRuleOnce(
302+
Not(($"a" <= $"b") || ($"a" >= $"b")),
303+
$"a" > $"b" && $"a" < $"b"
304+
)
305+
306+
// check simplify Not((expr1 OR expr2) OR (expr3 AND expr4)) in single pass
307+
executeRuleOnce(
308+
Not(($"a" <= $"b" || $"c" > $"a" + 4) || ($"a" >= $"b" && $"c" < $"a")),
309+
And(
310+
And($"a" > $"b", $"c" <= $"a" + 4),
311+
Or($"a" < $"b", $"c" >= $"a")
312+
)
313+
)
314+
315+
// check simplify Not((expr1 OR expr2) AND (expr3 OR expr4)) in single pass
316+
executeRuleOnce(
317+
Not(($"a" <= $"b" || $"c" > $"a" + 4) && ($"a" >= $"b" || $"c" < $"a")),
318+
Or(
319+
And($"a" > $"b", $"c" <= $"a" + 4),
320+
And($"a" < $"b", $"c" >= $"a")
321+
)
322+
)
323+
}
324+
294325
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
295326
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
296327
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)

0 commit comments

Comments
 (0)