|
48 | 48 | import java.util.HashMap; |
49 | 49 | import java.util.List; |
50 | 50 | import java.util.Map; |
| 51 | +import java.util.Objects; |
51 | 52 | import java.util.stream.Collectors; |
52 | 53 | import java.util.stream.Stream; |
53 | 54 |
|
@@ -99,16 +100,11 @@ public AggDef withoutExtra() { |
99 | 100 |
|
100 | 101 | /** Map of AggDef types to intermediate named expressions. */ |
101 | 102 | private static final Map<AggDef, List<IntermediateStateDesc>> MAPPER = AGG_FUNCTIONS.stream() |
102 | | - .flatMap(AggregateMapper::typeAndNames) |
103 | | - .flatMap(AggregateMapper::groupingAndNonGrouping) |
| 103 | + .flatMap(AggregateMapper::aggDefs) |
104 | 104 | .collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState)); |
105 | 105 |
|
106 | 106 | /** Cache of aggregates to intermediate expressions. */ |
107 | | - private final HashMap<Expression, List<NamedExpression>> cache; |
108 | | - |
109 | | - AggregateMapper() { |
110 | | - cache = new HashMap<>(); |
111 | | - } |
| 107 | + private final HashMap<Expression, List<NamedExpression>> cache = new HashMap<>(); |
112 | 108 |
|
113 | 109 | public List<NamedExpression> mapNonGrouping(List<? extends NamedExpression> aggregates) { |
114 | 110 | return doMapping(aggregates, false); |
@@ -167,7 +163,7 @@ private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) { |
167 | 163 | return l; |
168 | 164 | } |
169 | 165 |
|
170 | | - private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class<?> clazz) { |
| 166 | + private static Stream<AggDef> aggDefs(Class<?> clazz) { |
171 | 167 | List<String> types; |
172 | 168 | List<String> extraConfigs = List.of(""); |
173 | 169 | if (NumericAggregate.class.isAssignableFrom(clazz)) { |
@@ -197,32 +193,26 @@ private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class |
197 | 193 | assert false : "unknown aggregate type " + clazz; |
198 | 194 | throw new IllegalArgumentException("unknown aggregate type " + clazz); |
199 | 195 | } |
200 | | - return combine(clazz, types, extraConfigs); |
201 | | - } |
202 | | - |
203 | | - private static Stream<Tuple<Class<?>, Tuple<String, String>>> combine(Class<?> clazz, List<String> types, List<String> extraConfigs) { |
204 | | - return combinations(types, extraConfigs).map(combo -> new Tuple<>(clazz, combo)); |
| 196 | + return combinations(types, extraConfigs).flatMap(typeAndExtraConfig -> { |
| 197 | + var type = typeAndExtraConfig.v1(); |
| 198 | + var extra = typeAndExtraConfig.v2(); |
| 199 | + |
| 200 | + if (clazz.isAssignableFrom(Rate.class)) { |
| 201 | + // rate doesn't support non-grouping aggregations |
| 202 | + return Stream.of(new AggDef(clazz, type, extra, true)); |
| 203 | + } else if (Objects.equals(type, "AggregateMetricDouble")) { |
| 204 | + // TODO: support grouping aggregations for aggregate metric double |
| 205 | + return Stream.of(new AggDef(clazz, type, extra, false)); |
| 206 | + } else { |
| 207 | + return Stream.of(new AggDef(clazz, type, extra, true), new AggDef(clazz, type, extra, false)); |
| 208 | + } |
| 209 | + }); |
205 | 210 | } |
206 | 211 |
|
207 | 212 | private static Stream<Tuple<String, String>> combinations(List<String> types, List<String> extraConfigs) { |
208 | 213 | return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple<>(type, config))); |
209 | 214 | } |
210 | 215 |
|
211 | | - private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<String, String>> tuple) { |
212 | | - if (tuple.v1().isAssignableFrom(Rate.class)) { |
213 | | - // rate doesn't support non-grouping aggregations |
214 | | - return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true)); |
215 | | - } else if (tuple.v2().v1().equals("AggregateMetricDouble")) { |
216 | | - // TODO: support grouping aggregations for aggregate metric double |
217 | | - return Stream.of(new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false)); |
218 | | - } else { |
219 | | - return Stream.of( |
220 | | - new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), true), |
221 | | - new AggDef(tuple.v1(), tuple.v2().v1(), tuple.v2().v2(), false) |
222 | | - ); |
223 | | - } |
224 | | - } |
225 | | - |
226 | 216 | /** Retrieves the intermediate state description for a given class, type, and grouping. */ |
227 | 217 | private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) { |
228 | 218 | try { |
@@ -264,23 +254,13 @@ private static MethodHandle lookupRetry(Class<?> clazz, String type, String extr |
264 | 254 |
|
265 | 255 | /** Determines the engines agg class name, for the given class, type, and grouping. */ |
266 | 256 | private static String determineAggName(Class<?> clazz, String type, String extra, boolean grouping) { |
267 | | - StringBuilder sb = new StringBuilder(); |
268 | | - sb.append(determinePackageName(clazz)).append("."); |
269 | | - sb.append(clazz.getSimpleName()); |
270 | | - sb.append(type); |
271 | | - sb.append(extra); |
272 | | - sb.append(grouping ? "Grouping" : ""); |
273 | | - sb.append("AggregatorFunction"); |
274 | | - return sb.toString(); |
275 | | - } |
276 | | - |
277 | | - /** Determines the engine agg package name, for the given class. */ |
278 | | - private static String determinePackageName(Class<?> clazz) { |
279 | | - if (clazz.getSimpleName().startsWith("Spatial")) { |
280 | | - // All spatial aggs are in the spatial sub-package |
281 | | - return "org.elasticsearch.compute.aggregation.spatial"; |
282 | | - } |
283 | | - return "org.elasticsearch.compute.aggregation"; |
| 257 | + return "org.elasticsearch.compute.aggregation." |
| 258 | + + (clazz.getSimpleName().startsWith("Spatial") ? "spatial." : "") |
| 259 | + + clazz.getSimpleName() |
| 260 | + + type |
| 261 | + + extra |
| 262 | + + (grouping ? "Grouping" : "") |
| 263 | + + "AggregatorFunction"; |
284 | 264 | } |
285 | 265 |
|
286 | 266 | /** Maps intermediate state description to named expressions. */ |
@@ -317,19 +297,16 @@ private static String dataTypeToString(DataType type, Class<?> aggClass) { |
317 | 297 | if (aggClass == ToPartial.class || aggClass == FromPartial.class) { |
318 | 298 | return ""; |
319 | 299 | } |
320 | | - if ((aggClass == Max.class || aggClass == Min.class) && type.equals(DataType.IP)) { |
321 | | - return "Ip"; |
322 | | - } |
323 | | - if (aggClass == Top.class && type.equals(DataType.IP)) { |
| 300 | + if ((aggClass == Max.class || aggClass == Min.class || aggClass == Top.class) && type.equals(DataType.IP)) { |
324 | 301 | return "Ip"; |
325 | 302 | } |
326 | 303 |
|
327 | 304 | return switch (type) { |
328 | | - case DataType.BOOLEAN -> "Boolean"; |
329 | | - case DataType.INTEGER, DataType.COUNTER_INTEGER -> "Int"; |
330 | | - case DataType.LONG, DataType.DATETIME, DataType.COUNTER_LONG, DataType.DATE_NANOS -> "Long"; |
331 | | - case DataType.DOUBLE, DataType.COUNTER_DOUBLE -> "Double"; |
332 | | - case DataType.KEYWORD, DataType.IP, DataType.VERSION, DataType.TEXT, DataType.SEMANTIC_TEXT -> "BytesRef"; |
| 305 | + case BOOLEAN -> "Boolean"; |
| 306 | + case INTEGER, COUNTER_INTEGER -> "Int"; |
| 307 | + case LONG, DATETIME, COUNTER_LONG, DATE_NANOS -> "Long"; |
| 308 | + case DOUBLE, COUNTER_DOUBLE -> "Double"; |
| 309 | + case KEYWORD, IP, VERSION, TEXT, SEMANTIC_TEXT -> "BytesRef"; |
333 | 310 | case GEO_POINT -> "GeoPoint"; |
334 | 311 | case CARTESIAN_POINT -> "CartesianPoint"; |
335 | 312 | case GEO_SHAPE -> "GeoShape"; |
|
0 commit comments