|
25 | 25 | import com.facebook.presto.spi.SchemaTableName; |
26 | 26 | import com.facebook.presto.spi.plan.SortNode; |
27 | 27 | import com.facebook.presto.sql.analyzer.FeaturesConfig; |
| 28 | +import com.facebook.presto.sql.planner.Symbol; |
| 29 | +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; |
28 | 30 | import com.facebook.presto.sql.planner.plan.ExchangeNode; |
29 | 31 | import com.facebook.presto.testing.MaterializedResult; |
30 | 32 | import com.facebook.presto.testing.MaterializedRow; |
|
33 | 35 | import com.facebook.presto.tests.DistributedQueryRunner; |
34 | 36 | import com.google.common.collect.ImmutableList; |
35 | 37 | import com.google.common.collect.ImmutableMap; |
| 38 | +import com.google.common.collect.ImmutableSet; |
36 | 39 | import org.intellij.lang.annotations.Language; |
37 | 40 | import org.testng.annotations.Test; |
38 | 41 |
|
|
71 | 74 | import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion; |
72 | 75 | import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createSupplier; |
73 | 76 | import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createTableToTestHiddenColumns; |
74 | | -import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; |
| 77 | +import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; |
| 78 | +import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; |
75 | 79 | import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR; |
76 | 80 | import static com.facebook.presto.spi.plan.ExchangeEncoding.ROW_WISE; |
| 81 | +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.GroupingSetDescriptor; |
77 | 82 | import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; |
| 83 | +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol; |
78 | 84 | import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; |
79 | 85 | import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; |
80 | 86 | import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; |
@@ -1520,66 +1526,72 @@ public void testSystemTables() |
1520 | 1526 | "AS " + |
1521 | 1527 | "SELECT nationkey, name, comment, regionkey FROM nation", tableName)); |
1522 | 1528 |
|
1523 | | - String filter = format("SELECT regionkey FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName); |
| 1529 | + String groupingSet = format("SELECT count(*) FROM \"%s\" GROUP BY GROUPING SETS ((regionkey), ())", partitionsTableName); |
1524 | 1530 | assertPlan( |
1525 | | - filter, |
1526 | | - anyTree( |
1527 | | - exchange(REMOTE_STREAMING, GATHER, |
1528 | | - filter( |
1529 | | - "REGION_KEY % 3 = 1", |
1530 | | - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))); |
1531 | | - assertQuery(filter); |
1532 | | - |
1533 | | - String project = format("SELECT regionkey + 1 FROM \"%s\"", partitionsTableName); |
1534 | | - assertPlan( |
1535 | | - project, |
1536 | | - anyTree( |
1537 | | - exchange(REMOTE_STREAMING, GATHER, |
1538 | | - project( |
1539 | | - ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")), |
1540 | | - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))); |
1541 | | - assertQuery(project); |
1542 | | - |
1543 | | - String filterProject = format("SELECT regionkey + 1 FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName); |
1544 | | - assertPlan( |
1545 | | - filterProject, |
1546 | | - anyTree( |
1547 | | - exchange(REMOTE_STREAMING, GATHER, |
1548 | | - project( |
1549 | | - ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")), |
1550 | | - filter( |
1551 | | - "REGION_KEY % 3 = 1", |
1552 | | - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); |
1553 | | - assertQuery(filterProject); |
| 1531 | + groupingSet, |
| 1532 | + PlanMatchPattern.output(project( |
| 1533 | + aggregation( |
| 1534 | + new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)), |
| 1535 | + ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of(anySymbol()))), |
| 1536 | + ImmutableMap.of(), |
| 1537 | + Optional.of(new Symbol("groupid")), |
| 1538 | + FINAL, |
| 1539 | + exchange(LOCAL, REPARTITION, |
| 1540 | + aggregation( |
| 1541 | + new GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)), |
| 1542 | + ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of())), |
| 1543 | + ImmutableList.of(), |
| 1544 | + ImmutableMap.of(), |
| 1545 | + Optional.of(new Symbol("groupid")), |
| 1546 | + PARTIAL, |
| 1547 | + PlanMatchPattern.groupingSet( |
| 1548 | + ImmutableList.of(ImmutableList.of("REGION_KEY"), ImmutableList.of()), |
| 1549 | + ImmutableMap.of(), |
| 1550 | + "groupid", |
| 1551 | + ImmutableMap.of("regionkey$gid", expression("REGION_KEY")), |
| 1552 | + exchange(REMOTE_STREAMING, GATHER, |
| 1553 | + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))))); |
1554 | 1554 |
|
1555 | 1555 | String aggregation = format("SELECT count(*), sum(regionkey) FROM \"%s\"", partitionsTableName); |
1556 | 1556 | assertPlan( |
1557 | 1557 | aggregation, |
1558 | | - anyTree( |
| 1558 | + PlanMatchPattern.output( |
1559 | 1559 | aggregation( |
1560 | 1560 | ImmutableMap.of( |
1561 | | - "FINAL_COUNT", functionCall("count", ImmutableList.of()), |
1562 | | - "FINAL_SUM", functionCall("sum", ImmutableList.of("REGION_KEY"))), |
1563 | | - SINGLE, |
| 1561 | + "FINAL_COUNT", functionCall("count", false, ImmutableList.of(anySymbol())), |
| 1562 | + "FINAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))), |
| 1563 | + FINAL, |
1564 | 1564 | exchange(LOCAL, GATHER, |
1565 | | - exchange(REMOTE_STREAMING, GATHER, |
1566 | | - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); |
| 1565 | + aggregation( |
| 1566 | + ImmutableMap.of( |
| 1567 | + "PARTIAL_COUNT", functionCall("count", false, ImmutableList.of()), |
| 1568 | + "PARTIAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))), |
| 1569 | + PARTIAL, |
| 1570 | + exchange(REMOTE_STREAMING, GATHER, |
| 1571 | + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))); |
1567 | 1572 | assertQuery(aggregation); |
1568 | 1573 |
|
1569 | 1574 | String groupBy = format("SELECT regionkey, count(*) FROM \"%s\" GROUP BY regionkey", partitionsTableName); |
1570 | 1575 | assertPlan( |
1571 | 1576 | groupBy, |
1572 | | - anyTree( |
| 1577 | + PlanMatchPattern.output( |
1573 | 1578 | aggregation( |
1574 | 1579 | singleGroupingSet("REGION_KEY"), |
1575 | 1580 | ImmutableMap.of( |
1576 | | - Optional.of("FINAL_COUNT"), functionCall("count", ImmutableList.of())), |
| 1581 | + Optional.of("FINAL_COUNT"), functionCall("count", false, ImmutableList.of(anySymbol()))), |
1577 | 1582 | ImmutableMap.of(), |
1578 | 1583 | Optional.empty(), |
1579 | | - SINGLE, |
| 1584 | + FINAL, |
1580 | 1585 | exchange(LOCAL, REPARTITION, |
1581 | | - exchange(REMOTE_STREAMING, GATHER, |
1582 | | - tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))); |
| 1586 | + aggregation( |
| 1587 | + singleGroupingSet("REGION_KEY"), |
| 1588 | + ImmutableMap.of( |
| 1589 | + Optional.of("PARTIAL_COUNT"), functionCall("count", false, ImmutableList.of())), |
| 1590 | + ImmutableMap.of(), |
| 1591 | + Optional.empty(), |
| 1592 | + PARTIAL, |
| 1593 | + exchange(REMOTE_STREAMING, GATHER, |
| 1594 | + tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))))); |
1583 | 1595 | assertQuery(groupBy); |
1584 | 1596 |
|
1585 | 1597 | String join = format("SELECT * " + |
|
0 commit comments