Skip to content

Commit 2c7a952

Browse files
committed
Remove reflection from AggregateMapper
1 parent 8aba64f commit 2c7a952

File tree

1 file changed

+10
-170
lines changed

1 file changed

+10
-170
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java

Lines changed: 10 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import org.elasticsearch.common.Strings;
1111
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
1212
import org.elasticsearch.compute.data.ElementType;
13-
import org.elasticsearch.core.Tuple;
1413
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
1514
import org.elasticsearch.xpack.esql.core.expression.Alias;
1615
import org.elasticsearch.xpack.esql.core.expression.Attribute;
@@ -30,10 +29,8 @@
3029
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
3130
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
3231
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
33-
import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
3432
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
3533
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
36-
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
3734
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
3835
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
3936
import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev;
@@ -42,14 +39,8 @@
4239
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
4340
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
4441

45-
import java.lang.invoke.MethodHandle;
46-
import java.lang.invoke.MethodHandles;
47-
import java.lang.invoke.MethodType;
4842
import java.util.HashMap;
4943
import java.util.List;
50-
import java.util.Map;
51-
import java.util.Objects;
52-
import java.util.stream.Collectors;
5344
import java.util.stream.Stream;
5445

5546
/**
@@ -67,9 +58,6 @@
6758
*/
6859
final class AggregateMapper {
6960

70-
private static final List<String> NUMERIC = List.of("Int", "Long", "Double");
71-
private static final List<String> SPATIAL_EXTRA_CONFIGS = List.of("SourceValues", "DocValues");
72-
7361
/** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */
7462
private static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(
7563
Count.class,
@@ -91,18 +79,6 @@ final class AggregateMapper {
9179
ToPartial.class
9280
);
9381

94-
/** Record of agg Class, type, and grouping (or non-grouping). */
95-
private record AggDef(Class<?> aggClazz, String type, String extra, boolean grouping) {
96-
public AggDef withoutExtra() {
97-
return new AggDef(aggClazz, type, "", grouping);
98-
}
99-
}
100-
101-
/** Map of AggDef types to intermediate named expressions. */
102-
private static final Map<AggDef, List<IntermediateStateDesc>> MAPPER = AGG_FUNCTIONS.stream()
103-
.flatMap(AggregateMapper::aggDefs)
104-
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
105-
10682
/** Cache of aggregates to intermediate expressions. */
10783
private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>();
10884

@@ -144,127 +120,21 @@ private static List<NamedExpression> computeEntryForAgg(String aggAlias, Express
144120
}
145121

146122
private static List<NamedExpression> entryForAgg(String aggAlias, AggregateFunction aggregateFunction, boolean grouping) {
147-
var aggDef = new AggDef(
148-
aggregateFunction.getClass(),
149-
dataTypeToString(aggregateFunction.field().dataType(), aggregateFunction.getClass()),
150-
aggregateFunction instanceof SpatialAggregateFunction ? "SourceValues" : "",
151-
grouping
152-
);
153-
var is = getNonNull(aggDef);
154-
return isToNE(is, aggAlias).toList();
155-
}
156-
157-
/** Gets the agg from the mapper - wrapper around map::get for more informative failure.*/
158-
private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
159-
var l = MAPPER.getOrDefault(aggDef, MAPPER.get(aggDef.withoutExtra()));
160-
if (l == null) {
161-
throw new EsqlIllegalArgumentException("Cannot find intermediate state for: " + aggDef);
162-
}
163-
return l;
164-
}
165-
166-
private static Stream<AggDef> aggDefs(Class<?> clazz) {
167-
List<String> types;
168-
List<String> extraConfigs = List.of("");
169-
if (NumericAggregate.class.isAssignableFrom(clazz)) {
170-
types = NUMERIC;
171-
} else if (Max.class.isAssignableFrom(clazz) || Min.class.isAssignableFrom(clazz)) {
172-
types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef");
173-
} else if (clazz == Count.class) {
174-
types = List.of(""); // no extra type distinction
175-
} else if (clazz == SpatialCentroid.class) {
176-
types = List.of("GeoPoint", "CartesianPoint");
177-
extraConfigs = SPATIAL_EXTRA_CONFIGS;
178-
} else if (clazz == SpatialExtent.class) {
179-
types = List.of("GeoPoint", "CartesianPoint", "GeoShape", "CartesianShape");
180-
extraConfigs = SPATIAL_EXTRA_CONFIGS;
181-
} else if (Values.class.isAssignableFrom(clazz)) {
182-
// TODO can't we figure this out from the function itself?
183-
types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
184-
} else if (Top.class.isAssignableFrom(clazz)) {
185-
types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef");
186-
} else if (Rate.class.isAssignableFrom(clazz) || StdDev.class.isAssignableFrom(clazz)) {
187-
types = List.of("Int", "Long", "Double");
188-
} else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) {
189-
types = List.of(""); // no type
190-
} else if (CountDistinct.class.isAssignableFrom(clazz)) {
191-
types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();
123+
List<IntermediateStateDesc> intermediateState;
124+
if (aggregateFunction instanceof ToAggregator toAggregator) {
125+
var supplier = toAggregator.supplier();
126+
intermediateState = grouping ? supplier.groupingIntermediateStateDesc() : supplier.nonGroupingIntermediateStateDesc();
192127
} else {
193-
assert false : "unknown aggregate type " + clazz;
194-
throw new IllegalArgumentException("unknown aggregate type " + clazz);
195-
}
196-
return combinations(types, extraConfigs).flatMap(typeAndExtraConfig -> {
197-
var type = typeAndExtraConfig.v1();
198-
var extra = typeAndExtraConfig.v2();
199-
200-
if (clazz.isAssignableFrom(Rate.class)) {
201-
// rate doesn't support non-grouping aggregations
202-
return Stream.of(new AggDef(clazz, type, extra, true));
203-
} else if (Objects.equals(type, "AggregateMetricDouble")) {
204-
// TODO: support grouping aggregations for aggregate metric double
205-
return Stream.of(new AggDef(clazz, type, extra, false));
206-
} else {
207-
return Stream.of(new AggDef(clazz, type, extra, true), new AggDef(clazz, type, extra, false));
208-
}
209-
});
210-
}
211-
212-
private static Stream<Tuple<String, String>> combinations(List<String> types, List<String> extraConfigs) {
213-
return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple<>(type, config)));
214-
}
215-
216-
/** Retrieves the intermediate state description for a given class, type, and grouping. */
217-
private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
218-
try {
219-
return (List<IntermediateStateDesc>) lookup(aggDef.aggClazz(), aggDef.type(), aggDef.extra(), aggDef.grouping()).invokeExact();
220-
} catch (Throwable t) {
221-
// invokeExact forces us to handle any Throwable thrown by lookup.
222-
throw new EsqlIllegalArgumentException(t);
223-
}
224-
}
225-
226-
/** Looks up the intermediate state method for a given class, type, and grouping. */
227-
private static MethodHandle lookup(Class<?> clazz, String type, String extra, boolean grouping) {
228-
try {
229-
return lookupRetry(clazz, type, extra, grouping);
230-
} catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e) {
231-
throw new EsqlIllegalArgumentException(e);
128+
throw new EsqlIllegalArgumentException("Aggregate has no defined intermediate state: " + aggregateFunction);
232129
}
233-
}
234-
235-
private static MethodHandle lookupRetry(Class<?> clazz, String type, String extra, boolean grouping) throws IllegalAccessException,
236-
NoSuchMethodException, ClassNotFoundException {
237-
try {
238-
return MethodHandles.lookup()
239-
.findStatic(
240-
Class.forName(determineAggName(clazz, type, extra, grouping)),
241-
"intermediateStateDesc",
242-
MethodType.methodType(List.class)
243-
);
244-
} catch (NoSuchMethodException ignore) {
245-
// Retry without the extra information.
246-
return MethodHandles.lookup()
247-
.findStatic(
248-
Class.forName(determineAggName(clazz, type, "", grouping)),
249-
"intermediateStateDesc",
250-
MethodType.methodType(List.class)
251-
);
252-
}
253-
}
254-
255-
/** Determines the engines agg class name, for the given class, type, and grouping. */
256-
private static String determineAggName(Class<?> clazz, String type, String extra, boolean grouping) {
257-
return "org.elasticsearch.compute.aggregation."
258-
+ (clazz.getSimpleName().startsWith("Spatial") ? "spatial." : "")
259-
+ clazz.getSimpleName()
260-
+ type
261-
+ extra
262-
+ (grouping ? "Grouping" : "")
263-
+ "AggregatorFunction";
130+
return intermediateStateToNamedExpressions(intermediateState, aggAlias).toList();
264131
}
265132

266133
/** Maps intermediate state description to named expressions. */
267-
private static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> intermediateStateDescs, String aggAlias) {
134+
private static Stream<NamedExpression> intermediateStateToNamedExpressions(
135+
List<IntermediateStateDesc> intermediateStateDescs,
136+
String aggAlias
137+
) {
268138
return intermediateStateDescs.stream().map(is -> {
269139
final DataType dataType;
270140
if (Strings.isEmpty(is.dataType())) {
@@ -288,34 +158,4 @@ private static DataType toDataType(ElementType elementType) {
288158
case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new EsqlIllegalArgumentException("unsupported agg type: " + elementType);
289159
};
290160
}
291-
292-
/** Returns the string representation for the data type. This reflects the engine's aggs naming structure. */
293-
private static String dataTypeToString(DataType type, Class<?> aggClass) {
294-
if (aggClass == Count.class) {
295-
return ""; // no type distinction
296-
}
297-
if (aggClass == ToPartial.class || aggClass == FromPartial.class) {
298-
return "";
299-
}
300-
if ((aggClass == Max.class || aggClass == Min.class || aggClass == Top.class) && type.equals(DataType.IP)) {
301-
return "Ip";
302-
}
303-
304-
return switch (type) {
305-
case BOOLEAN -> "Boolean";
306-
case INTEGER, COUNTER_INTEGER -> "Int";
307-
case LONG, DATETIME, COUNTER_LONG, DATE_NANOS -> "Long";
308-
case DOUBLE, COUNTER_DOUBLE -> "Double";
309-
case KEYWORD, IP, VERSION, TEXT, SEMANTIC_TEXT -> "BytesRef";
310-
case GEO_POINT -> "GeoPoint";
311-
case CARTESIAN_POINT -> "CartesianPoint";
312-
case GEO_SHAPE -> "GeoShape";
313-
case CARTESIAN_SHAPE -> "CartesianShape";
314-
case AGGREGATE_METRIC_DOUBLE -> "AggregateMetricDouble";
315-
case UNSUPPORTED, NULL, UNSIGNED_LONG, SHORT, BYTE, FLOAT, HALF_FLOAT, SCALED_FLOAT, OBJECT, SOURCE, DATE_PERIOD, TIME_DURATION,
316-
DOC_DATA_TYPE, TSID_DATA_TYPE, PARTIAL_AGG -> throw new EsqlIllegalArgumentException(
317-
"illegal agg type: " + type.typeName()
318-
);
319-
};
320-
}
321161
}

0 commit comments

Comments
 (0)