Skip to content

Commit 9d2daa8

Browse files
authored
[FLINK-38776][table] Fix incorrect auxiliary group field names
1 parent 457b3de commit 9d2daa8

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalWindowAggregateRule.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class BatchPhysicalWindowAggregateRule
397397
case (udf, aggIndex) =>
398398
aggBufferFieldNames(aggIndex) = udf match {
399399
case _: AggregateFunction[_, _] =>
400-
Array(aggNames(aggIndex))
400+
Array(aggNames(aggIndex + auxGroupSet.length))
401401
case agf: DeclarativeAggregateFunction =>
402402
agf.aggBufferAttributes.map {
403403
attr =>

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RelExplainUtil.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ object RelExplainUtil {
825825

826826
val inNames = grouping.map(inFields(_)) ++ auxGrouping.map(inFields(_)) ++ aggStrings
827827
val outNames = grouping.indices.map(outFields(_)) ++
828-
(grouping.length + 1 until grouping.length + 1 + auxGrouping.length).map(outFields(_)) ++
828+
(grouping.length until grouping.length + auxGrouping.length).map(outFields(_)) ++
829829
outFieldNames
830830
inNames
831831
.zip(outNames)

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/AggregateReduceGroupingTest.xml

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ LogicalProject(a4=[$0], b4=[$1], EXPR$2=[$3])
381381
</Resource>
382382
<Resource name="optimized exec plan">
383383
<![CDATA[
384-
HashWindowAggregate(groupBy=[a4], auxGrouping=[b4], window=[TumblingGroupWindow('w$, d4, 900000)], select=[a4, b4 AS EXPR$2, COUNT(c4) AS EXPR$2])
384+
HashWindowAggregate(groupBy=[a4], auxGrouping=[b4], window=[TumblingGroupWindow('w$, d4, 900000)], select=[a4, b4, COUNT(c4) AS EXPR$2])
385385
+- Exchange(distribution=[hash[a4]])
386386
+- LegacyTableSourceScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4, b4, c4, d4])
387387
]]>
@@ -401,7 +401,7 @@ LogicalProject(a4=[$0], c4=[$1], EXPR$2=[$3], EXPR$3=[$4])
401401
</Resource>
402402
<Resource name="optimized exec plan">
403403
<![CDATA[
404-
HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], select=[a4, c4 AS EXPR$2, COUNT(b4) AS EXPR$2, AVG(b4) AS EXPR$3])
404+
HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], select=[a4, c4, COUNT(b4) AS EXPR$2, AVG(b4) AS EXPR$3])
405405
+- Exchange(distribution=[hash[a4]])
406406
+- LegacyTableSourceScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4, b4, c4, d4])
407407
]]>
@@ -428,7 +428,7 @@ Calc(select=[a4, c4, s, EXPR$3])
428428
+- Exchange(distribution=[hash[a4, s]])
429429
+- LocalHashAggregate(groupBy=[a4, s], auxGrouping=[c4], select=[a4, s, c4, Partial_COUNT(b4) AS count$0])
430430
+- Calc(select=[a4, c4, w$start AS s, CAST((($f2 - (($f3 * $f3) / $f4)) / $f4) AS INTEGER) AS b4])
431-
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4 AS $f2, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
431+
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
432432
+- Exchange(distribution=[keep_input_as_is[hash[a4]]])
433433
+- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
434434
+- Exchange(distribution=[hash[a4]])
@@ -457,7 +457,7 @@ Calc(select=[a4, c4, e, EXPR$3])
457457
+- Exchange(distribution=[hash[a4, e]])
458458
+- LocalHashAggregate(groupBy=[a4, e], auxGrouping=[c4], select=[a4, e, c4, Partial_COUNT(b4) AS count$0])
459459
+- Calc(select=[a4, c4, w$end AS e, CAST((($f2 - (($f3 * $f3) / $f4)) / $f4) AS INTEGER) AS b4])
460-
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4 AS $f2, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
460+
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
461461
+- Exchange(distribution=[keep_input_as_is[hash[a4]]])
462462
+- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
463463
+- Exchange(distribution=[hash[a4]])
@@ -485,7 +485,7 @@ HashAggregate(isMerge=[true], groupBy=[a4, b4], auxGrouping=[c4], select=[a4, b4
485485
+- Exchange(distribution=[hash[a4, b4]])
486486
+- LocalHashAggregate(groupBy=[a4, b4], auxGrouping=[c4], select=[a4, b4, c4, Partial_COUNT(*) AS count1$0])
487487
+- Calc(select=[a4, CAST((($f2 - (($f3 * $f3) / $f4)) / $f4) AS INTEGER) AS b4, c4])
488-
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4 AS $f2, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
488+
+- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end, w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
489489
+- Exchange(distribution=[keep_input_as_is[hash[a4]]])
490490
+- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
491491
+- Exchange(distribution=[hash[a4]])
@@ -635,6 +635,31 @@ HashAggregate(isMerge=[true], groupBy=[a3, b3], select=[a3, b3, Final_COUNT(coun
635635
+- LocalHashAggregate(groupBy=[a3, b3], select=[a3, b3, Partial_COUNT(c3) AS count$0])
636636
+- Calc(select=[a3, b3, c3])
637637
+- LegacyTableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
638+
]]>
639+
</Resource>
640+
</TestCase>
641+
<TestCase name="testImperativeAggWithAuxiliaryGrouping">
642+
<Resource name="sql">
643+
<![CDATA[SELECT a4, c4, COUNT(b4) FROM (SELECT a4, c4, ARRAY_AGG(b4) AS b4 FROM T4 GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4]]>
644+
</Resource>
645+
<Resource name="ast">
646+
<![CDATA[
647+
LogicalAggregate(group=[{0, 1}], EXPR$2=[COUNT($2)])
648+
+- LogicalProject(a4=[$0], c4=[$1], b4=[$3])
649+
+- LogicalAggregate(group=[{0, 1, 2}], b4=[ARRAY_AGG($3)])
650+
+- LogicalProject(a4=[$0], c4=[$2], $f2=[$TUMBLE($3, 900000:INTERVAL MINUTE)], b4=[$1])
651+
+- LogicalTableScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]])
652+
]]>
653+
</Resource>
654+
<Resource name="optimized exec plan">
655+
<![CDATA[
656+
HashAggregate(isMerge=[false], groupBy=[a4], auxGrouping=[c4], select=[a4, c4, COUNT(b4) AS EXPR$2])
657+
+- Exchange(distribution=[hash[a4]])
658+
+- SortWindowAggregate(groupBy=[a4], auxGrouping=[c4], window=[TumblingGroupWindow('w$, d4, 900000)], select=[a4, c4, ARRAY_AGG(b4) AS b4])
659+
+- Exchange(distribution=[forward])
660+
+- Sort(orderBy=[a4 ASC, d4 ASC])
661+
+- Exchange(distribution=[hash[a4]])
662+
+- LegacyTableSourceScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4, b4, c4, d4])
638663
]]>
639664
</Resource>
640665
</TestCase>

flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRuleTest.xml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,27 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[COUNT($2)])
572572
FlinkLogicalAggregate(group=[{0, 1}], EXPR$2=[COUNT($2)])
573573
+- FlinkLogicalCalc(select=[a3, b3, c3])
574574
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, T3, source: [TestTableSource(a3, b3, c3, d3)]]], fields=[a3, b3, c3, d3])
575+
]]>
576+
</Resource>
577+
</TestCase>
578+
<TestCase name="testImperativeAggWithAuxiliaryGrouping">
579+
<Resource name="sql">
580+
<![CDATA[SELECT a4, c4, COUNT(b4) FROM (SELECT a4, c4, ARRAY_AGG(b4) AS b4 FROM T4 GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4]]>
581+
</Resource>
582+
<Resource name="ast">
583+
<![CDATA[
584+
LogicalAggregate(group=[{0, 1}], EXPR$2=[COUNT($2)])
585+
+- LogicalProject(a4=[$0], c4=[$1], b4=[$3])
586+
+- LogicalAggregate(group=[{0, 1, 2}], b4=[ARRAY_AGG($3)])
587+
+- LogicalProject(a4=[$0], c4=[$2], $f2=[$TUMBLE($3, 900000:INTERVAL MINUTE)], b4=[$1])
588+
+- LogicalTableScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]])
589+
]]>
590+
</Resource>
591+
<Resource name="optimized rel plan">
592+
<![CDATA[
593+
FlinkLogicalAggregate(group=[{0}], c4=[AUXILIARY_GROUP($1)], EXPR$2=[COUNT($2)])
594+
+- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($2)], b4=[ARRAY_AGG($1)], window=[TumblingGroupWindow('w$, d4, 900000)], properties=[])
595+
+- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog, default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4, b4, c4, d4])
575596
]]>
576597
</Resource>
577598
</TestCase>

flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/common/AggregateReduceGroupingTestBase.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,14 @@ abstract class AggregateReduceGroupingTestBase(withExecPlan: Boolean) extends Ta
342342
"SELECT a1, d1, COUNT(DISTINCT c1), MAX(DISTINCT b1), SUM(b1) FROM T1 GROUP BY a1, d1")
343343
}
344344

345+
@Test
346+
def testImperativeAggWithAuxiliaryGrouping(): Unit = {
347+
verifyPlan(
348+
"SELECT a4, c4, COUNT(b4) FROM " +
349+
"(SELECT a4, c4, ARRAY_AGG(b4) AS b4 FROM T4 " +
350+
"GROUP BY a4, c4, TUMBLE(d4, INTERVAL '15' MINUTE)) t GROUP BY a4, c4")
351+
}
352+
345353
def verifyPlan(sqlQuery: String): Unit = {
346354
if (withExecPlan) {
347355
util.verifyExecPlan(sqlQuery)

0 commit comments

Comments
 (0)