Skip to content

Commit 70095cf

Browse files
committed
Fully implemented TopNAggregate, optimizations and rules to be reviewed
1 parent 92311cd commit 70095cf

File tree

11 files changed

+313
-308
lines changed

11 files changed

+313
-308
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/ProjectAwayColumns.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
1818
import org.elasticsearch.xpack.esql.plan.logical.Eval;
1919
import org.elasticsearch.xpack.esql.plan.logical.Project;
20+
import org.elasticsearch.xpack.esql.plan.logical.TopN;
21+
import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate;
2022
import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
2123
import org.elasticsearch.xpack.esql.plan.physical.FragmentExec;
2224
import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
@@ -61,7 +63,7 @@ public PhysicalPlan apply(PhysicalPlan plan) {
6163
var logicalFragment = fragmentExec.fragment();
6264

6365
// no need for projection when dealing with aggs
64-
if (logicalFragment instanceof Aggregate == false) {
66+
if (logicalFragment instanceof Aggregate == false && logicalFragment instanceof TopNAggregate == false) {
6567
List<Attribute> output = new ArrayList<>(requiredAttrBuilder.build());
6668
// if all the fields are filtered out, it's only the count that matters
6769
// however until a proper fix (see https://github.com/elastic/elasticsearch/issues/98703)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.plan.physical;
9+
10+
import org.elasticsearch.TransportVersions;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.compute.aggregation.AggregatorMode;
14+
import org.elasticsearch.xpack.esql.core.expression.Attribute;
15+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
16+
import org.elasticsearch.xpack.esql.core.expression.Expression;
17+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
18+
import org.elasticsearch.xpack.esql.core.tree.Source;
19+
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
20+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
21+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
22+
23+
import java.io.IOException;
24+
import java.util.ArrayList;
25+
import java.util.HashSet;
26+
import java.util.List;
27+
import java.util.Objects;
28+
29+
/**
30+
* Base class for aggregate nodes.
31+
*/
32+
public abstract class AbstractAggregateExec extends UnaryExec implements EstimatesRowSize {
33+
protected final List<? extends Expression> groupings;
34+
protected final List<? extends NamedExpression> aggregates;
35+
/**
36+
* The output attributes of {@link AggregatorMode#INITIAL} and {@link AggregatorMode#INTERMEDIATE} aggregations, resp.
37+
* the input attributes of {@link AggregatorMode#FINAL} and {@link AggregatorMode#INTERMEDIATE} aggregations.
38+
*/
39+
protected final List<Attribute> intermediateAttributes;
40+
41+
protected final AggregatorMode mode;
42+
43+
/**
44+
* Estimate of the number of bytes that'll be loaded per position before
45+
* the stream of pages is consumed.
46+
*/
47+
protected final Integer estimatedRowSize;
48+
49+
protected AbstractAggregateExec(
50+
Source source,
51+
PhysicalPlan child,
52+
List<? extends Expression> groupings,
53+
List<? extends NamedExpression> aggregates,
54+
AggregatorMode mode,
55+
List<Attribute> intermediateAttributes,
56+
Integer estimatedRowSize
57+
) {
58+
super(source, child);
59+
this.groupings = groupings;
60+
this.aggregates = aggregates;
61+
this.mode = mode;
62+
this.intermediateAttributes = intermediateAttributes;
63+
this.estimatedRowSize = estimatedRowSize;
64+
}
65+
66+
protected AbstractAggregateExec(StreamInput in) throws IOException {
67+
// This is only deserialized as part of node level reduction, which is turned off until at least 8.16.
68+
// So, we do not have to consider previous transport versions here, because old nodes will not send AggregateExecs to new nodes.
69+
super(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(PhysicalPlan.class));
70+
this.groupings = in.readNamedWriteableCollectionAsList(Expression.class);
71+
this.aggregates = in.readNamedWriteableCollectionAsList(NamedExpression.class);
72+
this.mode = in.readEnum(AggregatorMode.class);
73+
this.intermediateAttributes = in.readNamedWriteableCollectionAsList(Attribute.class);
74+
this.estimatedRowSize = in.readOptionalVInt();
75+
}
76+
77+
@Override
78+
public void writeTo(StreamOutput out) throws IOException {
79+
Source.EMPTY.writeTo(out);
80+
out.writeNamedWriteable(child());
81+
out.writeNamedWriteableCollection(groupings());
82+
out.writeNamedWriteableCollection(aggregates());
83+
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
84+
out.writeEnum(getMode());
85+
out.writeNamedWriteableCollection(intermediateAttributes());
86+
} else {
87+
out.writeEnum(Mode.fromAggregatorMode(getMode()));
88+
}
89+
out.writeOptionalVInt(estimatedRowSize());
90+
}
91+
92+
public List<? extends Expression> groupings() {
93+
return groupings;
94+
}
95+
96+
public List<? extends NamedExpression> aggregates() {
97+
return aggregates;
98+
}
99+
100+
/**
101+
* Estimate of the number of bytes that'll be loaded per position before
102+
* the stream of pages is consumed.
103+
*/
104+
public Integer estimatedRowSize() {
105+
return estimatedRowSize;
106+
}
107+
108+
@Override
109+
public PhysicalPlan estimateRowSize(State state) {
110+
state.add(false, aggregates); // The groupings are contained within the aggregates
111+
int size = state.consumeAllFields(true);
112+
size = Math.max(size, 1);
113+
return Objects.equals(this.estimatedRowSize, size) ? this : withEstimatedSize(size);
114+
}
115+
116+
protected abstract AbstractAggregateExec withEstimatedSize(int estimatedRowSize);
117+
118+
public AggregatorMode getMode() {
119+
return mode;
120+
}
121+
122+
/**
123+
* Used only for bwc when de-/serializing.
124+
*/
125+
@Deprecated
126+
private enum Mode {
127+
SINGLE,
128+
PARTIAL, // maps raw inputs to intermediate outputs
129+
FINAL; // maps intermediate inputs to final outputs
130+
131+
static Mode fromAggregatorMode(AggregatorMode aggregatorMode) {
132+
return switch (aggregatorMode) {
133+
case SINGLE -> SINGLE;
134+
case INITIAL -> PARTIAL;
135+
case FINAL -> FINAL;
136+
// If needed, we could have this return an PARTIAL instead; that's how intermediate aggs were encoded in the past for
137+
// data node level reduction.
138+
case INTERMEDIATE -> throw new UnsupportedOperationException(
139+
"cannot turn intermediate aggregation into single, partial or final."
140+
);
141+
};
142+
}
143+
}
144+
145+
/**
146+
* Aggregations are usually performed in two steps, first partial (e.g. locally on a data node) then final (on the coordinator node).
147+
* These are the intermediate attributes output by a partial aggregation or consumed by a final one.
148+
* C.f. {@link org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders#intermediateAttributes}.
149+
*/
150+
public List<Attribute> intermediateAttributes() {
151+
return intermediateAttributes;
152+
}
153+
154+
@Override
155+
public List<Attribute> output() {
156+
return mode.isOutputPartial() ? intermediateAttributes : Aggregate.output(aggregates);
157+
}
158+
159+
@Override
160+
protected AttributeSet computeReferences() {
161+
return mode.isInputPartial()
162+
? AttributeSet.of(intermediateAttributes)
163+
: Aggregate.computeReferences(aggregates, groupings).subtract(AttributeSet.of(ordinalAttributes()));
164+
}
165+
166+
/** Returns the attributes that can be loaded from ordinals -- no explicit extraction is needed */
167+
public List<Attribute> ordinalAttributes() {
168+
List<Attribute> orginalAttributs = new ArrayList<>(groupings.size());
169+
// Ordinals can be leveraged just for a single grouping. If there are multiple groupings, fields need to be laoded for the
170+
// hash aggregator.
171+
// CATEGORIZE requires the standard hash aggregator as well.
172+
if (groupings().size() == 1 && groupings.get(0).anyMatch(e -> e instanceof Categorize) == false) {
173+
var leaves = new HashSet<>();
174+
aggregates.stream().filter(a -> groupings.contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
175+
groupings.forEach(g -> {
176+
if (leaves.contains(g) == false) {
177+
orginalAttributs.add((Attribute) g);
178+
}
179+
});
180+
}
181+
return orginalAttributs;
182+
}
183+
}

0 commit comments

Comments
 (0)