Skip to content

Commit 0e2c487

Browse files
committed
[SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expressions for final aggregate
## What changes were proposed in this pull request? A followup of apache#23388 . `AggUtils.createAggregate` is not the right place to normalize the grouping expressions, as final aggregate is also created by it. The grouping expressions of final aggregate should be attributes which refer to the grouping expressions in partial aggregate. This PR moves the normalization to the caller side of `AggUtils`. ## How was this patch tested? existing tests Closes apache#23692 from cloud-fan/follow. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0d77d57 commit 0e2c487

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
9898
}
9999

100100
private[sql] def normalize(expr: Expression): Expression = expr match {
101-
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
102-
NormalizeNaNAndZero(expr)
101+
case _ if !needNormalize(expr.dataType) => expr
102+
103+
case a: Alias =>
104+
a.withNewChildren(Seq(normalize(a.child)))
103105

104106
case CreateNamedStruct(children) =>
105107
CreateNamedStruct(children.map(normalize))
@@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
113115
case CreateMap(children) =>
114116
CreateMap(children.map(normalize))
115117

116-
case a: Alias if needNormalize(a.dataType) =>
117-
a.withNewChildren(Seq(normalize(a.child)))
118+
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
119+
NormalizeNaNAndZero(expr)
118120

119-
case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) =>
121+
case _ if expr.dataType.isInstanceOf[StructType] =>
120122
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
121123
normalize(GetStructField(expr, i))
122124
}
123125
CreateStruct(fields)
124126

125-
case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) =>
127+
case _ if expr.dataType.isInstanceOf[ArrayType] =>
126128
val ArrayType(et, containsNull) = expr.dataType
127129
val lv = NamedLambdaVariable("arg", et, containsNull)
128130
val function = normalize(lv)
129131
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
130132

131-
case _ => expr
133+
case _ => throw new IllegalStateException(s"fail to normalize $expr")
132134
}
133135
}
134136

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.encoders.RowEncoder
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
26+
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
2627
import org.apache.spark.sql.catalyst.planning._
2728
import org.apache.spark.sql.catalyst.plans._
2829
import org.apache.spark.sql.catalyst.plans.logical._
@@ -331,8 +332,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
331332

332333
val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
333334

335+
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
336+
// `groupingExpressions` is not extracted during logical phase.
337+
val normalizedGroupingExpressions = namedGroupingExpressions.map { e =>
338+
NormalizeFloatingNumbers.normalize(e) match {
339+
case n: NamedExpression => n
340+
case other => Alias(other, e.name)(exprId = e.exprId)
341+
}
342+
}
343+
334344
aggregate.AggUtils.planStreamingAggregation(
335-
namedGroupingExpressions,
345+
normalizedGroupingExpressions,
336346
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
337347
rewrittenResultExpressions,
338348
stateVersion,
@@ -414,16 +424,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
414424
"Spark user mailing list.")
415425
}
416426

427+
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
428+
// `groupingExpressions` is not extracted during logical phase.
429+
val normalizedGroupingExpressions = groupingExpressions.map { e =>
430+
NormalizeFloatingNumbers.normalize(e) match {
431+
case n: NamedExpression => n
432+
case other => Alias(other, e.name)(exprId = e.exprId)
433+
}
434+
}
435+
417436
val aggregateOperator =
418437
if (functionsWithDistinct.isEmpty) {
419438
aggregate.AggUtils.planAggregateWithoutDistinct(
420-
groupingExpressions,
439+
normalizedGroupingExpressions,
421440
aggregateExpressions,
422441
resultExpressions,
423442
planLater(child))
424443
} else {
425444
aggregate.AggUtils.planAggregateWithOneDistinct(
426-
groupingExpressions,
445+
normalizedGroupingExpressions,
427446
functionsWithDistinct,
428447
functionsWithoutDistinct,
429448
resultExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,12 @@ object AggUtils {
3535
initialInputBufferOffset: Int = 0,
3636
resultExpressions: Seq[NamedExpression] = Nil,
3737
child: SparkPlan): SparkPlan = {
38-
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because
39-
// `groupingExpressions` is not extracted during logical phase.
40-
val normalizedGroupingExpressions = groupingExpressions.map { e =>
41-
NormalizeFloatingNumbers.normalize(e) match {
42-
case n: NamedExpression => n
43-
case other => Alias(other, e.name)(exprId = e.exprId)
44-
}
45-
}
4638
val useHash = HashAggregateExec.supportsAggregate(
4739
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
4840
if (useHash) {
4941
HashAggregateExec(
5042
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
51-
groupingExpressions = normalizedGroupingExpressions,
43+
groupingExpressions = groupingExpressions,
5244
aggregateExpressions = aggregateExpressions,
5345
aggregateAttributes = aggregateAttributes,
5446
initialInputBufferOffset = initialInputBufferOffset,
@@ -61,7 +53,7 @@ object AggUtils {
6153
if (objectHashEnabled && useObjectHash) {
6254
ObjectHashAggregateExec(
6355
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
64-
groupingExpressions = normalizedGroupingExpressions,
56+
groupingExpressions = groupingExpressions,
6557
aggregateExpressions = aggregateExpressions,
6658
aggregateAttributes = aggregateAttributes,
6759
initialInputBufferOffset = initialInputBufferOffset,
@@ -70,7 +62,7 @@ object AggUtils {
7062
} else {
7163
SortAggregateExec(
7264
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
73-
groupingExpressions = normalizedGroupingExpressions,
65+
groupingExpressions = groupingExpressions,
7466
aggregateExpressions = aggregateExpressions,
7567
aggregateAttributes = aggregateAttributes,
7668
initialInputBufferOffset = initialInputBufferOffset,

0 commit comments

Comments
 (0)