Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,17 @@
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static java.util.Collections.emptyList;

public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {

private final AggregateMapper aggregateMapper = new AggregateMapper();
private final FoldContext foldContext;
private final AnalysisRegistry analysisRegistry;

Expand Down Expand Up @@ -76,7 +77,7 @@ public final PhysicalOperation groupingPhysicalOperation(

// append channels to the layout
if (aggregatorMode.isOutputPartial()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, AggregateExec#output has its own isOutputPartial check. I think it is probably correct to remove this if-else and always call layout.append(aggregateExec.output()); the only thing AggregateExec#output does differently in the non-partial case is de-duplicating based on name, but we shouldn't have duplicates to begin with. And if we have, we shouldn't be adding them to the output layout if they're not in the agg's output to begin with.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use aggregateExec.output() - I pushed c5ab8d8

layout.append(aggregateMapper.mapNonGrouping(aggregates));
layout.append(aggregateExec.output());
} else {
layout.append(aggregates);
}
Expand Down Expand Up @@ -148,15 +149,17 @@ else if (aggregatorMode.isOutputPartial()) {
}

if (aggregatorMode.isOutputPartial()) {
layout.append(aggregateMapper.mapGrouping(aggregates));
List<Attribute> output = aggregateExec.output();
for (int i = aggregateExec.groupings().size(); i < output.size(); i++) {
layout.append(output.get(i));
}
} else {
for (var agg : aggregates) {
if (Alias.unwrap(agg) instanceof AggregateFunction) {
layout.append(agg);
}
}
}

// create the agg factories
aggregatesToFactory(
aggregateExec,
Expand Down Expand Up @@ -203,13 +206,12 @@ public static List<Attribute> intermediateAttributes(List<? extends NamedExpress
// TODO: This should take CATEGORIZE into account:
// it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
// so the attribute generated here is the expected one
var aggregateMapper = new AggregateMapper();

List<Attribute> attrs = new ArrayList<>();

// no groups
if (groupings.isEmpty()) {
attrs = Expressions.asAttributes(aggregateMapper.mapNonGrouping(aggregates));
attrs = Expressions.asAttributes(AggregateMapper.mapNonGrouping(aggregates));
}
// groups
else {
Expand Down Expand Up @@ -241,13 +243,34 @@ public static List<Attribute> intermediateAttributes(List<? extends NamedExpress
attrs.add(groupAttribute);
}

attrs.addAll(Expressions.asAttributes(aggregateMapper.mapGrouping(aggregates)));
attrs.addAll(Expressions.asAttributes(AggregateMapper.mapGrouping(aggregates)));
}
return attrs;
}

private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, List<Integer> channels, AggregatorMode mode) {}

private static class IntermediateInputs {
private final List<Attribute> inputAttributes;
private int nextOffset;
private final Map<AggregateFunction, Integer> offsets = new HashMap<>();

IntermediateInputs(AggregateExec aggregateExec) {
inputAttributes = aggregateExec.child().output();
nextOffset = aggregateExec.groupings().size(); // skip grouping attributes
}

List<Attribute> nextInputAttributes(AggregateFunction af, boolean grouping) {
int intermediateStateSize = AggregateMapper.intermediateStateDesc(af, grouping).size();
int offset = offsets.computeIfAbsent(af, unused -> {
int v = nextOffset;
nextOffset += intermediateStateSize;
return v;
});
return inputAttributes.subList(offset, offset + intermediateStateSize);
}
}

private void aggregatesToFactory(
AggregateExec aggregateExec,
List<? extends NamedExpression> aggregates,
Expand All @@ -257,21 +280,16 @@ private void aggregatesToFactory(
Consumer<AggFunctionSupplierContext> consumer,
LocalExecutionPlannerContext context
) {
IntermediateInputs intermediateInputs = mode.isInputPartial() ? new IntermediateInputs(aggregateExec) : null;
// extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase
for (NamedExpression ne : aggregates) {
// a filter can only appear on aggregate function, not on the grouping columns

if (ne instanceof Alias alias) {
var child = alias.child();
if (child instanceof AggregateFunction aggregateFunction) {
List<NamedExpression> sourceAttr = new ArrayList<>();

final List<Attribute> sourceAttr;
if (mode.isInputPartial()) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(ne);
} else {
sourceAttr = aggregateMapper.mapNonGrouping(ne);
}
sourceAttr = intermediateInputs.nextInputAttributes(aggregateFunction, grouping);
} else {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
Expression field = aggregateFunction.field();
Expand All @@ -287,6 +305,7 @@ private void aggregatesToFactory(
}
} else {
// extra dependencies like TS ones (that require a timestamp)
sourceAttr = new ArrayList<>();
for (Expression input : aggregateFunction.aggregateInputReferences(aggregateExec.child()::output)) {
Attribute attr = Expressions.attribute(input);
if (attr == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,45 @@
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

/**
* Static class used to convert aggregate expressions to the named expressions that represent their intermediate state.
*/
final class AggregateMapper {

// TODO: Do we need this cache?
/** Cache of aggregates to intermediate expressions. */
private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>();

public List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
public static List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
return doMapping(aggregates, false);
}

public List<NamedExpression> mapNonGrouping(NamedExpression aggregate) {
return map(aggregate, false).toList();
}

public List<NamedExpression> mapGrouping(List<? extends NamedExpression> aggregates) {
public static List<NamedExpression> mapGrouping(List<? extends NamedExpression> aggregates) {
return doMapping(aggregates, true);
}

private List<NamedExpression> doMapping(List<? extends NamedExpression> aggregates, boolean grouping) {
private static List<NamedExpression> doMapping(List<? extends NamedExpression> aggregates, boolean grouping) {
Set<Expression> seen = new HashSet<>();
AttributeMap.Builder<NamedExpression> attrToExpressionsBuilder = AttributeMap.builder();
aggregates.stream().flatMap(ne -> map(ne, grouping)).forEach(ne -> attrToExpressionsBuilder.put(ne.toAttribute(), ne));
for (NamedExpression agg : aggregates) {
Expression inner = Alias.unwrap(agg);
if (seen.add(inner)) {
for (var ne : computeEntryForAgg(agg.name(), inner, grouping)) {
attrToExpressionsBuilder.put(ne.toAttribute(), ne);
}
}
}
return attrToExpressionsBuilder.build().values().stream().toList();
}

public List<NamedExpression> mapGrouping(NamedExpression aggregate) {
return map(aggregate, true).toList();
}

private Stream<NamedExpression> map(NamedExpression ne, boolean grouping) {
return cache.computeIfAbsent(Alias.unwrap(ne), aggKey -> computeEntryForAgg(ne.name(), aggKey, grouping)).stream();
public static List<IntermediateStateDesc> intermediateStateDesc(AggregateFunction fn, boolean grouping) {
if (fn instanceof ToAggregator toAggregator) {
var supplier = toAggregator.supplier();
return grouping ? supplier.groupingIntermediateStateDesc() : supplier.nonGroupingIntermediateStateDesc();
} else {
throw new EsqlIllegalArgumentException("Aggregate has no defined intermediate state: " + fn);
}
}

private static List<NamedExpression> computeEntryForAgg(String aggAlias, Expression aggregate, boolean grouping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.elasticsearch.compute.operator.topn.TopNEncoder;
import org.elasticsearch.compute.operator.topn.TopNOperator;
import org.elasticsearch.compute.operator.topn.TopNOperator.TopNOperatorFactory;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexMode;
Expand Down Expand Up @@ -443,6 +444,13 @@ private PhysicalOperation planExchangeSink(ExchangeSinkExec exchangeSink, LocalE
Objects.requireNonNull(exchangeSinkSupplier, "ExchangeSinkHandler wasn't provided");
var child = exchangeSink.child();
PhysicalOperation source = plan(child, context);
if (Assertions.ENABLED) {
List<Attribute> inputAttributes = exchangeSink.child().output();
for (Attribute attr : inputAttributes) {
assert source.layout.get(attr.id()) != null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe throw ISE instead, so that messing up an agg will not kill the node during a full run of the test suite. That makes for easier-to-triage CI issues opened by the CI bot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be an assertion to verify the invariant. I opened #135862 to handle the spec tests when the test cluster is broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The added check is very reasonable. Maybe it makes sense to add it to the general plan method, as it's an invariant that's not only required for planning exchanges, but for every plan node?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ I tried to add this, but this invariant doesn't hold for ExchangeSinkExec and ExchangeSourceExec. I will address this in a follow-up.

: "input attribute [" + attr + "] does not exist in the source layout [" + source.layout + "]";
}
}
return source.withSink(new ExchangeSinkOperatorFactory(exchangeSinkSupplier), source.layout);
}

Expand Down