diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index a66a302354df2..56ab4652bd482 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -48,6 +48,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -99,16 +100,11 @@ public AggDef withoutExtra() { /** Map of AggDef types to intermediate named expressions. */ private static final Map> MAPPER = AGG_FUNCTIONS.stream() - .flatMap(AggregateMapper::typeAndNames) - .flatMap(AggregateMapper::groupingAndNonGrouping) + .flatMap(AggregateMapper::aggDefs) .collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState)); /** Cache of aggregates to intermediate expressions. */ - private final HashMap> cache; - - AggregateMapper() { - cache = new HashMap<>(); - } + private final HashMap> cache = new HashMap<>(); public List mapNonGrouping(List aggregates) { return doMapping(aggregates, false); @@ -167,7 +163,7 @@ private static List getNonNull(AggDef aggDef) { return l; } - private static Stream, Tuple>> typeAndNames(Class clazz) { + private static Stream aggDefs(Class clazz) { List types; List extraConfigs = List.of(""); if (NumericAggregate.class.isAssignableFrom(clazz)) { @@ -197,32 +193,26 @@ private static Stream, Tuple>> typeAndNames(Class assert false : "unknown aggregate type " + clazz; throw new IllegalArgumentException("unknown aggregate type " + clazz); } - return combine(clazz, types, extraConfigs); - } - - private static Stream, Tuple>> combine(Class clazz, List types, List extraConfigs) { - return combinations(types, extraConfigs).map(combo -> new Tuple<>(clazz, combo)); + return combinations(types, extraConfigs).flatMap(typeAndExtraConfig -> { + var type = typeAndExtraConfig.v1(); + var extra = typeAndExtraConfig.v2(); + + if (clazz.isAssignableFrom(Rate.class)) { + // rate doesn't support non-grouping aggregations + return Stream.of(new AggDef(clazz, type, extra, true)); + } else if (Objects.equals(type, "AggregateMetricDouble")) { + // TODO: support grouping aggregations for aggregate metric double + return Stream.of(new AggDef(clazz, type, extra, false)); + } else { + return Stream.of(new AggDef(clazz, type, extra, true), new AggDef(clazz, type, extra, false)); + } + }); } private static Stream> combinations(List types, List extraConfigs) { return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple<>(type, config))); } - private static Stream groupingAndNonGrouping(Tuple, Tuple> tuple) { - if (tuple.v1().isAssignableFrom(Rate.class)) { - // rate doesn't support non-grouping aggregations - return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true)); - } else if (tuple.v2().v1().equals("AggregateMetricDouble")) { - // TODO: support grouping aggregations for aggregate metric double - return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false)); - } else { - return Stream.of( - new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true), - new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false) - ); - } - } - /** Retrieves the intermediate state description for a given class, type, and grouping. */ private static List lookupIntermediateState(AggDef aggDef) { try { @@ -264,23 +254,13 @@ private static MethodHandle lookupRetry(Class clazz, String type, String extr /** Determines the engines agg class name, for the given class, type, and grouping. */ private static String determineAggName(Class clazz, String type, String extra, boolean grouping) { - StringBuilder sb = new StringBuilder(); - sb.append(determinePackageName(clazz)).append("."); - sb.append(clazz.getSimpleName()); - sb.append(type); - sb.append(extra); - sb.append(grouping ? "Grouping" : ""); - sb.append("AggregatorFunction"); - return sb.toString(); - } - - /** Determines the engine agg package name, for the given class. */ - private static String determinePackageName(Class clazz) { - if (clazz.getSimpleName().startsWith("Spatial")) { - // All spatial aggs are in the spatial sub-package - return "org.elasticsearch.compute.aggregation.spatial"; - } - return "org.elasticsearch.compute.aggregation"; + return "org.elasticsearch.compute.aggregation." + + (clazz.getSimpleName().startsWith("Spatial") ? "spatial." : "") + + clazz.getSimpleName() + + type + + extra + + (grouping ? "Grouping" : "") + + "AggregatorFunction"; } /** Maps intermediate state description to named expressions. */ @@ -317,19 +297,16 @@ private static String dataTypeToString(DataType type, Class aggClass) { if (aggClass == ToPartial.class || aggClass == FromPartial.class) { return ""; } - if ((aggClass == Max.class || aggClass == Min.class) && type.equals(DataType.IP)) { - return "Ip"; - } - if (aggClass == Top.class && type.equals(DataType.IP)) { + if ((aggClass == Max.class || aggClass == Min.class || aggClass == Top.class) && type.equals(DataType.IP)) { return "Ip"; } return switch (type) { - case DataType.BOOLEAN -> "Boolean"; - case DataType.INTEGER, DataType.COUNTER_INTEGER -> "Int"; - case DataType.LONG, DataType.DATETIME, DataType.COUNTER_LONG, DataType.DATE_NANOS -> "Long"; - case DataType.DOUBLE, DataType.COUNTER_DOUBLE -> "Double"; - case DataType.KEYWORD, DataType.IP, DataType.VERSION, DataType.TEXT, DataType.SEMANTIC_TEXT -> "BytesRef"; + case BOOLEAN -> "Boolean"; + case INTEGER, COUNTER_INTEGER -> "Int"; + case LONG, DATETIME, COUNTER_LONG, DATE_NANOS -> "Long"; + case DOUBLE, COUNTER_DOUBLE -> "Double"; + case KEYWORD, IP, VERSION, TEXT, SEMANTIC_TEXT -> "BytesRef"; case GEO_POINT -> "GeoPoint"; case CARTESIAN_POINT -> "CartesianPoint"; case GEO_SHAPE -> "GeoShape";