Skip to content

Commit 025d992

Browse files
committed
Fix remaining classes
1 parent 2c7a952 commit 025d992

File tree

3 files changed

+17
-62
lines changed

3 files changed

+17
-62
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S
155155

156156
if (grouping.equals("none")) {
157157
return new AggregationOperator(
158-
List.of(supplier(op, dataType, filter, 0).aggregatorFactory(AggregatorMode.SINGLE).apply(driverContext)),
158+
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
159159
driverContext
160160
);
161161
}
@@ -182,33 +182,33 @@ private static Operator operator(DriverContext driverContext, String grouping, S
182182
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
183183
};
184184
return new HashAggregationOperator(
185-
List.of(supplier(op, dataType, filter, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
185+
List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))),
186186
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
187187
driverContext
188188
);
189189
}
190190

191-
private static AggregatorFunctionSupplier supplier(String op, String dataType, String filter, int dataChannel) {
191+
private static AggregatorFunctionSupplier supplier(String op, String dataType, String filter) {
192192
return filtered(switch (op) {
193-
case COUNT -> CountAggregatorFunction.supplier(List.of(dataChannel));
193+
case COUNT -> CountAggregatorFunction.supplier();
194194
case COUNT_DISTINCT -> switch (dataType) {
195-
case LONGS -> new CountDistinctLongAggregatorFunctionSupplier(List.of(dataChannel), 3000);
196-
case DOUBLES -> new CountDistinctDoubleAggregatorFunctionSupplier(List.of(dataChannel), 3000);
195+
case LONGS -> new CountDistinctLongAggregatorFunctionSupplier(3000);
196+
case DOUBLES -> new CountDistinctDoubleAggregatorFunctionSupplier(3000);
197197
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
198198
};
199199
case MAX -> switch (dataType) {
200-
case LONGS -> new MaxLongAggregatorFunctionSupplier(List.of(dataChannel));
201-
case DOUBLES -> new MaxDoubleAggregatorFunctionSupplier(List.of(dataChannel));
200+
case LONGS -> new MaxLongAggregatorFunctionSupplier();
201+
case DOUBLES -> new MaxDoubleAggregatorFunctionSupplier();
202202
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
203203
};
204204
case MIN -> switch (dataType) {
205-
case LONGS -> new MinLongAggregatorFunctionSupplier(List.of(dataChannel));
206-
case DOUBLES -> new MinDoubleAggregatorFunctionSupplier(List.of(dataChannel));
205+
case LONGS -> new MinLongAggregatorFunctionSupplier();
206+
case DOUBLES -> new MinDoubleAggregatorFunctionSupplier();
207207
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
208208
};
209209
case SUM -> switch (dataType) {
210-
case LONGS -> new SumLongAggregatorFunctionSupplier(List.of(dataChannel));
211-
case DOUBLES -> new SumDoubleAggregatorFunctionSupplier(List.of(dataChannel));
210+
case LONGS -> new SumLongAggregatorFunctionSupplier();
211+
case DOUBLES -> new SumDoubleAggregatorFunctionSupplier();
212212
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
213213
};
214214
default -> throw new IllegalArgumentException("unsupported op [" + op + "]");

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
import static org.elasticsearch.compute.lucene.TimeSeriesSortedSourceOperatorTests.createTimeSeriesSourceOperator;
4343
import static org.elasticsearch.compute.lucene.TimeSeriesSortedSourceOperatorTests.writeTS;
44+
import static org.elasticsearch.compute.operator.TimeSeriesAggregationOperatorFactories.SupplierWithChannels;
4445
import static org.hamcrest.Matchers.equalTo;
4546

4647
public class TimeSeriesAggregationOperatorTests extends ComputeTestCase {
@@ -269,7 +270,7 @@ public void close() {
269270
1,
270271
3,
271272
IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(),
272-
List.of(new RateLongAggregatorFunctionSupplier(List.of(4, 2), unitInMillis)),
273+
List.of(new SupplierWithChannels(new RateLongAggregatorFunctionSupplier(unitInMillis), List.of(4, 2))),
273274
List.of(),
274275
between(1, 100)
275276
).get(ctx);
@@ -279,7 +280,7 @@ public void close() {
279280
0,
280281
1,
281282
IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(),
282-
List.of(new RateLongAggregatorFunctionSupplier(List.of(2, 3, 4), unitInMillis)),
283+
List.of(new SupplierWithChannels(new RateLongAggregatorFunctionSupplier(unitInMillis), List.of(2, 3, 4))),
283284
List.of(),
284285
between(1, 100)
285286
).get(ctx);
@@ -295,7 +296,7 @@ public void close() {
295296
}
296297
Operator finalAgg = new TimeSeriesAggregationOperatorFactories.Final(
297298
finalGroups,
298-
List.of(new SumDoubleAggregatorFunctionSupplier(List.of(2))),
299+
List.of(new SupplierWithChannels(new SumDoubleAggregatorFunctionSupplier(), List.of(2))),
299300
List.of(),
300301
between(1, 100)
301302
).get(ctx);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,66 +19,20 @@
1919
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
2020
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
2121
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
22-
import org.elasticsearch.xpack.esql.core.expression.function.Function;
2322
import org.elasticsearch.xpack.esql.core.tree.Source;
2423
import org.elasticsearch.xpack.esql.core.type.DataType;
2524
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
26-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
27-
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
28-
import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial;
29-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
30-
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
31-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
32-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
33-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
34-
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
35-
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
36-
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
37-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
38-
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
39-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
40-
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
4125

4226
import java.util.HashMap;
4327
import java.util.List;
4428
import java.util.stream.Stream;
4529

4630
/**
4731
* Static class used to convert aggregate expressions to the named expressions that represent their intermediate state.
48-
* <p>
49-
* At class load time, the mapper is populated with all supported aggregate functions and their intermediate state.
50-
* </p>
51-
* <p>
52-
* Reflection is used to call the {@code intermediateStateDesc()}` static method of the aggregate functions,
53-
* but the function classes are found based on the exising information within this class.
54-
* </p>
55-
* <p>
56-
* This class must be updated when aggregations are created or updated, by adding the new aggs or types to the corresponding methods.
57-
* </p>
5832
*/
5933
final class AggregateMapper {
6034

61-
/** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */
62-
private static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(
63-
Count.class,
64-
CountDistinct.class,
65-
Max.class,
66-
MedianAbsoluteDeviation.class,
67-
Min.class,
68-
Percentile.class,
69-
SpatialCentroid.class,
70-
SpatialExtent.class,
71-
StdDev.class,
72-
Sum.class,
73-
Values.class,
74-
Top.class,
75-
Rate.class,
76-
77-
// internal function
78-
FromPartial.class,
79-
ToPartial.class
80-
);
81-
35+
// TODO: Do we need this cache?
8236
/** Cache of aggregates to intermediate expressions. */
8337
private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>();
8438

0 commit comments

Comments
 (0)