From 3221cc69aa5ed2b50461942bb170d2732760df50 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 1 Oct 2025 14:34:16 -0700 Subject: [PATCH 1/2] Avoid re-generating intermediate states when executing plans --- .../AbstractPhysicalOperationProviders.java | 49 +++++++++++++------ .../xpack/esql/planner/AggregateMapper.java | 40 ++++++++------- .../esql/planner/LocalExecutionPlanner.java | 8 +++ 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 50a8deb23130b..14f6a674ce3e0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -38,8 +38,10 @@ import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.function.Consumer; @@ -47,7 +49,6 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders { - private final AggregateMapper aggregateMapper = new AggregateMapper(); private final FoldContext foldContext; private final AnalysisRegistry analysisRegistry; @@ -76,7 +77,7 @@ public final PhysicalOperation groupingPhysicalOperation( // append channels to the layout if (aggregatorMode.isOutputPartial()) { - layout.append(aggregateMapper.mapNonGrouping(aggregates)); + layout.append(aggregateExec.output()); } else { layout.append(aggregates); } @@ -148,7 +149,10 @@ else if (aggregatorMode.isOutputPartial()) { } if (aggregatorMode.isOutputPartial()) { - layout.append(aggregateMapper.mapGrouping(aggregates)); + List output = aggregateExec.output(); + for (int i = aggregateExec.groupings().size(); i < output.size(); i++) { + layout.append(output.get(i)); + } } else { for (var agg : aggregates) { if (Alias.unwrap(agg) instanceof AggregateFunction) { @@ -156,7 +160,6 @@ else if (aggregatorMode.isOutputPartial()) { } } } - // create the agg factories aggregatesToFactory( aggregateExec, @@ -203,13 +206,12 @@ public static List intermediateAttributes(List attrs = new ArrayList<>(); // no groups if (groupings.isEmpty()) { - attrs = Expressions.asAttributes(aggregateMapper.mapNonGrouping(aggregates)); + attrs = Expressions.asAttributes(AggregateMapper.mapNonGrouping(aggregates)); } // groups else { @@ -241,13 +243,34 @@ public static List intermediateAttributes(List channels, AggregatorMode mode) {} + private static class IntermediateInputs { + private final List inputAttributes; + private int nextOffset; + private final Map offsets = new HashMap<>(); + + IntermediateInputs(AggregateExec aggregateExec) { + inputAttributes = aggregateExec.child().output(); + nextOffset = aggregateExec.groupings().size(); // skip grouping attributes + } + + List nextInputAttributes(AggregateFunction af, boolean grouping) { + int intermediateStateSize = AggregateMapper.intermediateStateDesc(af, grouping).size(); + int offset = offsets.computeIfAbsent(af, unused -> { + int v = nextOffset; + nextOffset += intermediateStateSize; + return v; + }); + return inputAttributes.subList(offset, offset + intermediateStateSize); + } + } + private void aggregatesToFactory( AggregateExec aggregateExec, List aggregates, @@ -257,21 +280,16 @@ private void aggregatesToFactory( Consumer consumer, LocalExecutionPlannerContext context ) { + IntermediateInputs intermediateInputs = mode.isInputPartial() ? new IntermediateInputs(aggregateExec) : null; // extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase for (NamedExpression ne : aggregates) { // a filter can only appear on aggregate function, not on the grouping columns - if (ne instanceof Alias alias) { var child = alias.child(); if (child instanceof AggregateFunction aggregateFunction) { - List sourceAttr = new ArrayList<>(); - + final List sourceAttr; if (mode.isInputPartial()) { - if (grouping) { - sourceAttr = aggregateMapper.mapGrouping(ne); - } else { - sourceAttr = aggregateMapper.mapNonGrouping(ne); - } + sourceAttr = intermediateInputs.nextInputAttributes(aggregateFunction, grouping); } else { // TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1) Expression field = aggregateFunction.field(); @@ -287,6 +305,7 @@ private void aggregatesToFactory( } } else { // extra dependencies like TS ones (that require a timestamp) + sourceAttr = new ArrayList<>(); for (Expression input : aggregateFunction.aggregateInputReferences(aggregateExec.child()::output)) { Attribute attr = Expressions.attribute(input); if (attr == null) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index baf49a9c35493..29edfc621e8d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -23,8 +23,9 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; -import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Stream; /** @@ -32,34 +33,35 @@ */ final class AggregateMapper { - // TODO: Do we need this cache? - /** Cache of aggregates to intermediate expressions. */ - private final HashMap> cache = new HashMap<>(); - - public List mapNonGrouping(List aggregates) { + public static List mapNonGrouping(List aggregates) { return doMapping(aggregates, false); } - public List mapNonGrouping(NamedExpression aggregate) { - return map(aggregate, false).toList(); - } - - public List mapGrouping(List aggregates) { + public static List mapGrouping(List aggregates) { return doMapping(aggregates, true); } - private List doMapping(List aggregates, boolean grouping) { + private static List doMapping(List aggregates, boolean grouping) { + Set seen = new HashSet<>(); AttributeMap.Builder attrToExpressionsBuilder = AttributeMap.builder(); - aggregates.stream().flatMap(ne -> map(ne, grouping)).forEach(ne -> attrToExpressionsBuilder.put(ne.toAttribute(), ne)); + for (NamedExpression agg : aggregates) { + Expression inner = Alias.unwrap(agg); + if (seen.add(inner)) { + for (var ne : computeEntryForAgg(agg.name(), inner, grouping)) { + attrToExpressionsBuilder.put(ne.toAttribute(), ne); + } + } + } return attrToExpressionsBuilder.build().values().stream().toList(); } - public List mapGrouping(NamedExpression aggregate) { - return map(aggregate, true).toList(); - } - - private Stream map(NamedExpression ne, boolean grouping) { - return cache.computeIfAbsent(Alias.unwrap(ne), aggKey -> computeEntryForAgg(ne.name(), aggKey, grouping)).stream(); + public static List intermediateStateDesc(AggregateFunction fn, boolean grouping) { + if (fn instanceof ToAggregator toAggregator) { + var supplier = toAggregator.supplier(); + return grouping ? supplier.groupingIntermediateStateDesc() : supplier.nonGroupingIntermediateStateDesc(); + } else { + throw new EsqlIllegalArgumentException("Aggregate has no defined intermediate state: " + fn); + } } private static List computeEntryForAgg(String aggAlias, Expression aggregate, boolean grouping) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 7778ef601eeb3..b179b4575d25f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -56,6 +56,7 @@ import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; import org.elasticsearch.compute.operator.topn.TopNOperator.TopNOperatorFactory; +import org.elasticsearch.core.Assertions; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexMode; @@ -443,6 +444,13 @@ private PhysicalOperation planExchangeSink(ExchangeSinkExec exchangeSink, LocalE Objects.requireNonNull(exchangeSinkSupplier, "ExchangeSinkHandler wasn't provided"); var child = exchangeSink.child(); PhysicalOperation source = plan(child, context); + if (Assertions.ENABLED) { + List inputAttributes = exchangeSink.child().output(); + for (Attribute attr : inputAttributes) { + assert source.layout.get(attr.id()) != null + : "input attribute [" + attr + "] does not exist in the source layout [" + source.layout + "]"; + } + } return source.withSink(new ExchangeSinkOperatorFactory(exchangeSinkSupplier), source.layout); } From c5ab8d8c861531ad6636295e2f2daa917552461d Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 2 Oct 2025 12:42:59 -0700 Subject: [PATCH 2/2] append output directly --- .../esql/planner/AbstractPhysicalOperationProviders.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 14f6a674ce3e0..087d03ea4c10e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -75,12 +75,7 @@ public final PhysicalOperation groupingPhysicalOperation( // not grouping List aggregatorFactories = new ArrayList<>(); - // append channels to the layout - if (aggregatorMode.isOutputPartial()) { - layout.append(aggregateExec.output()); - } else { - layout.append(aggregates); - } + layout.append(aggregateExec.output()); // create the agg factories aggregatesToFactory(