17
17
18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
- import org .apache .spark .sql .catalyst .expressions .{Alias , And , ArrayTransform , CreateArray , CreateMap , CreateNamedStruct , CreateNamedStructUnsafe , CreateStruct , EqualTo , ExpectsInputTypes , Expression , GetStructField , LambdaFunction , NamedLambdaVariable , UnaryExpression }
20
+ import org .apache .spark .sql .catalyst .expressions .{Alias , And , ArrayTransform , CreateArray , CreateMap , CreateNamedStruct , CreateNamedStructUnsafe , CreateStruct , EqualTo , ExpectsInputTypes , Expression , GetStructField , KnownFloatingPointNormalized , LambdaFunction , NamedLambdaVariable , UnaryExpression }
21
21
import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode }
22
22
import org .apache .spark .sql .catalyst .planning .ExtractEquiJoinKeys
23
23
import org .apache .spark .sql .catalyst .plans .logical .{LogicalPlan , Subquery , Window }
@@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
61
61
case _ : Subquery => plan
62
62
63
63
case _ => plan transform {
64
- case w : Window if w.partitionSpec.exists(p => needNormalize(p.dataType )) =>
64
+ case w : Window if w.partitionSpec.exists(p => needNormalize(p)) =>
65
65
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
66
66
// to normalize the `windowExpressions`, as they are executed per input row and should take
67
67
// the input row as it is.
@@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
73
73
case j @ ExtractEquiJoinKeys (_, leftKeys, rightKeys, condition, _, _, _)
74
74
// The analyzer guarantees left and right joins keys are of the same data type. Here we
75
75
// only need to check join keys of one side.
76
- if leftKeys.exists(k => needNormalize(k.dataType )) =>
76
+ if leftKeys.exists(k => needNormalize(k)) =>
77
77
val newLeftJoinKeys = leftKeys.map(normalize)
78
78
val newRightJoinKeys = rightKeys.map(normalize)
79
79
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
@@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
87
87
}
88
88
}
89
89
90
+ /**
91
+ * Short circuit if the underlying expression is already normalized
92
+ */
93
+ private def needNormalize (expr : Expression ): Boolean = expr match {
94
+ case KnownFloatingPointNormalized (_) => false
95
+ case _ => needNormalize(expr.dataType)
96
+ }
97
+
90
98
private def needNormalize (dt : DataType ): Boolean = dt match {
91
99
case FloatType | DoubleType => true
92
100
case StructType (fields) => fields.exists(f => needNormalize(f.dataType))
@@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
98
106
}
99
107
100
108
private [sql] def normalize (expr : Expression ): Expression = expr match {
101
- case _ if ! needNormalize(expr.dataType ) => expr
109
+ case _ if ! needNormalize(expr) => expr
102
110
103
111
case a : Alias =>
104
112
a.withNewChildren(Seq (normalize(a.child)))
@@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
116
124
CreateMap (children.map(normalize))
117
125
118
126
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
119
- NormalizeNaNAndZero (expr)
127
+ KnownFloatingPointNormalized ( NormalizeNaNAndZero (expr) )
120
128
121
129
case _ if expr.dataType.isInstanceOf [StructType ] =>
122
130
val fields = expr.dataType.asInstanceOf [StructType ].fields.indices.map { i =>
@@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
128
136
val ArrayType (et, containsNull) = expr.dataType
129
137
val lv = NamedLambdaVariable (" arg" , et, containsNull)
130
138
val function = normalize(lv)
131
- ArrayTransform (expr, LambdaFunction (function, Seq (lv)))
139
+ KnownFloatingPointNormalized ( ArrayTransform (expr, LambdaFunction (function, Seq (lv) )))
132
140
133
141
case _ => throw new IllegalStateException (s " fail to normalize $expr" )
134
142
}
0 commit comments