Skip to content

Commit 1378b59

Browse files
authored
replace tuples with named parameters in AggregateMapper (#121542)
1 parent 30a706a commit 1378b59

File tree

1 file changed

+31
-54
lines changed

1 file changed

+31
-54
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import java.util.HashMap;
4949
import java.util.List;
5050
import java.util.Map;
51+
import java.util.Objects;
5152
import java.util.stream.Collectors;
5253
import java.util.stream.Stream;
5354

@@ -99,16 +100,11 @@ public AggDef withoutExtra() {
99100

100101
/** Map of AggDef types to intermediate named expressions. */
101102
private static final Map<AggDef, List<IntermediateStateDesc>> MAPPER = AGG_FUNCTIONS.stream()
102-
.flatMap(AggregateMapper::typeAndNames)
103-
.flatMap(AggregateMapper::groupingAndNonGrouping)
103+
.flatMap(AggregateMapper::aggDefs)
104104
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
105105

106106
/** Cache of aggregates to intermediate expressions. */
107-
private final HashMap<Expression, List<NamedExpression>> cache;
108-
109-
AggregateMapper() {
110-
cache = new HashMap<>();
111-
}
107+
private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>();
112108

113109
public List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
114110
return doMapping(aggregates, false);
@@ -167,7 +163,7 @@ private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
167163
return l;
168164
}
169165

170-
private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class<?> clazz) {
166+
private static Stream<AggDef> aggDefs(Class<?> clazz) {
171167
List<String> types;
172168
List<String> extraConfigs = List.of("");
173169
if (NumericAggregate.class.isAssignableFrom(clazz)) {
@@ -197,32 +193,26 @@ private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class
197193
assert false : "unknown aggregate type " + clazz;
198194
throw new IllegalArgumentException("unknown aggregate type " + clazz);
199195
}
200-
return combine(clazz, types, extraConfigs);
201-
}
202-
203-
private static Stream<Tuple<Class<?>, Tuple<String, String>>> combine(Class<?> clazz, List<String> types, List<String> extraConfigs) {
204-
return combinations(types, extraConfigs).map(combo -> new Tuple<>(clazz, combo));
196+
return combinations(types, extraConfigs).flatMap(typeAndExtraConfig -> {
197+
var type = typeAndExtraConfig.v1();
198+
var extra = typeAndExtraConfig.v2();
199+
200+
if (clazz.isAssignableFrom(Rate.class)) {
201+
// rate doesn't support non-grouping aggregations
202+
return Stream.of(new AggDef(clazz, type, extra, true));
203+
} else if (Objects.equals(type, "AggregateMetricDouble")) {
204+
// TODO: support grouping aggregations for aggregate metric double
205+
return Stream.of(new AggDef(clazz, type, extra, false));
206+
} else {
207+
return Stream.of(new AggDef(clazz, type, extra, true), new AggDef(clazz, type, extra, false));
208+
}
209+
});
205210
}
206211

207212
private static Stream<Tuple<String, String>> combinations(List<String> types, List<String> extraConfigs) {
208213
return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple<>(type, config)));
209214
}
210215

211-
private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<String, String>> tuple) {
212-
if (tuple.v1().isAssignableFrom(Rate.class)) {
213-
// rate doesn't support non-grouping aggregations
214-
return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true));
215-
} else if (tuple.v2().v1().equals("AggregateMetricDouble")) {
216-
// TODO: support grouping aggregations for aggregate metric double
217-
return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false));
218-
} else {
219-
return Stream.of(
220-
new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true),
221-
new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false)
222-
);
223-
}
224-
}
225-
226216
/** Retrieves the intermediate state description for a given class, type, and grouping. */
227217
private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
228218
try {
@@ -264,23 +254,13 @@ private static MethodHandle lookupRetry(Class<?> clazz, String type, String extr
264254

265255
/** Determines the engines agg class name, for the given class, type, and grouping. */
266256
private static String determineAggName(Class<?> clazz, String type, String extra, boolean grouping) {
267-
StringBuilder sb = new StringBuilder();
268-
sb.append(determinePackageName(clazz)).append(".");
269-
sb.append(clazz.getSimpleName());
270-
sb.append(type);
271-
sb.append(extra);
272-
sb.append(grouping ? "Grouping" : "");
273-
sb.append("AggregatorFunction");
274-
return sb.toString();
275-
}
276-
277-
/** Determines the engine agg package name, for the given class. */
278-
private static String determinePackageName(Class<?> clazz) {
279-
if (clazz.getSimpleName().startsWith("Spatial")) {
280-
// All spatial aggs are in the spatial sub-package
281-
return "org.elasticsearch.compute.aggregation.spatial";
282-
}
283-
return "org.elasticsearch.compute.aggregation";
257+
return "org.elasticsearch.compute.aggregation."
258+
+ (clazz.getSimpleName().startsWith("Spatial") ? "spatial." : "")
259+
+ clazz.getSimpleName()
260+
+ type
261+
+ extra
262+
+ (grouping ? "Grouping" : "")
263+
+ "AggregatorFunction";
284264
}
285265

286266
/** Maps intermediate state description to named expressions. */
@@ -317,19 +297,16 @@ private static String dataTypeToString(DataType type, Class<?> aggClass) {
317297
if (aggClass == ToPartial.class || aggClass == FromPartial.class) {
318298
return "";
319299
}
320-
if ((aggClass == Max.class || aggClass == Min.class) && type.equals(DataType.IP)) {
321-
return "Ip";
322-
}
323-
if (aggClass == Top.class && type.equals(DataType.IP)) {
300+
if ((aggClass == Max.class || aggClass == Min.class || aggClass == Top.class) && type.equals(DataType.IP)) {
324301
return "Ip";
325302
}
326303

327304
return switch (type) {
328-
case DataType.BOOLEAN -> "Boolean";
329-
case DataType.INTEGER, DataType.COUNTER_INTEGER -> "Int";
330-
case DataType.LONG, DataType.DATETIME, DataType.COUNTER_LONG, DataType.DATE_NANOS -> "Long";
331-
case DataType.DOUBLE, DataType.COUNTER_DOUBLE -> "Double";
332-
case DataType.KEYWORD, DataType.IP, DataType.VERSION, DataType.TEXT, DataType.SEMANTIC_TEXT -> "BytesRef";
305+
case BOOLEAN -> "Boolean";
306+
case INTEGER, COUNTER_INTEGER -> "Int";
307+
case LONG, DATETIME, COUNTER_LONG, DATE_NANOS -> "Long";
308+
case DOUBLE, COUNTER_DOUBLE -> "Double";
309+
case KEYWORD, IP, VERSION, TEXT, SEMANTIC_TEXT -> "BytesRef";
333310
case GEO_POINT -> "GeoPoint";
334311
case CARTESIAN_POINT -> "CartesianPoint";
335312
case GEO_SHAPE -> "GeoShape";

0 commit comments

Comments
 (0)