@@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
361361 _.containsAnyPattern(AND , OR , NOT ), ruleId) {
362362 case q : LogicalPlan => q.transformExpressionsUpWithPruning(
363363 _.containsAnyPattern(AND , OR , NOT ), ruleId) {
364- case TrueLiteral And e => e
365- case e And TrueLiteral => e
366- case FalseLiteral Or e => e
367- case e Or FalseLiteral => e
368-
369- case FalseLiteral And _ => FalseLiteral
370- case _ And FalseLiteral => FalseLiteral
371- case TrueLiteral Or _ => TrueLiteral
372- case _ Or TrueLiteral => TrueLiteral
373-
374- case a And b if Not (a).semanticEquals(b) =>
375- If ( IsNull (a), Literal .create( null , a.dataType), FalseLiteral )
376- case a And b if a.semanticEquals( Not (b)) =>
377- If ( IsNull (b), Literal .create( null , b.dataType), FalseLiteral )
378-
379- case a Or b if Not (a).semanticEquals(b) =>
380- If (IsNull (a), Literal .create(null , a.dataType), TrueLiteral )
381- case a Or b if a.semanticEquals(Not (b)) =>
382- If (IsNull (b), Literal .create(null , b.dataType), TrueLiteral )
383-
384- case a And b if a .semanticEquals(b) => a
385- case a Or b if a.semanticEquals(b) => a
386-
387- // The following optimizations are applicable only when the operands are not nullable,
388- // since the three-value logic of AND and OR are different in NULL handling.
389- // See the chart:
390- // +---------+---------+---------+---------+
391- // | operand | operand | OR | AND |
392- // +---------+---------+---------+---------+
393- // | TRUE | TRUE | TRUE | TRUE |
394- // | TRUE | FALSE | TRUE | FALSE |
395- // | FALSE | FALSE | FALSE | FALSE |
396- // | UNKNOWN | TRUE | TRUE | UNKNOWN |
397- // | UNKNOWN | FALSE | UNKNOWN | FALSE |
398- // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
399- // +---------+---------+---------+---------+
400-
401- // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
402- case a And (b Or c) if ! a.nullable && Not (a).semanticEquals(b) => And (a, c)
403- // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
404- case a And (b Or c) if ! a.nullable && Not (a).semanticEquals(c) => And (a, b)
405- // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, c can't be nullable.
406- case (a Or b) And c if ! c.nullable && a.semanticEquals( Not (c)) => And (b, c)
407- // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus , c can't be nullable.
408- case (a Or b) And c if ! c.nullable && b.semanticEquals( Not (c)) => And (a, c)
409-
410- // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE ) = TRUE . Thus, a can't be nullable.
411- case a Or (b And c) if ! a .nullable && Not (a) .semanticEquals(b) => Or (a , c)
412- // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE ) = TRUE . Thus, a can't be nullable.
413- case a Or (b And c) if ! a .nullable && Not (a) .semanticEquals(c) => Or (a, b )
414- // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
415- case (a And b) Or c if ! c.nullable && a.semanticEquals( Not (c)) => Or (b, c)
416- // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus , c can't be nullable.
417- case (a And b) Or c if ! c.nullable && b.semanticEquals( Not (c)) => Or (a, c)
418-
419- // Common factor elimination for conjunction
420- case and @ (left And right) =>
421- // 1. Split left and right to get the disjunctive predicates,
422- // i.e. lhs = (a || b), rhs = (a || c)
423- // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
424- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
425- // 4. If common is non-empty, apply the formula to get the optimized predicate:
426- // common || (ldiff && rdiff)
427- // 5. Else if common is empty, split left and right to get the conjunctive predicates.
428- // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c )
429- // optimized predicate: (a && b && c)
430- val lhs = splitDisjunctivePredicates(left)
431- val rhs = splitDisjunctivePredicates(right )
432- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
433- if (common.nonEmpty) {
434- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals) )
435- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals) )
436- if (ldiff.isEmpty || rdiff.isEmpty) {
437- // (a || b || c || ...) && (a || b) => (a || b )
438- common.reduce( Or )
439- } else {
440- // (a || b || c || ...) && (a || b || d || ...) =>
441- // a || b || ((c || ...) && (d || ...))
442- (common :+ And (ldiff.reduce( Or ), rdiff.reduce( Or ))).reduce( Or )
443- }
364+ actualExprTransformer
365+ }
366+ }
367+
368+ val actualExprTransformer : PartialFunction [ Expression , Expression ] = {
369+ case TrueLiteral And e => e
370+ case e And TrueLiteral => e
371+ case FalseLiteral Or e => e
372+ case e Or FalseLiteral => e
373+
374+ case FalseLiteral And _ => FalseLiteral
375+ case _ And FalseLiteral => FalseLiteral
376+ case TrueLiteral Or _ => TrueLiteral
377+ case _ Or TrueLiteral => TrueLiteral
378+
379+ case a And b if Not (a).semanticEquals(b) =>
380+ If (IsNull (a), Literal .create(null , a.dataType), FalseLiteral )
381+ case a And b if a.semanticEquals(Not (b)) =>
382+ If (IsNull (b), Literal .create(null , b.dataType), FalseLiteral )
383+
384+ case a Or b if Not (a) .semanticEquals(b) =>
385+ If ( IsNull (a), Literal .create( null , a.dataType), TrueLiteral )
386+ case a Or b if a.semanticEquals( Not (b)) =>
387+ If ( IsNull (b), Literal .create( null , b.dataType), TrueLiteral )
388+
389+ case a And b if a.semanticEquals(b) => a
390+ case a Or b if a.semanticEquals(b) => a
391+
392+ // The following optimizations are applicable only when the operands are not nullable,
393+ // since the three-value logic of AND and OR are different in NULL handling.
394+ // See the chart:
395+ // +---------+---------+---------+---------+
396+ // | operand | operand | OR | AND |
397+ // +---------+---------+---------+---------+
398+ // | TRUE | TRUE | TRUE | TRUE |
399+ // | TRUE | FALSE | TRUE | FALSE |
400+ // | FALSE | FALSE | FALSE | FALSE |
401+ // | UNKNOWN | TRUE | TRUE | UNKNOWN |
402+ // | UNKNOWN | FALSE | UNKNOWN | FALSE |
403+ // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
404+ // +---------+---------+---------+---------+
405+
406+ // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
407+ case a And (b Or c) if ! a.nullable && Not (a).semanticEquals(b) => And (a , c)
408+ // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, a can't be nullable.
409+ case a And (b Or c) if ! a.nullable && Not (a).semanticEquals(c) => And (a, b)
410+ // (( NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL ) = FALSE . Thus, c can't be nullable.
411+ case ( a Or b) And c if ! c .nullable && a .semanticEquals(Not (c)) => And (b , c)
412+ // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL ) = FALSE . Thus, c can't be nullable.
413+ case ( a Or b) And c if ! c .nullable && b .semanticEquals(Not (c)) => And (a, c )
414+
415+ // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
416+ case a Or (b And c) if ! a.nullable && Not (a).semanticEquals(b) => Or (a , c)
417+ // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a can't be nullable.
418+ case a Or (b And c) if ! a.nullable && Not (a).semanticEquals(c) => Or (a, b)
419+ // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
420+ case (a And b) Or c if ! c.nullable && a.semanticEquals( Not (c)) => Or (b, c)
421+ // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c can't be nullable.
422+ case (a And b) Or c if ! c.nullable && b.semanticEquals( Not (c)) => Or (a, c)
423+
424+ // Common factor elimination for conjunction
425+ case and @ (left And right) =>
426+ // 1. Split left and right to get the disjunctive predicates,
427+ // i.e. lhs = (a || b), rhs = (a || c)
428+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a )
429+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = ( c)
430+ // 4. If common is non-empty, apply the formula to get the optimized predicate:
431+ // common || (ldiff && rdiff )
432+ // 5. Else if common is empty, split left and right to get the conjunctive predicates.
433+ // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), distinct = (a, b, c)
434+ // optimized predicate: (a && b && c )
435+ val lhs = splitDisjunctivePredicates(left )
436+ val rhs = splitDisjunctivePredicates(right)
437+ val common = lhs.filter(e => rhs.exists(e.semanticEquals) )
438+ if ( common.nonEmpty) {
439+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
440+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
441+ if (ldiff.isEmpty || rdiff.isEmpty) {
442+ // (a || b || c || ...) && (a || b) => (a || b )
443+ common.reduce( Or )
444444 } else {
445- // No common factors from disjunctive predicates, reduce common factor from conjunction
446- val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
447- val distinct = ExpressionSet (all)
448- if (all.size == distinct.size) {
449- // No common factors, return the original predicate
450- and
451- } else {
452- // (a && b) && a && (a && c) => a && b && c
453- buildBalancedPredicate(distinct.toSeq, And )
454- }
445+ // (a || b || c || ...) && (a || b || d || ...) =>
446+ // a || b || ((c || ...) && (d || ...))
447+ (common :+ And (ldiff.reduce(Or ), rdiff.reduce(Or ))).reduce(Or )
455448 }
449+ } else {
450+ // No common factors from disjunctive predicates, reduce common factor from conjunction
451+ val all = splitConjunctivePredicates(left) ++ splitConjunctivePredicates(right)
452+ val distinct = ExpressionSet (all)
453+ if (all.size == distinct.size) {
454+ // No common factors, return the original predicate
455+ and
456+ } else {
457+ // (a && b) && a && (a && c) => a && b && c
458+ buildBalancedPredicate(distinct.toSeq, And )
459+ }
460+ }
456461
457- // Common factor elimination for disjunction
458- case or @ (left Or right) =>
459- // 1. Split left and right to get the conjunctive predicates,
460- // i.e. lhs = (a && b), rhs = (a && c)
461- // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
462- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
463- // 4. If common is non-empty, apply the formula to get the optimized predicate:
464- // common && (ldiff || rdiff)
465- // 5. Else if common is empty, split left and right to get the conjunctive predicates.
466- // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
467- // optimized predicate: (a || b || c)
468- val lhs = splitConjunctivePredicates(left)
469- val rhs = splitConjunctivePredicates(right)
470- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
471- if (common.nonEmpty) {
472- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
473- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
474- if (ldiff.isEmpty || rdiff.isEmpty) {
475- // (a && b) || (a && b && c && ...) => a && b
476- common.reduce(And )
477- } else {
478- // (a && b && c && ...) || (a && b && d && ...) =>
479- // a && b && ((c && ...) || (d && ...))
480- (common :+ Or (ldiff.reduce(And ), rdiff.reduce(And ))).reduce(And )
481- }
462+ // Common factor elimination for disjunction
463+ case or @ (left Or right) =>
464+ // 1. Split left and right to get the conjunctive predicates,
465+ // i.e. lhs = (a && b), rhs = (a && c)
466+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common = (a)
467+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), rdiff = (c)
468+ // 4. If common is non-empty, apply the formula to get the optimized predicate:
469+ // common && (ldiff || rdiff)
470+ // 5. Else if common is empty, split left and right to get the conjunctive predicates.
471+ // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), distinct = (a, b, c)
472+ // optimized predicate: (a || b || c)
473+ val lhs = splitConjunctivePredicates(left)
474+ val rhs = splitConjunctivePredicates(right)
475+ val common = lhs.filter(e => rhs.exists(e.semanticEquals))
476+ if (common.nonEmpty) {
477+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
478+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
479+ if (ldiff.isEmpty || rdiff.isEmpty) {
480+ // (a && b) || (a && b && c && ...) => a && b
481+ common.reduce(And )
482482 } else {
483- // No common factors in conjunctive predicates, reduce common factor from disjunction
484- val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
485- val distinct = ExpressionSet (all)
486- if (all.size == distinct.size) {
487- // No common factors, return the original predicate
488- or
489- } else {
490- // (a || b) || a || (a || c) => a || b || c
491- buildBalancedPredicate(distinct.toSeq, Or )
492- }
483+ // (a && b && c && ...) || (a && b && d && ...) =>
484+ // a && b && ((c && ...) || (d && ...))
485+ (common :+ Or (ldiff.reduce(And ), rdiff.reduce(And ))).reduce(And )
493486 }
487+ } else {
488+ // No common factors in conjunctive predicates, reduce common factor from disjunction
489+ val all = splitDisjunctivePredicates(left) ++ splitDisjunctivePredicates(right)
490+ val distinct = ExpressionSet (all)
491+ if (all.size == distinct.size) {
492+ // No common factors, return the original predicate
493+ or
494+ } else {
495+ // (a || b) || a || (a || c) => a || b || c
496+ buildBalancedPredicate(distinct.toSeq, Or )
497+ }
498+ }
494499
495- case Not (TrueLiteral ) => FalseLiteral
496- case Not (FalseLiteral ) => TrueLiteral
500+ case Not (TrueLiteral ) => FalseLiteral
501+ case Not (FalseLiteral ) => TrueLiteral
497502
498- case Not (a GreaterThan b) => LessThanOrEqual (a, b)
499- case Not (a GreaterThanOrEqual b) => LessThan (a, b)
503+ case Not (a GreaterThan b) => LessThanOrEqual (a, b)
504+ case Not (a GreaterThanOrEqual b) => LessThan (a, b)
500505
501- case Not (a LessThan b) => GreaterThanOrEqual (a, b)
502- case Not (a LessThanOrEqual b) => GreaterThan (a, b)
506+ case Not (a LessThan b) => GreaterThanOrEqual (a, b)
507+ case Not (a LessThanOrEqual b) => GreaterThan (a, b)
503508
504- case Not (a Or b) => And (Not (a), Not (b))
505- case Not (a And b) => Or (Not (a), Not (b))
509+ // SPARK-54881: push down the NOT operators on children, before attaching the junction Node
510+ // to the main tree. This ensures idempotency in an optimal way and avoids an extra rule
511+ // iteration.
512+ case Not (a Or b) =>
513+ And (Not (a), Not (b)).transformDownWithPruning(_.containsPattern(NOT ), ruleId) {
514+ actualExprTransformer
515+ }
516+ case Not (a And b) =>
517+ Or (Not (a), Not (b)).transformDownWithPruning(_.containsPattern(NOT ), ruleId) {
518+ actualExprTransformer
519+ }
506520
507- case Not (Not (e)) => e
521+ case Not (Not (e)) => e
508522
509- case Not (IsNull (e)) => IsNotNull (e)
510- case Not (IsNotNull (e)) => IsNull (e)
511- }
523+ case Not (IsNull (e)) => IsNotNull (e)
524+ case Not (IsNotNull (e)) => IsNull (e)
512525 }
513526}
514527
0 commit comments