Skip to content

Commit 85383d2

Browse files
gatorsmiledbtsai
authored andcommitted
[SPARK-25860][SPARK-26107][FOLLOW-UP] Rule ReplaceNullWithFalseInPredicate
## What changes were proposed in this pull request? Based on apache#22857 and apache#23079, this PR did a few updates - Limit the data types of NULL to Boolean. - Limit the input data type of replaceNullWithFalse to Boolean; throw an exception in the testing mode. - Create a new file for the rule ReplaceNullWithFalseInPredicate - Update the description of this rule. ## How was this patch tested? Added a test case Closes apache#23139 from gatorsmile/followupSpark-25860. Authored-by: gatorsmile <[email protected]> Signed-off-by: DB Tsai <[email protected]>
1 parent 1c487f7 commit 85383d2

File tree

3 files changed

+119
-68
lines changed

3 files changed

+119
-68
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
21+
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
22+
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
23+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.types.BooleanType
26+
import org.apache.spark.util.Utils
27+
28+
29+
/**
30+
* A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search
31+
* condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
32+
* "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is
33+
* semantically equivalent to `FalseLiteral` when evaluating the whole search condition.
34+
*
35+
* Please note that FALSE and NULL are not exchangeable in most cases, when the search condition
36+
* contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable
37+
* in very limited cases.
38+
*
39+
* For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`.
40+
*
41+
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
42+
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
43+
* `Filter(FalseLiteral)`.
44+
*
45+
* Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch
46+
* conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions.
47+
*
48+
* For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified
49+
* into `Project(Literal(2))`.
50+
*/
51+
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
52+
53+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
54+
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
55+
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
56+
case p: LogicalPlan => p transformExpressions {
57+
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
58+
case cw @ CaseWhen(branches, _) =>
59+
val newBranches = branches.map { case (cond, value) =>
60+
replaceNullWithFalse(cond) -> value
61+
}
62+
cw.copy(branches = newBranches)
63+
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
64+
val newLambda = lf.copy(function = replaceNullWithFalse(func))
65+
af.copy(function = newLambda)
66+
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
67+
val newLambda = lf.copy(function = replaceNullWithFalse(func))
68+
ae.copy(function = newLambda)
69+
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
70+
val newLambda = lf.copy(function = replaceNullWithFalse(func))
71+
mf.copy(function = newLambda)
72+
}
73+
}
74+
75+
/**
76+
* Recursively traverse the Boolean-type expression to replace
77+
* `Literal(null, BooleanType)` with `FalseLiteral`, if possible.
78+
*
79+
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
80+
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or
81+
* `Literal(null, BooleanType)`.
82+
*/
83+
private def replaceNullWithFalse(e: Expression): Expression = e match {
84+
case Literal(null, BooleanType) =>
85+
FalseLiteral
86+
case And(left, right) =>
87+
And(replaceNullWithFalse(left), replaceNullWithFalse(right))
88+
case Or(left, right) =>
89+
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
90+
case cw: CaseWhen if cw.dataType == BooleanType =>
91+
val newBranches = cw.branches.map { case (cond, value) =>
92+
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
93+
}
94+
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
95+
CaseWhen(newBranches, newElseValue)
96+
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
97+
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
98+
case e if e.dataType == BooleanType =>
99+
e
100+
case e =>
101+
val message = "Expected a Boolean type expression in replaceNullWithFalse, " +
102+
s"but got the type `${e.dataType.catalogString}` in `${e.sql}`."
103+
if (Utils.isTesting) {
104+
throw new IllegalArgumentException(message)
105+
} else {
106+
logWarning(message)
107+
e
108+
}
109+
}
110+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -736,69 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] {
736736
flattenConcats(concat)
737737
}
738738
}
739-
740-
/**
741-
* A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations.
742-
*
743-
* This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates
744-
* in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions.
745-
*
746-
* For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
747-
*
748-
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
749-
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
750-
* `Filter(FalseLiteral)`.
751-
*
752-
* As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can
753-
* benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))`
754-
* can be simplified into `Project(Literal(2))`.
755-
*
756-
* As a result, many unnecessary computations can be removed in the query optimization phase.
757-
*/
758-
object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
759-
760-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
761-
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
762-
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
763-
case p: LogicalPlan => p transformExpressions {
764-
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
765-
case cw @ CaseWhen(branches, _) =>
766-
val newBranches = branches.map { case (cond, value) =>
767-
replaceNullWithFalse(cond) -> value
768-
}
769-
cw.copy(branches = newBranches)
770-
case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
771-
val newLambda = lf.copy(function = replaceNullWithFalse(func))
772-
af.copy(function = newLambda)
773-
case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
774-
val newLambda = lf.copy(function = replaceNullWithFalse(func))
775-
ae.copy(function = newLambda)
776-
case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
777-
val newLambda = lf.copy(function = replaceNullWithFalse(func))
778-
mf.copy(function = newLambda)
779-
}
780-
}
781-
782-
/**
783-
* Recursively replaces `Literal(null, _)` with `FalseLiteral`.
784-
*
785-
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
786-
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`.
787-
*/
788-
private def replaceNullWithFalse(e: Expression): Expression = e match {
789-
case cw: CaseWhen if cw.dataType == BooleanType =>
790-
val newBranches = cw.branches.map { case (cond, value) =>
791-
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
792-
}
793-
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
794-
CaseWhen(newBranches, newElseValue)
795-
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
796-
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
797-
case And(left, right) =>
798-
And(replaceNullWithFalse(left), replaceNullWithFalse(right))
799-
case Or(left, right) =>
800-
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
801-
case Literal(null, _) => FalseLiteral
802-
case _ => e
803-
}
804-
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,15 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
4444
private val anotherTestRelation = LocalRelation('d.int)
4545

4646
test("replace null inside filter and join conditions") {
47-
testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
48-
testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
47+
testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
48+
testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
49+
}
50+
51+
test("Not expected type - replaceNullWithFalse") {
52+
val e = intercept[IllegalArgumentException] {
53+
testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral)
54+
}.getMessage
55+
assert(e.contains("but got the type `int` in `CAST(NULL AS INT)"))
4956
}
5057

5158
test("replace null in branches of If") {

0 commit comments

Comments
 (0)