Skip to content

Commit bc9f9b4

Browse files
aokolnychyidbtsai
authored andcommitted
[SPARK-25860][SQL] Replace Literal(null, _) with FalseLiteral whenever possible
## What changes were proposed in this pull request? This PR proposes a new optimization rule that replaces `Literal(null, _)` with `FalseLiteral` in conditions in `Join` and `Filter`, predicates in `If`, conditions in `CaseWhen`. The idea is that some expressions evaluate to `false` if the underlying expression is `null` (as an example see `GeneratePredicate$create` or `doGenCode` and `eval` methods in `If` and `CaseWhen`). Therefore, we can replace `Literal(null, _)` with `FalseLiteral`, which can lead to more optimizations later on. Let’s consider a few examples. ``` val df = spark.range(1, 100).select($"id".as("l"), ($"id" > 50).as("b")) df.createOrReplaceTempView("t") df.createOrReplaceTempView("p") ``` **Case 1** ``` spark.sql("SELECT * FROM t WHERE if(l > 10, false, NULL)").explain(true) // without the new rule … == Optimized Logical Plan == Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- Filter if ((id#0L > 10)) false else null +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- *(1) Filter if ((id#0L > 10)) false else null +- *(1) Range (1, 100, step=1, splits=12) // with the new rule … == Optimized Logical Plan == LocalRelation <empty>, [l#2L, s#3] == Physical Plan == LocalTableScan <empty>, [l#2L, s#3] ``` **Case 2** ``` spark.sql("SELECT * FROM t WHERE CASE WHEN l < 10 THEN null WHEN l > 40 THEN false ELSE null END”).explain(true) // without the new rule ... == Optimized Logical Plan == Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE null END +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] +- *(1) Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE null END +- *(1) Range (1, 100, step=1, splits=12) // with the new rule ... == Optimized Logical Plan == LocalRelation <empty>, [l#2L, s#3] == Physical Plan == LocalTableScan <empty>, [l#2L, s#3] ``` **Case 3** ``` spark.sql("SELECT * FROM t JOIN p ON IF(t.l > p.l, null, false)").explain(true) // without the new rule ... == Optimized Logical Plan == Join Inner, if ((l#2L > l#37L)) null else false :- Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] : +- Range (1, 100, step=1, splits=Some(12)) +- Project [id#0L AS l#37L, cast(id#0L as string) AS s#38] +- Range (1, 100, step=1, splits=Some(12)) == Physical Plan == BroadcastNestedLoopJoin BuildRight, Inner, if ((l#2L > l#37L)) null else false :- *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3] : +- *(1) Range (1, 100, step=1, splits=12) +- BroadcastExchange IdentityBroadcastMode +- *(2) Project [id#0L AS l#37L, cast(id#0L as string) AS s#38] +- *(2) Range (1, 100, step=1, splits=12) // with the new rule ... == Optimized Logical Plan == LocalRelation <empty>, [l#2L, s#3, l#37L, s#38] ``` ## How was this patch tested? This PR comes with a set of dedicated tests. Closes apache#22857 from aokolnychyi/spark-25860. Authored-by: Anton Okolnychyi <[email protected]> Signed-off-by: DB Tsai <[email protected]>
1 parent 68dde34 commit bc9f9b4

File tree

5 files changed

+454
-2
lines changed

5 files changed

+454
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
8484
SimplifyConditionals,
8585
RemoveDispensableExpressions,
8686
SimplifyBinaryComparison,
87+
ReplaceNullWithFalse,
8788
PruneFilters,
8889
EliminateSorts,
8990
SimplifyCasts,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] {
736736
flattenConcats(concat)
737737
}
738738
}
739+
740+
/**
741+
* A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations.
742+
*
743+
* This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates
744+
* in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions.
745+
*
746+
* For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
747+
*
748+
* Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
749+
* this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
750+
* `Filter(FalseLiteral)`.
751+
*
752+
* As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can
753+
* benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))`
754+
* can be simplified into `Project(Literal(2))`.
755+
*
756+
* As a result, many unnecessary computations can be removed in the query optimization phase.
757+
*/
758+
object ReplaceNullWithFalse extends Rule[LogicalPlan] {
759+
760+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
761+
case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
762+
case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
763+
case p: LogicalPlan => p transformExpressions {
764+
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
765+
case cw @ CaseWhen(branches, _) =>
766+
val newBranches = branches.map { case (cond, value) =>
767+
replaceNullWithFalse(cond) -> value
768+
}
769+
cw.copy(branches = newBranches)
770+
}
771+
}
772+
773+
/**
774+
* Recursively replaces `Literal(null, _)` with `FalseLiteral`.
775+
*
776+
* Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
777+
* an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`.
778+
*/
779+
private def replaceNullWithFalse(e: Expression): Expression = e match {
780+
case cw: CaseWhen if cw.dataType == BooleanType =>
781+
val newBranches = cw.branches.map { case (cond, value) =>
782+
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
783+
}
784+
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
785+
CaseWhen(newBranches, newElseValue)
786+
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
787+
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
788+
case And(left, right) =>
789+
And(replaceNullWithFalse(left), replaceNullWithFalse(right))
790+
case Or(left, right) =>
791+
Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
792+
case Literal(null, _) => FalseLiteral
793+
case _ => e
794+
}
795+
}
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.dsl.plans._
23+
import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or}
24+
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
25+
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
26+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
27+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
28+
import org.apache.spark.sql.types.{BooleanType, IntegerType}
29+
30+
class ReplaceNullWithFalseSuite extends PlanTest {
31+
32+
object Optimize extends RuleExecutor[LogicalPlan] {
33+
val batches =
34+
Batch("Replace null literals", FixedPoint(10),
35+
NullPropagation,
36+
ConstantFolding,
37+
BooleanSimplification,
38+
SimplifyConditionals,
39+
ReplaceNullWithFalse) :: Nil
40+
}
41+
42+
private val testRelation = LocalRelation('i.int, 'b.boolean)
43+
private val anotherTestRelation = LocalRelation('d.int)
44+
45+
test("replace null inside filter and join conditions") {
46+
testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
47+
testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
48+
}
49+
50+
test("replace null in branches of If") {
51+
val originalCond = If(
52+
UnresolvedAttribute("i") > Literal(10),
53+
FalseLiteral,
54+
Literal(null, BooleanType))
55+
testFilter(originalCond, expectedCond = FalseLiteral)
56+
testJoin(originalCond, expectedCond = FalseLiteral)
57+
}
58+
59+
test("replace nulls in nested expressions in branches of If") {
60+
val originalCond = If(
61+
UnresolvedAttribute("i") > Literal(10),
62+
TrueLiteral && Literal(null, BooleanType),
63+
UnresolvedAttribute("b") && Literal(null, BooleanType))
64+
testFilter(originalCond, expectedCond = FalseLiteral)
65+
testJoin(originalCond, expectedCond = FalseLiteral)
66+
}
67+
68+
test("replace null in elseValue of CaseWhen") {
69+
val branches = Seq(
70+
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
71+
(UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
72+
val originalCond = CaseWhen(branches, Literal(null, BooleanType))
73+
val expectedCond = CaseWhen(branches, FalseLiteral)
74+
testFilter(originalCond, expectedCond)
75+
testJoin(originalCond, expectedCond)
76+
}
77+
78+
test("replace null in branch values of CaseWhen") {
79+
val branches = Seq(
80+
(UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType),
81+
(UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
82+
val originalCond = CaseWhen(branches, Literal(null))
83+
testFilter(originalCond, expectedCond = FalseLiteral)
84+
testJoin(originalCond, expectedCond = FalseLiteral)
85+
}
86+
87+
test("replace null in branches of If inside CaseWhen") {
88+
val originalBranches = Seq(
89+
(UnresolvedAttribute("i") < Literal(10)) ->
90+
If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral),
91+
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
92+
val originalCond = CaseWhen(originalBranches)
93+
94+
val expectedBranches = Seq(
95+
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
96+
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
97+
val expectedCond = CaseWhen(expectedBranches)
98+
99+
testFilter(originalCond, expectedCond)
100+
testJoin(originalCond, expectedCond)
101+
}
102+
103+
test("replace null in complex CaseWhen expressions") {
104+
val originalBranches = Seq(
105+
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
106+
(Literal(6) <= Literal(1)) -> FalseLiteral,
107+
(Literal(4) === Literal(5)) -> FalseLiteral,
108+
(UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType),
109+
(Literal(4) === Literal(4)) -> TrueLiteral)
110+
val originalCond = CaseWhen(originalBranches)
111+
112+
val expectedBranches = Seq(
113+
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
114+
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
115+
TrueLiteral -> TrueLiteral)
116+
val expectedCond = CaseWhen(expectedBranches)
117+
118+
testFilter(originalCond, expectedCond)
119+
testJoin(originalCond, expectedCond)
120+
}
121+
122+
test("replace null in Or") {
123+
val originalCond = Or(UnresolvedAttribute("b"), Literal(null))
124+
val expectedCond = UnresolvedAttribute("b")
125+
testFilter(originalCond, expectedCond)
126+
testJoin(originalCond, expectedCond)
127+
}
128+
129+
test("replace null in And") {
130+
val originalCond = And(UnresolvedAttribute("b"), Literal(null))
131+
testFilter(originalCond, expectedCond = FalseLiteral)
132+
testJoin(originalCond, expectedCond = FalseLiteral)
133+
}
134+
135+
test("replace nulls in nested And/Or expressions") {
136+
val originalCond = And(
137+
And(UnresolvedAttribute("b"), Literal(null)),
138+
Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null)))))
139+
testFilter(originalCond, expectedCond = FalseLiteral)
140+
testJoin(originalCond, expectedCond = FalseLiteral)
141+
}
142+
143+
test("replace null in And inside branches of If") {
144+
val originalCond = If(
145+
UnresolvedAttribute("i") > Literal(10),
146+
FalseLiteral,
147+
And(UnresolvedAttribute("b"), Literal(null, BooleanType)))
148+
testFilter(originalCond, expectedCond = FalseLiteral)
149+
testJoin(originalCond, expectedCond = FalseLiteral)
150+
}
151+
152+
test("replace null in branches of If inside And") {
153+
val originalCond = And(
154+
UnresolvedAttribute("b"),
155+
If(
156+
UnresolvedAttribute("i") > Literal(10),
157+
Literal(null),
158+
And(FalseLiteral, UnresolvedAttribute("b"))))
159+
testFilter(originalCond, expectedCond = FalseLiteral)
160+
testJoin(originalCond, expectedCond = FalseLiteral)
161+
}
162+
163+
test("replace null in branches of If inside another If") {
164+
val originalCond = If(
165+
If(UnresolvedAttribute("b"), Literal(null), FalseLiteral),
166+
TrueLiteral,
167+
Literal(null))
168+
testFilter(originalCond, expectedCond = FalseLiteral)
169+
testJoin(originalCond, expectedCond = FalseLiteral)
170+
}
171+
172+
test("replace null in CaseWhen inside another CaseWhen") {
173+
val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null))
174+
val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null))
175+
testFilter(originalCond, expectedCond = FalseLiteral)
176+
testJoin(originalCond, expectedCond = FalseLiteral)
177+
}
178+
179+
test("inability to replace null in non-boolean branches of If") {
180+
val condition = If(
181+
UnresolvedAttribute("i") > Literal(10),
182+
Literal(5) > If(
183+
UnresolvedAttribute("i") === Literal(15),
184+
Literal(null, IntegerType),
185+
Literal(3)),
186+
FalseLiteral)
187+
testFilter(originalCond = condition, expectedCond = condition)
188+
testJoin(originalCond = condition, expectedCond = condition)
189+
}
190+
191+
test("inability to replace null in non-boolean values of CaseWhen") {
192+
val nestedCaseWhen = CaseWhen(
193+
Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)),
194+
Literal(null, IntegerType))
195+
val branchValue = If(
196+
Literal(2) === nestedCaseWhen,
197+
TrueLiteral,
198+
FalseLiteral)
199+
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
200+
val condition = CaseWhen(branches)
201+
testFilter(originalCond = condition, expectedCond = condition)
202+
testJoin(originalCond = condition, expectedCond = condition)
203+
}
204+
205+
test("inability to replace null in non-boolean branches of If inside another If") {
206+
val condition = If(
207+
Literal(5) > If(
208+
UnresolvedAttribute("i") === Literal(15),
209+
Literal(null, IntegerType),
210+
Literal(3)),
211+
TrueLiteral,
212+
FalseLiteral)
213+
testFilter(originalCond = condition, expectedCond = condition)
214+
testJoin(originalCond = condition, expectedCond = condition)
215+
}
216+
217+
test("replace null in If used as a join condition") {
218+
// this test is only for joins as the condition involves columns from different relations
219+
val originalCond = If(
220+
UnresolvedAttribute("d") > UnresolvedAttribute("i"),
221+
Literal(null),
222+
FalseLiteral)
223+
testJoin(originalCond, expectedCond = FalseLiteral)
224+
}
225+
226+
test("replace null in CaseWhen used as a join condition") {
227+
// this test is only for joins as the condition involves columns from different relations
228+
val originalBranches = Seq(
229+
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null),
230+
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
231+
232+
val expectedBranches = Seq(
233+
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral,
234+
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
235+
236+
testJoin(
237+
originalCond = CaseWhen(originalBranches, FalseLiteral),
238+
expectedCond = CaseWhen(expectedBranches, FalseLiteral))
239+
}
240+
241+
test("inability to replace null in CaseWhen inside EqualTo used as a join condition") {
242+
// this test is only for joins as the condition involves columns from different relations
243+
val branches = Seq(
244+
(UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType),
245+
(UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
246+
val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral)
247+
testJoin(originalCond = condition, expectedCond = condition)
248+
}
249+
250+
test("replace null in predicates of If") {
251+
val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null))
252+
testProjection(
253+
originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
254+
expectedExpr = Literal(1).as("out"))
255+
}
256+
257+
test("replace null in predicates of If inside another If") {
258+
val predicate = If(
259+
And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)),
260+
TrueLiteral,
261+
FalseLiteral)
262+
testProjection(
263+
originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
264+
expectedExpr = Literal(1).as("out"))
265+
}
266+
267+
test("inability to replace null in non-boolean expressions inside If predicates") {
268+
val predicate = GreaterThan(
269+
UnresolvedAttribute("i"),
270+
If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
271+
val column = If(predicate, Literal(5), Literal(1)).as("out")
272+
testProjection(originalExpr = column, expectedExpr = column)
273+
}
274+
275+
test("replace null in conditions of CaseWhen") {
276+
val branches = Seq(
277+
And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5))
278+
testProjection(
279+
originalExpr = CaseWhen(branches, Literal(2)).as("out"),
280+
expectedExpr = Literal(2).as("out"))
281+
}
282+
283+
test("replace null in conditions of CaseWhen inside another CaseWhen") {
284+
val nestedCaseWhen = CaseWhen(
285+
Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)),
286+
Literal(2))
287+
val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1))
288+
testProjection(
289+
originalExpr = CaseWhen(branches).as("out"),
290+
expectedExpr = Literal(1).as("out"))
291+
}
292+
293+
test("inability to replace null in non-boolean exprs inside CaseWhen conditions") {
294+
val condition = GreaterThan(
295+
UnresolvedAttribute("i"),
296+
If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
297+
val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out")
298+
testProjection(originalExpr = column, expectedExpr = column)
299+
}
300+
301+
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
302+
test((rel, exp) => rel.where(exp), originalCond, expectedCond)
303+
}
304+
305+
private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = {
306+
test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond)
307+
}
308+
309+
private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = {
310+
test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
311+
}
312+
313+
private def test(
314+
func: (LogicalPlan, Expression) => LogicalPlan,
315+
originalExpr: Expression,
316+
expectedExpr: Expression): Unit = {
317+
318+
val originalPlan = func(testRelation, originalExpr).analyze
319+
val optimizedPlan = Optimize.execute(originalPlan)
320+
val expectedPlan = func(testRelation, expectedExpr).analyze
321+
comparePlans(optimizedPlan, expectedPlan)
322+
}
323+
}

0 commit comments

Comments
 (0)