Skip to content

Commit b24f476

Browse files
iwanttobepowerfulmihaibudiu
authored andcommitted
[CALCITE-7320] AggregateProjectMergeRule throws AssertionError when Project maps multiple grouping keys to the same field
1 parent 41774e2 commit b24f476

File tree

4 files changed

+104
-1
lines changed

4 files changed

+104
-1
lines changed

core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectMergeRule.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ public AggregateProjectMergeRule(
110110
newGroupingSets =
111111
ImmutableBitSet.ORDERING.immutableSortedCopy(
112112
ImmutableBitSet.permute(aggregate.getGroupSets(), map));
113+
for (int i = 0; i < newGroupingSets.size() - 1; i++) {
114+
if (newGroupingSets.get(i).equals(newGroupingSets.get(i + 1))) {
115+
// If the project merges two columns that are both in the grouping sets,
116+
// we might get duplicate grouping sets. Aggregate does not allow
117+
// duplicate grouping sets, so we abort the rule.
118+
return null;
119+
}
120+
}
113121
}
114122

115123
final ImmutableList.Builder<AggregateCall> aggCalls =

core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ private RelNode rewriteScalarAggregate(Aggregate oldRel,
992992
for (int i1 = 0; i1 < oldRel.getAggCallList().size(); i1++) {
993993
AggregateCall aggCall = oldRel.getAggCallList().get(i1);
994994
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
995-
int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupSet().size()));
995+
int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupCount()));
996996
final RexInputRef ref = RexInputRef.of(index + valueGenFieldCount, joinRowType);
997997
ImmutableList<RexNode> exprs =
998998
ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0));

core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,4 +1438,77 @@ public static Frameworks.ConfigBuilder config() {
14381438
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
14391439
assertThat(after, hasTree(planAfter));
14401440
}
1441+
1442+
/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7320">[CALCITE-7320]
1443+
* AggregateProjectMergeRule throws AssertionError when Project maps multiple grouping keys
1444+
* to the same field</a>. */
1445+
@Test void test7320() {
1446+
final FrameworkConfig frameworkConfig = config().build();
1447+
final RelBuilder builder = RelBuilder.create(frameworkConfig);
1448+
final RelOptCluster cluster = builder.getCluster();
1449+
final Planner planner = Frameworks.getPlanner(frameworkConfig);
1450+
final String sql = ""
1451+
+ "SELECT deptno,\n"
1452+
+ " (SELECT SUM(cnt)\n"
1453+
+ " FROM (\n"
1454+
+ " SELECT COUNT(*) AS cnt\n"
1455+
+ " FROM emp\n"
1456+
+ " WHERE emp.deptno = dept.deptno\n"
1457+
+ " GROUP BY GROUPING SETS ((deptno), ())\n"
1458+
+ "))\n"
1459+
+ "FROM dept";
1460+
final RelNode originalRel;
1461+
try {
1462+
final SqlNode parse = planner.parse(sql);
1463+
final SqlNode validate = planner.validate(parse);
1464+
originalRel = planner.rel(validate).rel;
1465+
} catch (Exception e) {
1466+
throw TestUtil.rethrow(e);
1467+
}
1468+
1469+
final HepProgram hepProgram = HepProgram.builder()
1470+
.addRuleCollection(
1471+
ImmutableList.of(
1472+
// SubQuery program rules
1473+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
1474+
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
1475+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
1476+
.build();
1477+
final Program program =
1478+
Programs.of(hepProgram, true,
1479+
requireNonNull(cluster.getMetadataProvider()));
1480+
final RelNode before =
1481+
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
1482+
Collections.emptyList(), Collections.emptyList());
1483+
final String planBefore = ""
1484+
+ "LogicalProject(DEPTNO=[$0], EXPR$1=[$3])\n"
1485+
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
1486+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1487+
+ " LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])\n"
1488+
+ " LogicalProject(CNT=[$1])\n"
1489+
+ " LogicalAggregate(group=[{0}], groups=[[{0}, {}]], CNT=[COUNT()])\n"
1490+
+ " LogicalProject(DEPTNO=[$7])\n"
1491+
+ " LogicalFilter(condition=[=($7, $cor0.DEPTNO)])\n"
1492+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1493+
assertThat(before, hasTree(planBefore));
1494+
1495+
// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
1496+
final RelNode after =
1497+
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
1498+
RuleSets.ofList(Collections.emptyList()));
1499+
final String planAfter = ""
1500+
+ "LogicalProject(DEPTNO=[$0], EXPR$1=[$4])\n"
1501+
+ " LogicalJoin(condition=[=($0, $3)], joinType=[left])\n"
1502+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1503+
+ " LogicalAggregate(group=[{0}], EXPR$0=[SUM($1)])\n"
1504+
+ " LogicalProject(DEPTNO1=[$0], CNT=[CASE(IS NOT NULL($3), $3, 0)])\n"
1505+
+ " LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $2)], joinType=[left])\n"
1506+
+ " LogicalProject(DEPTNO=[$0])\n"
1507+
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
1508+
+ " LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {1}]], CNT=[COUNT()])\n"
1509+
+ " LogicalProject(DEPTNO=[$7], DEPTNO1=[$7])\n"
1510+
+ " LogicalFilter(condition=[IS NOT NULL($7)])\n"
1511+
+ " LogicalTableScan(table=[[scott, EMP]])\n";
1512+
assertThat(after, hasTree(planAfter));
1513+
}
14411514
}

core/src/test/resources/sql/sub-query.iq

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8467,5 +8467,27 @@ ON t2.a = foo.a
84678467
+---+---+---+
84688468
(1 row)
84698469

8470+
!ok
8471+
8472+
# [CALCITE-7320] AggregateProjectMergeRule throws AssertionError when Project maps multiple grouping keys to the same field
8473+
SELECT deptno,
8474+
(SELECT SUM(cnt)
8475+
FROM (
8476+
SELECT COUNT(*) AS cnt
8477+
FROM emp
8478+
WHERE emp.deptno = dept.deptno
8479+
GROUP BY GROUPING SETS ((deptno), ())
8480+
))
8481+
FROM dept;
8482+
+--------+--------+
8483+
| DEPTNO | EXPR$1 |
8484+
+--------+--------+
8485+
| 10 | 6 |
8486+
| 20 | 10 |
8487+
| 30 | 12 |
8488+
| 40 | 0 |
8489+
+--------+--------+
8490+
(4 rows)
8491+
84708492
!ok
84718493
# End sub-query.iq

0 commit comments

Comments
 (0)