Skip to content

Commit 7021588

Browse files
yeshengmcloud-fan
authored andcommitted
[SPARK-28306][SQL] Make NormalizeFloatingNumbers rule idempotent
## What changes were proposed in this pull request? The optimizer rule `NormalizeFloatingNumbers` is not idempotent. It will generate multiple `NormalizeNaNAndZero` and `ArrayTransform` expression nodes for multiple runs. This patch fixed this non-idempotence by adding a marking tag above normalized expressions. It also adds missing UTs for `NormalizeFloatingNumbers`. ## How was this patch tested? New UTs. Closes apache#25080 from yeshengm/spark-28306. Authored-by: Yesheng Ma <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0197628 commit 7021588

File tree

3 files changed

+108
-12
lines changed

3 files changed

+108
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
2222
import org.apache.spark.sql.types.DataType
2323

24-
case class KnownNotNull(child: Expression) extends UnaryExpression {
25-
override def nullable: Boolean = false
24+
trait TaggingExpression extends UnaryExpression {
25+
override def nullable: Boolean = child.nullable
2626
override def dataType: DataType = child.dataType
2727

28+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.genCode(ctx)
29+
30+
override def eval(input: InternalRow): Any = child.eval(input)
31+
}
32+
33+
case class KnownNotNull(child: Expression) extends TaggingExpression {
34+
override def nullable: Boolean = false
35+
2836
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
2937
child.genCode(ctx).copy(isNull = FalseLiteral)
3038
}
31-
32-
override def eval(input: InternalRow): Any = {
33-
child.eval(input)
34-
}
3539
}
40+
41+
case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression}
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2222
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2323
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
@@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
6161
case _: Subquery => plan
6262

6363
case _ => plan transform {
64-
case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) =>
64+
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
6565
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
6666
// to normalize the `windowExpressions`, as they are executed per input row and should take
6767
// the input row as it is.
@@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
7373
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
7474
// The analyzer guarantees left and right joins keys are of the same data type. Here we
7575
// only need to check join keys of one side.
76-
if leftKeys.exists(k => needNormalize(k.dataType)) =>
76+
if leftKeys.exists(k => needNormalize(k)) =>
7777
val newLeftJoinKeys = leftKeys.map(normalize)
7878
val newRightJoinKeys = rightKeys.map(normalize)
7979
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
@@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
8787
}
8888
}
8989

90+
/**
91+
* Short circuit if the underlying expression is already normalized
92+
*/
93+
private def needNormalize(expr: Expression): Boolean = expr match {
94+
case KnownFloatingPointNormalized(_) => false
95+
case _ => needNormalize(expr.dataType)
96+
}
97+
9098
private def needNormalize(dt: DataType): Boolean = dt match {
9199
case FloatType | DoubleType => true
92100
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
@@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
98106
}
99107

100108
private[sql] def normalize(expr: Expression): Expression = expr match {
101-
case _ if !needNormalize(expr.dataType) => expr
109+
case _ if !needNormalize(expr) => expr
102110

103111
case a: Alias =>
104112
a.withNewChildren(Seq(normalize(a.child)))
@@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
116124
CreateMap(children.map(normalize))
117125

118126
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
119-
NormalizeNaNAndZero(expr)
127+
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
120128

121129
case _ if expr.dataType.isInstanceOf[StructType] =>
122130
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
@@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
128136
val ArrayType(et, containsNull) = expr.dataType
129137
val lv = NamedLambdaVariable("arg", et, containsNull)
130138
val function = normalize(lv)
131-
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
139+
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))
132140

133141
case _ => throw new IllegalStateException(s"fail to normalize $expr")
134142
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
23+
import org.apache.spark.sql.catalyst.plans.PlanTest
24+
import org.apache.spark.sql.catalyst.plans.logical._
25+
import org.apache.spark.sql.catalyst.rules.RuleExecutor
26+
27+
class NormalizeFloatingPointNumbersSuite extends PlanTest {
28+
29+
object Optimize extends RuleExecutor[LogicalPlan] {
30+
val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil
31+
}
32+
33+
val testRelation1 = LocalRelation('a.double)
34+
val a = testRelation1.output(0)
35+
val testRelation2 = LocalRelation('a.double)
36+
val b = testRelation2.output(0)
37+
38+
test("normalize floating points in window function expressions") {
39+
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
40+
41+
val optimized = Optimize.execute(query)
42+
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
43+
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
44+
45+
comparePlans(optimized, correctAnswer)
46+
}
47+
48+
test("normalize floating points in window function expressions - idempotence") {
49+
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
50+
51+
val optimized = Optimize.execute(query)
52+
val doubleOptimized = Optimize.execute(optimized)
53+
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
54+
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
55+
56+
comparePlans(doubleOptimized, correctAnswer)
57+
}
58+
59+
test("normalize floating points in join keys") {
60+
val query = testRelation1.join(testRelation2, condition = Some(a === b))
61+
62+
val optimized = Optimize.execute(query)
63+
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
64+
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
65+
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
66+
67+
comparePlans(optimized, correctAnswer)
68+
}
69+
70+
test("normalize floating points in join keys - idempotence") {
71+
val query = testRelation1.join(testRelation2, condition = Some(a === b))
72+
73+
val optimized = Optimize.execute(query)
74+
val doubleOptimized = Optimize.execute(optimized)
75+
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
76+
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
77+
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
78+
79+
comparePlans(doubleOptimized, correctAnswer)
80+
}
81+
}
82+

0 commit comments

Comments
 (0)