Skip to content

Commit aed5516

Browse files
committed
[SPARK-53155][SQL] Global lower agggregation should not be replaced with a project
This patch fixes the optimization rule `RemoveRedundantAggregates`. The optimizer rule `RemoveRedundantAggregates` removes redundant lower aggregation from a query plan and replace it with a project of referred non-aggregate expressions. However, if the removed aggregation is a global one, that is not correct because a project is different with a global aggregation in semantics. For example, if the input relation is empty, a project might be optimized to an empty relation, while a global aggregation will return a single row. Yes, this fixes a user-facing bug. Previously, a global aggregation under another aggregation might be treated as redundant and replaced as a project with non-aggregation expressions. If the input relation is empty, the replacement is incorrect and might produce incorrect result. This patch adds a new unit test to show the difference. Unit test, manual test. No Closes #51884 from viirya/fix_remove_redundant_agg. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]> (cherry picked from commit 3aa8c9d) Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 2a83c63 commit aed5516

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
5454
.map(_.toAttribute)
5555
))
5656

57-
upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg
57+
// If the lower aggregation is global, it is not redundant because a project with
58+
// non-aggregate expressions is different with global aggregation in semantics.
59+
// E.g., if the input relation is empty, a project might be optimized to an empty
60+
// relation, while a global aggregation will return a single row.
61+
lazy val lowerIsGlobalAgg = lower.groupingExpressions.isEmpty
62+
63+
upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg && !lowerIsGlobalAgg
5864
}
5965

6066
private def isDuplicateSensitive(ae: AggregateExpression): Boolean = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
22-
import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF}
22+
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PythonUDAF}
2323
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2424
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
2525
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, LogicalPlan}
@@ -289,4 +289,23 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
289289
val originalQuery = Distinct(x.groupBy($"a", $"b")($"a", TrueLiteral)).analyze
290290
comparePlans(Optimize.execute(originalQuery), originalQuery)
291291
}
292+
293+
test("SPARK-53155: global lower aggregation should not be removed") {
294+
object OptimizeNonRemovedRedundantAgg extends RuleExecutor[LogicalPlan] {
295+
val batches = Batch("RemoveRedundantAggregates", FixedPoint(10),
296+
PropagateEmptyRelation,
297+
RemoveRedundantAggregates) :: Nil
298+
}
299+
300+
val query = relation
301+
.groupBy()(Literal(1).as("col1"), Literal(2).as("col2"), Literal(3).as("col3"))
302+
.groupBy($"col1")(max($"col1"))
303+
.analyze
304+
val expected = relation
305+
.groupBy()(Literal(1).as("col1"), Literal(2).as("col2"), Literal(3).as("col3"))
306+
.groupBy($"col1")(max($"col1"))
307+
.analyze
308+
val optimized = OptimizeNonRemovedRedundantAgg.execute(query)
309+
comparePlans(optimized, expected)
310+
}
292311
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,6 +2567,13 @@ class DataFrameAggregateSuite extends QueryTest
25672567
checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil)
25682568
}
25692569
}
2570+
2571+
test("SPARK-53155: global lower aggregation should not be removed") {
2572+
val df = emptyTestData
2573+
.groupBy().agg(lit(1).as("col1"), lit(2).as("col2"), lit(3).as("col3"))
2574+
.groupBy($"col1").agg(max("col1"))
2575+
checkAnswer(df, Seq(Row(1, 1)))
2576+
}
25702577
}
25712578

25722579
case class B(c: Option[Double])

0 commit comments

Comments
 (0)