Skip to content

Commit 0c3c78d

Browse files
committed
Run single phase aggregation when possible
1 parent 771aaff commit 0c3c78d

File tree

5 files changed

+83
-27
lines changed

5 files changed

+83
-27
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.esql.common.Failures;
1212
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1313
import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
14+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.SinglePhaseAggregate;
1415
import org.elasticsearch.xpack.esql.plan.physical.FragmentExec;
1516
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
1617
import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor;
@@ -25,7 +26,8 @@
2526
public class PhysicalPlanOptimizer extends ParameterizedRuleExecutor<PhysicalPlan, PhysicalOptimizerContext> {
2627

2728
private static final List<RuleExecutor.Batch<PhysicalPlan>> RULES = List.of(
28-
new Batch<>("Plan Boundary", Limiter.ONCE, new ProjectAwayColumns())
29+
new Batch<>("Plan Boundary", Limiter.ONCE, new ProjectAwayColumns()),
30+
new Batch<>("Single aggregation", Limiter.ONCE, new SinglePhaseAggregate())
2931
);
3032

3133
private final PhysicalVerifier verifier = PhysicalVerifier.INSTANCE;
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.physical;
9+
10+
import org.elasticsearch.compute.aggregation.AggregatorMode;
11+
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
12+
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
13+
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
14+
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
15+
16+
/**
17+
* Collapses two-phase aggregation into a single phase when possible.
18+
* For example, in FROM .. | STATS first | STATS second, the STATS second aggregation
19+
* can be executed in a single phase on the coordinator instead of two phases.
20+
*/
21+
public class SinglePhaseAggregate extends PhysicalOptimizerRules.OptimizerRule<AggregateExec> {
22+
@Override
23+
protected PhysicalPlan rule(AggregateExec plan) {
24+
if (plan instanceof AggregateExec parent
25+
&& parent.getMode() == AggregatorMode.FINAL
26+
&& parent.child() instanceof AggregateExec child
27+
&& child.getMode() == AggregatorMode.INITIAL) {
28+
if (parent.groupings()
29+
.stream()
30+
.noneMatch(group -> group.anyMatch(expr -> expr instanceof GroupingFunction.NonEvaluatableGroupingFunction))) {
31+
return child.withMode(AggregatorMode.SINGLE);
32+
}
33+
}
34+
return plan;
35+
}
36+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ public final PhysicalOperation groupingPhysicalOperation(
7575
List<Aggregator.Factory> aggregatorFactories = new ArrayList<>();
7676

7777
// append channels to the layout
78-
if (aggregatorMode == AggregatorMode.FINAL) {
79-
layout.append(aggregates);
80-
} else {
78+
if (aggregatorMode.isOutputPartial()) {
8179
layout.append(aggregateMapper.mapNonGrouping(aggregates));
80+
} else {
81+
layout.append(aggregates);
8282
}
8383

8484
// create the agg factories
@@ -147,14 +147,14 @@ else if (aggregatorMode.isOutputPartial()) {
147147
groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
148148
}
149149

150-
if (aggregatorMode == AggregatorMode.FINAL) {
150+
if (aggregatorMode.isOutputPartial()) {
151+
layout.append(aggregateMapper.mapGrouping(aggregates));
152+
} else {
151153
for (var agg : aggregates) {
152154
if (Alias.unwrap(agg) instanceof AggregateFunction) {
153155
layout.append(agg);
154156
}
155157
}
156-
} else {
157-
layout.append(aggregateMapper.mapGrouping(aggregates));
158158
}
159159

160160
// create the agg factories
@@ -266,7 +266,13 @@ private void aggregatesToFactory(
266266
if (child instanceof AggregateFunction aggregateFunction) {
267267
List<NamedExpression> sourceAttr = new ArrayList<>();
268268

269-
if (mode == AggregatorMode.INITIAL) {
269+
if (mode.isInputPartial()) {
270+
if (grouping) {
271+
sourceAttr = aggregateMapper.mapGrouping(ne);
272+
} else {
273+
sourceAttr = aggregateMapper.mapNonGrouping(ne);
274+
}
275+
} else {
270276
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
271277
Expression field = aggregateFunction.field();
272278
// Only count can now support literals - all the other aggs should be optimized away
@@ -294,16 +300,6 @@ private void aggregatesToFactory(
294300
}
295301
}
296302
}
297-
// coordinator/exchange phase
298-
else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
299-
if (grouping) {
300-
sourceAttr = aggregateMapper.mapGrouping(ne);
301-
} else {
302-
sourceAttr = aggregateMapper.mapNonGrouping(ne);
303-
}
304-
} else {
305-
throw new EsqlIllegalArgumentException("illegal aggregation mode");
306-
}
307303

308304
AggregatorFunctionSupplier aggSupplier = supplier(aggregateFunction);
309305

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.util.BigArrays;
1515
import org.elasticsearch.common.util.Maps;
1616
import org.elasticsearch.compute.Describable;
17-
import org.elasticsearch.compute.aggregation.AggregatorMode;
1817
import org.elasticsearch.compute.data.Block;
1918
import org.elasticsearch.compute.data.BlockFactory;
2019
import org.elasticsearch.compute.data.ElementType;
@@ -226,7 +225,7 @@ public LocalExecutionPlan plan(String description, FoldContext foldCtx, Physical
226225
// workaround for https://github.com/elastic/elasticsearch/issues/99782
227226
localPhysicalPlan = localPhysicalPlan.transformUp(
228227
AggregateExec.class,
229-
a -> a.getMode() == AggregatorMode.FINAL ? new ProjectExec(a.source(), a, Expressions.asAttributes(a.aggregates())) : a
228+
a -> a.getMode().isOutputPartial() ? a : new ProjectExec(a.source(), a, Expressions.asAttributes(a.aggregates()))
230229
);
231230
PhysicalOperation physicalOperation = plan(localPhysicalPlan, context);
232231

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
import static java.util.Arrays.asList;
165165
import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL;
166166
import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL;
167+
import static org.elasticsearch.compute.aggregation.AggregatorMode.SINGLE;
167168
import static org.elasticsearch.core.Tuple.tuple;
168169
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
169170
import static org.elasticsearch.index.query.QueryBuilders.existsQuery;
@@ -3793,8 +3794,7 @@ public void testMixedSpatialBoundsAndPointsExtracted() {
37933794
* After local optimizations we expect no changes because field is extracted:
37943795
* <code>
37953796
* LimitExec[1000[INTEGER]]
3796-
* \_AggregateExec[[],[SPATIALCENTROID(__centroid_SPATIALCENTROID@7ff910a{r}#7) AS centroid],FINAL,50]
3797-
* \_AggregateExec[[],[SPATIALCENTROID(__centroid_SPATIALCENTROID@7ff910a{r}#7) AS centroid],PARTIAL,50]
3797+
* \_AggregateExec[[],[SPATIALCENTROID(__centroid_SPATIALCENTROID@7ff910a{r}#7) AS centroid],SINGLE,50]
37983798
* \_EvalExec[[[1 1 0 0 0 0 0 30 e2 4c 7c 45 40 0 0 e0 92 b0 82 2d 40][GEO_POINT] AS __centroid_SPATIALCENTROID@7ff910a]]
37993799
* \_RowExec[[[50 4f 49 4e 54 28 34 32 2e 39 37 31 30 39 36 32 39 39 35 38 38 36 38 20 31 34 2e 37 35 35 32 35 33 34 30 30
38003800
* 36 35 33 36 29][KEYWORD] AS wkt]]
@@ -3822,11 +3822,7 @@ public void testSpatialTypesAndStatsUseDocValuesNestedLiteral() {
38223822
var optimized = optimizedPlan(plan);
38233823
limit = as(optimized, LimitExec.class);
38243824
agg = as(limit.child(), AggregateExec.class);
3825-
assertThat("Aggregation is FINAL", agg.getMode(), equalTo(FINAL));
3826-
assertThat("No groupings in aggregation", agg.groupings().size(), equalTo(0));
3827-
assertAggregation(agg, "centroid", SpatialCentroid.class, GEO_POINT, FieldExtractPreference.NONE);
3828-
agg = as(agg.child(), AggregateExec.class);
3829-
assertThat("Aggregation is PARTIAL", agg.getMode(), equalTo(INITIAL));
3825+
assertThat("Aggregation is SINGLE", agg.getMode(), equalTo(SINGLE));
38303826
assertThat("No groupings in aggregation", agg.groupings().size(), equalTo(0));
38313827
assertAggregation(agg, "centroid", SpatialCentroid.class, GEO_POINT, FieldExtractPreference.NONE);
38323828
eval = as(agg.child(), EvalExec.class);
@@ -7815,6 +7811,33 @@ public void testLookupJoinFieldLoadingDropAllFields() throws Exception {
78157811
assertLookupJoinFieldNames(query, data, List.of(Set.of(), Set.of("foo", "bar", "baz")));
78167812
}
78177813

7814+
/**
7815+
* LimitExec[1000[INTEGER],null]
7816+
* \_AggregateExec[[last_name{r}#8],[COUNT(first_name{r}#5,true[BOOLEAN]) AS count(first_name)#11, last_name{r}#8],SINGLE,[last_name
7817+
* {r}#8, $$count(first_name)$count{r}#25, $$count(first_name)$seen{r}#26],null]
7818+
* \_AggregateExec[[emp_no{f}#12],[VALUES(first_name{f}#13,true[BOOLEAN]) AS first_name#5, VALUES(last_name{f}#16,true[BOOLEAN]) A
7819+
* S last_name#8],FINAL,[emp_no{f}#12, $$first_name$values{r}#23, $$last_name$values{r}#24],null]
7820+
* \_ExchangeExec[[emp_no{f}#12, $$first_name$values{r}#23, $$last_name$values{r}#24],true]
7821+
* \_FragmentExec[filter=null, estimatedRowSize=0, reducer=[], fragment=[
7822+
* Aggregate[[emp_no{f}#12],[VALUES(first_name{f}#13,true[BOOLEAN]) AS first_name#5, VALUES(last_name{f}#16,true[BOOLEAN]) A
7823+
* S last_name#8]]
7824+
* \_EsRelation[test][_meta_field{f}#18, emp_no{f}#12, first_name{f}#13, ..]]]
7825+
*/
7826+
public void testSingleModeAggregate() {
7827+
String q = """
7828+
FROM test
7829+
| STATS first_name = VALUES(first_name), last_name = VALUES(last_name) BY emp_no
7830+
| STATS count(first_name) BY last_name""";
7831+
PhysicalPlan plan = physicalPlan(q);
7832+
PhysicalPlan optimized = physicalPlanOptimizer.optimize(plan);
7833+
LimitExec limit = as(optimized, LimitExec.class);
7834+
AggregateExec second = as(limit.child(), AggregateExec.class);
7835+
assertThat(second.getMode(), equalTo(SINGLE));
7836+
AggregateExec first = as(second.child(), AggregateExec.class);
7837+
assertThat(first.getMode(), equalTo(FINAL));
7838+
as(first.child(), ExchangeExec.class);
7839+
}
7840+
78187841
private void assertLookupJoinFieldNames(String query, TestDataSource data, List<Set<String>> expectedFieldNames) {
78197842
assertLookupJoinFieldNames(query, data, expectedFieldNames, false);
78207843
}

0 commit comments

Comments
 (0)