Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -99,16 +100,11 @@ public AggDef withoutExtra() {

/** Map of AggDef types to intermediate named expressions. */
private static final Map<AggDef, List<IntermediateStateDesc>> MAPPER = AGG_FUNCTIONS.stream()
.flatMap(AggregateMapper::typeAndNames)
.flatMap(AggregateMapper::groupingAndNonGrouping)
.flatMap(AggregateMapper::aggDefs)
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));

/** Cache of aggregates to intermediate expressions. */
private final HashMap<Expression, List<NamedExpression>> cache;

AggregateMapper() {
cache = new HashMap<>();
}
private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>();

public List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) {
return doMapping(aggregates, false);
Expand Down Expand Up @@ -167,7 +163,7 @@ private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
return l;
}

private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class<?> clazz) {
private static Stream<AggDef> aggDefs(Class<?> clazz) {
List<String> types;
List<String> extraConfigs = List.of("");
if (NumericAggregate.class.isAssignableFrom(clazz)) {
Expand Down Expand Up @@ -197,32 +193,26 @@ private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class
assert false : "unknown aggregate type " + clazz;
throw new IllegalArgumentException("unknown aggregate type " + clazz);
}
return combine(clazz, types, extraConfigs);
}

private static Stream<Tuple<Class<?>, Tuple<String, String>>> combine(Class<?> clazz, List<String> types, List<String> extraConfigs) {
return combinations(types, extraConfigs).map(combo -> new Tuple<>(clazz, combo));
return combinations(types, extraConfigs).flatMap(typeAndExtraConfig -> {
var type = typeAndExtraConfig.v1();
var extra = typeAndExtraConfig.v2();

if (clazz.isAssignableFrom(Rate.class)) {
// rate doesn't support non-grouping aggregations
return Stream.of(new AggDef(clazz, type, extra, true));
} else if (Objects.equals(type, "AggregateMetricDouble")) {
// TODO: support grouping aggregations for aggregate metric double
return Stream.of(new AggDef(clazz, type, extra, false));
} else {
return Stream.of(new AggDef(clazz, type, extra, true), new AggDef(clazz, type, extra, false));
}
});
}

private static Stream<Tuple<String, String>> combinations(List<String> types, List<String> extraConfigs) {
return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple<>(type, config)));
}

private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<String, String>> tuple) {
if (tuple.v1().isAssignableFrom(Rate.class)) {
// rate doesn't support non-grouping aggregations
return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true));
} else if (tuple.v2().v1().equals("AggregateMetricDouble")) {
// TODO: support grouping aggregations for aggregate metric double
return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false));
} else {
return Stream.of(
new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true),
new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false)
);
}
}

/** Retrieves the intermediate state description for a given class, type, and grouping. */
private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
try {
Expand Down Expand Up @@ -264,23 +254,13 @@ private static MethodHandle lookupRetry(Class<?> clazz, String type, String extr

/** Determines the engines agg class name, for the given class, type, and grouping. */
private static String determineAggName(Class<?> clazz, String type, String extra, boolean grouping) {
StringBuilder sb = new StringBuilder();
sb.append(determinePackageName(clazz)).append(".");
sb.append(clazz.getSimpleName());
sb.append(type);
sb.append(extra);
sb.append(grouping ? "Grouping" : "");
sb.append("AggregatorFunction");
return sb.toString();
}

/** Determines the engine agg package name, for the given class. */
private static String determinePackageName(Class<?> clazz) {
if (clazz.getSimpleName().startsWith("Spatial")) {
// All spatial aggs are in the spatial sub-package
return "org.elasticsearch.compute.aggregation.spatial";
}
return "org.elasticsearch.compute.aggregation";
return "org.elasticsearch.compute.aggregation."
+ (clazz.getSimpleName().startsWith("Spatial") ? "spatial." : "")
+ clazz.getSimpleName()
+ type
+ extra
+ (grouping ? "Grouping" : "")
+ "AggregatorFunction";
}

/** Maps intermediate state description to named expressions. */
Expand Down Expand Up @@ -317,19 +297,16 @@ private static String dataTypeToString(DataType type, Class<?> aggClass) {
if (aggClass == ToPartial.class || aggClass == FromPartial.class) {
return "";
}
if ((aggClass == Max.class || aggClass == Min.class) && type.equals(DataType.IP)) {
return "Ip";
}
if (aggClass == Top.class && type.equals(DataType.IP)) {
if ((aggClass == Max.class || aggClass == Min.class || aggClass == Top.class) && type.equals(DataType.IP)) {
return "Ip";
}

return switch (type) {
case DataType.BOOLEAN -> "Boolean";
case DataType.INTEGER, DataType.COUNTER_INTEGER -> "Int";
case DataType.LONG, DataType.DATETIME, DataType.COUNTER_LONG, DataType.DATE_NANOS -> "Long";
case DataType.DOUBLE, DataType.COUNTER_DOUBLE -> "Double";
case DataType.KEYWORD, DataType.IP, DataType.VERSION, DataType.TEXT, DataType.SEMANTIC_TEXT -> "BytesRef";
case BOOLEAN -> "Boolean";
case INTEGER, COUNTER_INTEGER -> "Int";
case LONG, DATETIME, COUNTER_LONG, DATE_NANOS -> "Long";
case DOUBLE, COUNTER_DOUBLE -> "Double";
case KEYWORD, IP, VERSION, TEXT, SEMANTIC_TEXT -> "BytesRef";
case GEO_POINT -> "GeoPoint";
case CARTESIAN_POINT -> "CartesianPoint";
case GEO_SHAPE -> "GeoShape";
Expand Down