Skip to content

Commit d4c3415

Browse files
dbtsaigatorsmile
authored andcommitted
[SPARK-24890][SQL] Short circuiting the if condition when trueValue and falseValue are the same
## What changes were proposed in this pull request? When `trueValue` and `falseValue` are semantic equivalence, the condition expression in `if` can be removed to avoid extra computation in runtime. ## How was this patch tested? Test added. Author: DB Tsai <[email protected]> Closes apache#21848 from dbtsai/short-circuit-if.
1 parent c26b092 commit d4c3415

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
390390
case If(TrueLiteral, trueValue, _) => trueValue
391391
case If(FalseLiteral, _, falseValue) => falseValue
392392
case If(Literal(null, _), _, falseValue) => falseValue
393+
case If(cond, trueValue, falseValue)
394+
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
393395

394396
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
395397
// If there are branches that are always false, remove them.
@@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
403405
e.copy(branches = newBranches)
404406
}
405407

406-
case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) =>
408+
case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) =>
407409
// If the first branch is a true literal, remove the entire CaseWhen and use the value
408410
// from that. Note that CaseWhen.branches should never be empty, and as a result the
409411
// headOption (rather than head) added above is just an extra (and unnecessary) safeguard.
410412
branches.head._2
411413

412414
case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) =>
413-
// a branc with a TRue condition eliminates all following branches,
415+
// a branch with a true condition eliminates all following branches,
414416
// these branches can be pruned away
415417
val (h, t) = branches.span(_._1 != TrueLiteral)
416418
CaseWhen( h :+ t.head, None)
@@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
651653
}
652654
}
653655

656+
654657
/**
655658
* Combine nested [[Concat]] expressions.
656659
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2022
import org.apache.spark.sql.catalyst.dsl.plans._
2123
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
@@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
2931
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3032

3133
object Optimize extends RuleExecutor[LogicalPlan] {
32-
val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
34+
val batches = Batch("SimplifyConditionals", FixedPoint(50),
35+
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
3336
}
3437

3538
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
@@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4346
private val unreachableBranch = (FalseLiteral, Literal(20))
4447
private val nullBranch = (Literal.create(null, NullType), Literal(30))
4548

49+
private val testRelation = LocalRelation('a.int)
50+
4651
test("simplify if") {
4752
assertEquivalent(
4853
If(TrueLiteral, Literal(10), Literal(20)),
@@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
5762
Literal(20))
5863
}
5964

65+
test("remove unnecessary if when the outputs are semantic equivalence") {
66+
assertEquivalent(
67+
If(IsNotNull(UnresolvedAttribute("a")),
68+
Subtract(Literal(10), Literal(1)),
69+
Add(Literal(6), Literal(3))),
70+
Literal(9))
71+
72+
// For non-deterministic condition, we don't remove the `If` statement.
73+
assertEquivalent(
74+
If(GreaterThan(Rand(0), Literal(0.5)),
75+
Subtract(Literal(10), Literal(1)),
76+
Add(Literal(6), Literal(3))),
77+
If(GreaterThan(Rand(0), Literal(0.5)),
78+
Literal(9),
79+
Literal(9)))
80+
}
81+
6082
test("remove unreachable branches") {
6183
// i.e. removing branches whose conditions are always false
6284
assertEquivalent(

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase
393393
}
394394

395395
/**
396-
* Returns full path to the given file in the resouce folder
396+
* Returns full path to the given file in the resource folder
397397
*/
398398
protected def testFile(fileName: String): String = {
399399
Thread.currentThread().getContextClassLoader.getResource(fileName).toString

0 commit comments

Comments
 (0)