Skip to content

Commit 3da0e14

Browse files
committed
Fix the partial aggregation pushdown for system table for native execution
1 parent edfebc2 commit 3da0e14

File tree

5 files changed

+91
-50
lines changed

5 files changed

+91
-50
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.facebook.presto.spi.relation.VariableReferenceExpression;
4545
import com.facebook.presto.sql.analyzer.Field;
4646
import com.facebook.presto.sql.planner.iterative.Lookup;
47+
import com.facebook.presto.sql.planner.plan.ExchangeNode;
4748
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
4849
import com.facebook.presto.sql.relational.FunctionResolution;
4950
import com.facebook.presto.sql.tree.ComparisonExpression;
@@ -511,6 +512,25 @@ public static boolean containsSystemTableScan(PlanNode plan, Lookup lookup)
511512
.matches();
512513
}
513514

515+
/// Checks whether a node is directly on top of a system table scan without exchange in between
516+
public static boolean directlyOnSystemTableScan(PlanNode plan, Lookup lookup)
517+
{
518+
plan = lookup.resolve(plan);
519+
for (PlanNode source : plan.getSources()) {
520+
source = lookup.resolve(source);
521+
if (source instanceof TableScanNode && isInternalSystemConnector(((TableScanNode) source).getTable().getConnectorId())) {
522+
return true;
523+
}
524+
if (source instanceof ExchangeNode) {
525+
continue;
526+
}
527+
if (directlyOnSystemTableScan(source, lookup)) {
528+
return true;
529+
}
530+
}
531+
return false;
532+
}
533+
514534
public static boolean isConstant(RowExpression expression, Type type, Object value)
515535
{
516536
return expression instanceof ConstantExpression &&

presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public interface Lookup
3535
default PlanNode resolve(PlanNode node)
3636
{
3737
if (node instanceof GroupReference) {
38-
return resolveGroup(node).collect(toOptional()).get();
38+
return resolveGroup(node).collect(toOptional()).orElse(node);
3939
}
4040
return node;
4141
}

presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import com.facebook.presto.spi.relation.RowExpression;
3434
import com.facebook.presto.spi.relation.VariableReferenceExpression;
3535
import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy;
36+
import com.facebook.presto.sql.planner.PlannerUtils;
3637
import com.facebook.presto.sql.planner.iterative.Rule;
3738
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
3839
import com.facebook.presto.sql.planner.plan.ExchangeNode;
@@ -60,7 +61,6 @@
6061
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW;
6162
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.AUTOMATIC;
6263
import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER;
63-
import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan;
6464
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER;
6565
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
6666
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
@@ -166,8 +166,13 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
166166
return Result.empty();
167167
}
168168

169-
// System table scan must be run in Java on coordinator and partial aggregation output may not be compatible with Velox
170-
if (nativeExecution && containsSystemTableScan(exchangeNode, context.getLookup())) {
169+
// For native execution:
170+
// Partial aggregation result from Java coordinator task is not compatible with native worker.
171+
// System table scan must be run in on coordinator and addExchange would always add a GatherExchange on top of it.
172+
// We should never push partial aggregation past the GatherExchange.
173+
if (nativeExecution
174+
&& exchangeNode.getType() == GATHER
175+
&& PlannerUtils.directlyOnSystemTableScan(exchangeNode, context.getLookup())) {
171176
return Result.empty();
172177
}
173178

presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeGeneralQueries.java

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import com.facebook.presto.spi.SchemaTableName;
2626
import com.facebook.presto.spi.plan.SortNode;
2727
import com.facebook.presto.sql.analyzer.FeaturesConfig;
28+
import com.facebook.presto.sql.planner.Symbol;
29+
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
2830
import com.facebook.presto.sql.planner.plan.ExchangeNode;
2931
import com.facebook.presto.testing.MaterializedResult;
3032
import com.facebook.presto.testing.MaterializedRow;
@@ -33,6 +35,7 @@
3335
import com.facebook.presto.tests.DistributedQueryRunner;
3436
import com.google.common.collect.ImmutableList;
3537
import com.google.common.collect.ImmutableMap;
38+
import com.google.common.collect.ImmutableSet;
3639
import org.intellij.lang.annotations.Language;
3740
import org.testng.annotations.Test;
3841

@@ -71,10 +74,13 @@
7174
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion;
7275
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createSupplier;
7376
import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createTableToTestHiddenColumns;
74-
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
77+
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
78+
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
7579
import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR;
7680
import static com.facebook.presto.spi.plan.ExchangeEncoding.ROW_WISE;
81+
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.GroupingSetDescriptor;
7782
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
83+
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol;
7884
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
7985
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
8086
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
@@ -1520,66 +1526,72 @@ public void testSystemTables()
15201526
"AS " +
15211527
"SELECT nationkey, name, comment, regionkey FROM nation", tableName));
15221528

1523-
String filter = format("SELECT regionkey FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName);
1529+
String groupingSet = format("SELECT count(*) FROM \"%s\" GROUP BY GROUPING SETS ((regionkey), ())", partitionsTableName);
15241530
assertPlan(
1525-
filter,
1526-
anyTree(
1527-
exchange(REMOTE_STREAMING, GATHER,
1528-
filter(
1529-
"REGION_KEY % 3 = 1",
1530-
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))));
1531-
assertQuery(filter);
1532-
1533-
String project = format("SELECT regionkey + 1 FROM \"%s\"", partitionsTableName);
1534-
assertPlan(
1535-
project,
1536-
anyTree(
1537-
exchange(REMOTE_STREAMING, GATHER,
1538-
project(
1539-
ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")),
1540-
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))));
1541-
assertQuery(project);
1542-
1543-
String filterProject = format("SELECT regionkey + 1 FROM \"%s\" WHERE regionkey %% 3 = 1", partitionsTableName);
1544-
assertPlan(
1545-
filterProject,
1546-
anyTree(
1547-
exchange(REMOTE_STREAMING, GATHER,
1548-
project(
1549-
ImmutableMap.of("EXPRESSION", expression("REGION_KEY + CAST(1 AS bigint)")),
1550-
filter(
1551-
"REGION_KEY % 3 = 1",
1552-
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
1553-
assertQuery(filterProject);
1531+
groupingSet,
1532+
PlanMatchPattern.output(project(
1533+
aggregation(
1534+
new PlanMatchPattern.GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)),
1535+
ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of(anySymbol()))),
1536+
ImmutableMap.of(),
1537+
Optional.of(new Symbol("groupid")),
1538+
FINAL,
1539+
exchange(LOCAL, REPARTITION,
1540+
aggregation(
1541+
new GroupingSetDescriptor(ImmutableList.of("regionkey$gid", "groupid"), 2, ImmutableSet.of(1)),
1542+
ImmutableMap.of(Optional.empty(), functionCall("count", false, ImmutableList.of())),
1543+
ImmutableList.of(),
1544+
ImmutableMap.of(),
1545+
Optional.of(new Symbol("groupid")),
1546+
PARTIAL,
1547+
PlanMatchPattern.groupingSet(
1548+
ImmutableList.of(ImmutableList.of("REGION_KEY"), ImmutableList.of()),
1549+
ImmutableMap.of(),
1550+
"groupid",
1551+
ImmutableMap.of("regionkey$gid", expression("REGION_KEY")),
1552+
exchange(REMOTE_STREAMING, GATHER,
1553+
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))))));
15541554

15551555
String aggregation = format("SELECT count(*), sum(regionkey) FROM \"%s\"", partitionsTableName);
15561556
assertPlan(
15571557
aggregation,
1558-
anyTree(
1558+
PlanMatchPattern.output(
15591559
aggregation(
15601560
ImmutableMap.of(
1561-
"FINAL_COUNT", functionCall("count", ImmutableList.of()),
1562-
"FINAL_SUM", functionCall("sum", ImmutableList.of("REGION_KEY"))),
1563-
SINGLE,
1561+
"FINAL_COUNT", functionCall("count", false, ImmutableList.of(anySymbol())),
1562+
"FINAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))),
1563+
FINAL,
15641564
exchange(LOCAL, GATHER,
1565-
exchange(REMOTE_STREAMING, GATHER,
1566-
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
1565+
aggregation(
1566+
ImmutableMap.of(
1567+
"PARTIAL_COUNT", functionCall("count", false, ImmutableList.of()),
1568+
"PARTIAL_SUM", functionCall("sum", false, ImmutableList.of(anySymbol()))),
1569+
PARTIAL,
1570+
exchange(REMOTE_STREAMING, GATHER,
1571+
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))));
15671572
assertQuery(aggregation);
15681573

15691574
String groupBy = format("SELECT regionkey, count(*) FROM \"%s\" GROUP BY regionkey", partitionsTableName);
15701575
assertPlan(
15711576
groupBy,
1572-
anyTree(
1577+
PlanMatchPattern.output(
15731578
aggregation(
15741579
singleGroupingSet("REGION_KEY"),
15751580
ImmutableMap.of(
1576-
Optional.of("FINAL_COUNT"), functionCall("count", ImmutableList.of())),
1581+
Optional.of("FINAL_COUNT"), functionCall("count", false, ImmutableList.of(anySymbol()))),
15771582
ImmutableMap.of(),
15781583
Optional.empty(),
1579-
SINGLE,
1584+
FINAL,
15801585
exchange(LOCAL, REPARTITION,
1581-
exchange(REMOTE_STREAMING, GATHER,
1582-
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey")))))));
1586+
aggregation(
1587+
singleGroupingSet("REGION_KEY"),
1588+
ImmutableMap.of(
1589+
Optional.of("PARTIAL_COUNT"), functionCall("count", false, ImmutableList.of())),
1590+
ImmutableMap.of(),
1591+
Optional.empty(),
1592+
PARTIAL,
1593+
exchange(REMOTE_STREAMING, GATHER,
1594+
tableScan(partitionsTableName, ImmutableMap.of("REGION_KEY", "regionkey"))))))));
15831595
assertQuery(groupBy);
15841596

15851597
String join = format("SELECT * " +

presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeSystemQueries.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual;
2323
import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters;
24-
import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE;
24+
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
25+
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
2526
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
2627
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree;
2728
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange;
@@ -67,10 +68,13 @@ public void testTasks()
6768
anyTree(
6869
aggregation(
6970
Collections.emptyMap(),
70-
SINGLE,
71+
FINAL,
7172
exchange(LOCAL, GATHER,
72-
exchange(REMOTE_STREAMING, GATHER,
73-
tableScan("tasks"))))));
73+
aggregation(
74+
Collections.emptyMap(),
75+
PARTIAL,
76+
exchange(REMOTE_STREAMING, GATHER,
77+
tableScan("tasks")))))));
7478
}
7579

7680
@Test

0 commit comments

Comments
 (0)