1010import org .elasticsearch .common .Strings ;
1111import org .elasticsearch .compute .aggregation .IntermediateStateDesc ;
1212import org .elasticsearch .compute .data .ElementType ;
13- import org .elasticsearch .core .Tuple ;
1413import org .elasticsearch .xpack .esql .EsqlIllegalArgumentException ;
1514import org .elasticsearch .xpack .esql .core .expression .Alias ;
1615import org .elasticsearch .xpack .esql .core .expression .Attribute ;
3029import org .elasticsearch .xpack .esql .expression .function .aggregate .Max ;
3130import org .elasticsearch .xpack .esql .expression .function .aggregate .MedianAbsoluteDeviation ;
3231import org .elasticsearch .xpack .esql .expression .function .aggregate .Min ;
33- import org .elasticsearch .xpack .esql .expression .function .aggregate .NumericAggregate ;
3432import org .elasticsearch .xpack .esql .expression .function .aggregate .Percentile ;
3533import org .elasticsearch .xpack .esql .expression .function .aggregate .Rate ;
36- import org .elasticsearch .xpack .esql .expression .function .aggregate .SpatialAggregateFunction ;
3734import org .elasticsearch .xpack .esql .expression .function .aggregate .SpatialCentroid ;
3835import org .elasticsearch .xpack .esql .expression .function .aggregate .SpatialExtent ;
3936import org .elasticsearch .xpack .esql .expression .function .aggregate .StdDev ;
4239import org .elasticsearch .xpack .esql .expression .function .aggregate .Top ;
4340import org .elasticsearch .xpack .esql .expression .function .aggregate .Values ;
4441
45- import java .lang .invoke .MethodHandle ;
46- import java .lang .invoke .MethodHandles ;
47- import java .lang .invoke .MethodType ;
4842import java .util .HashMap ;
4943import java .util .List ;
50- import java .util .Map ;
51- import java .util .Objects ;
52- import java .util .stream .Collectors ;
5344import java .util .stream .Stream ;
5445
5546/**
6758 */
6859final class AggregateMapper {
6960
70- private static final List <String > NUMERIC = List .of ("Int" , "Long" , "Double" );
71- private static final List <String > SPATIAL_EXTRA_CONFIGS = List .of ("SourceValues" , "DocValues" );
72-
7361 /** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */
7462 private static final List <? extends Class <? extends Function >> AGG_FUNCTIONS = List .of (
7563 Count .class ,
@@ -91,18 +79,6 @@ final class AggregateMapper {
9179 ToPartial .class
9280 );
9381
94- /** Record of agg Class, type, and grouping (or non-grouping). */
95- private record AggDef (Class <?> aggClazz , String type , String extra , boolean grouping ) {
96- public AggDef withoutExtra () {
97- return new AggDef (aggClazz , type , "" , grouping );
98- }
99- }
100-
101- /** Map of AggDef types to intermediate named expressions. */
102- private static final Map <AggDef , List <IntermediateStateDesc >> MAPPER = AGG_FUNCTIONS .stream ()
103- .flatMap (AggregateMapper ::aggDefs )
104- .collect (Collectors .toUnmodifiableMap (aggDef -> aggDef , AggregateMapper ::lookupIntermediateState ));
105-
10682 /** Cache of aggregates to intermediate expressions. */
10783 private final HashMap <Expression , List <NamedExpression >> cache = new HashMap <>();
10884
@@ -144,127 +120,21 @@ private static List<NamedExpression> computeEntryForAgg(String aggAlias, Express
144120 }
145121
146122 private static List <NamedExpression > entryForAgg (String aggAlias , AggregateFunction aggregateFunction , boolean grouping ) {
147- var aggDef = new AggDef (
148- aggregateFunction .getClass (),
149- dataTypeToString (aggregateFunction .field ().dataType (), aggregateFunction .getClass ()),
150- aggregateFunction instanceof SpatialAggregateFunction ? "SourceValues" : "" ,
151- grouping
152- );
153- var is = getNonNull (aggDef );
154- return isToNE (is , aggAlias ).toList ();
155- }
156-
157- /** Gets the agg from the mapper - wrapper around map::get for more informative failure.*/
158- private static List <IntermediateStateDesc > getNonNull (AggDef aggDef ) {
159- var l = MAPPER .getOrDefault (aggDef , MAPPER .get (aggDef .withoutExtra ()));
160- if (l == null ) {
161- throw new EsqlIllegalArgumentException ("Cannot find intermediate state for: " + aggDef );
162- }
163- return l ;
164- }
165-
166- private static Stream <AggDef > aggDefs (Class <?> clazz ) {
167- List <String > types ;
168- List <String > extraConfigs = List .of ("" );
169- if (NumericAggregate .class .isAssignableFrom (clazz )) {
170- types = NUMERIC ;
171- } else if (Max .class .isAssignableFrom (clazz ) || Min .class .isAssignableFrom (clazz )) {
172- types = List .of ("Boolean" , "Int" , "Long" , "Double" , "Ip" , "BytesRef" );
173- } else if (clazz == Count .class ) {
174- types = List .of ("" ); // no extra type distinction
175- } else if (clazz == SpatialCentroid .class ) {
176- types = List .of ("GeoPoint" , "CartesianPoint" );
177- extraConfigs = SPATIAL_EXTRA_CONFIGS ;
178- } else if (clazz == SpatialExtent .class ) {
179- types = List .of ("GeoPoint" , "CartesianPoint" , "GeoShape" , "CartesianShape" );
180- extraConfigs = SPATIAL_EXTRA_CONFIGS ;
181- } else if (Values .class .isAssignableFrom (clazz )) {
182- // TODO can't we figure this out from the function itself?
183- types = List .of ("Int" , "Long" , "Double" , "Boolean" , "BytesRef" );
184- } else if (Top .class .isAssignableFrom (clazz )) {
185- types = List .of ("Boolean" , "Int" , "Long" , "Double" , "Ip" , "BytesRef" );
186- } else if (Rate .class .isAssignableFrom (clazz ) || StdDev .class .isAssignableFrom (clazz )) {
187- types = List .of ("Int" , "Long" , "Double" );
188- } else if (FromPartial .class .isAssignableFrom (clazz ) || ToPartial .class .isAssignableFrom (clazz )) {
189- types = List .of ("" ); // no type
190- } else if (CountDistinct .class .isAssignableFrom (clazz )) {
191- types = Stream .concat (NUMERIC .stream (), Stream .of ("Boolean" , "BytesRef" )).toList ();
123+ List <IntermediateStateDesc > intermediateState ;
124+ if (aggregateFunction instanceof ToAggregator toAggregator ) {
125+ var supplier = toAggregator .supplier ();
126+ intermediateState = grouping ? supplier .groupingIntermediateStateDesc () : supplier .nonGroupingIntermediateStateDesc ();
192127 } else {
193- assert false : "unknown aggregate type " + clazz ;
194- throw new IllegalArgumentException ("unknown aggregate type " + clazz );
195- }
196- return combinations (types , extraConfigs ).flatMap (typeAndExtraConfig -> {
197- var type = typeAndExtraConfig .v1 ();
198- var extra = typeAndExtraConfig .v2 ();
199-
200- if (clazz .isAssignableFrom (Rate .class )) {
201- // rate doesn't support non-grouping aggregations
202- return Stream .of (new AggDef (clazz , type , extra , true ));
203- } else if (Objects .equals (type , "AggregateMetricDouble" )) {
204- // TODO: support grouping aggregations for aggregate metric double
205- return Stream .of (new AggDef (clazz , type , extra , false ));
206- } else {
207- return Stream .of (new AggDef (clazz , type , extra , true ), new AggDef (clazz , type , extra , false ));
208- }
209- });
210- }
211-
212- private static Stream <Tuple <String , String >> combinations (List <String > types , List <String > extraConfigs ) {
213- return types .stream ().flatMap (type -> extraConfigs .stream ().map (config -> new Tuple <>(type , config )));
214- }
215-
216- /** Retrieves the intermediate state description for a given class, type, and grouping. */
217- private static List <IntermediateStateDesc > lookupIntermediateState (AggDef aggDef ) {
218- try {
219- return (List <IntermediateStateDesc >) lookup (aggDef .aggClazz (), aggDef .type (), aggDef .extra (), aggDef .grouping ()).invokeExact ();
220- } catch (Throwable t ) {
221- // invokeExact forces us to handle any Throwable thrown by lookup.
222- throw new EsqlIllegalArgumentException (t );
223- }
224- }
225-
226- /** Looks up the intermediate state method for a given class, type, and grouping. */
227- private static MethodHandle lookup (Class <?> clazz , String type , String extra , boolean grouping ) {
228- try {
229- return lookupRetry (clazz , type , extra , grouping );
230- } catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e ) {
231- throw new EsqlIllegalArgumentException (e );
128+ throw new EsqlIllegalArgumentException ("Aggregate has no defined intermediate state: " + aggregateFunction );
232129 }
233- }
234-
235- private static MethodHandle lookupRetry (Class <?> clazz , String type , String extra , boolean grouping ) throws IllegalAccessException ,
236- NoSuchMethodException , ClassNotFoundException {
237- try {
238- return MethodHandles .lookup ()
239- .findStatic (
240- Class .forName (determineAggName (clazz , type , extra , grouping )),
241- "intermediateStateDesc" ,
242- MethodType .methodType (List .class )
243- );
244- } catch (NoSuchMethodException ignore ) {
245- // Retry without the extra information.
246- return MethodHandles .lookup ()
247- .findStatic (
248- Class .forName (determineAggName (clazz , type , "" , grouping )),
249- "intermediateStateDesc" ,
250- MethodType .methodType (List .class )
251- );
252- }
253- }
254-
255- /** Determines the engines agg class name, for the given class, type, and grouping. */
256- private static String determineAggName (Class <?> clazz , String type , String extra , boolean grouping ) {
257- return "org.elasticsearch.compute.aggregation."
258- + (clazz .getSimpleName ().startsWith ("Spatial" ) ? "spatial." : "" )
259- + clazz .getSimpleName ()
260- + type
261- + extra
262- + (grouping ? "Grouping" : "" )
263- + "AggregatorFunction" ;
130+ return intermediateStateToNamedExpressions (intermediateState , aggAlias ).toList ();
264131 }
265132
266133 /** Maps intermediate state description to named expressions. */
267- private static Stream <NamedExpression > isToNE (List <IntermediateStateDesc > intermediateStateDescs , String aggAlias ) {
134+ private static Stream <NamedExpression > intermediateStateToNamedExpressions (
135+ List <IntermediateStateDesc > intermediateStateDescs ,
136+ String aggAlias
137+ ) {
268138 return intermediateStateDescs .stream ().map (is -> {
269139 final DataType dataType ;
270140 if (Strings .isEmpty (is .dataType ())) {
@@ -288,34 +158,4 @@ private static DataType toDataType(ElementType elementType) {
288158 case FLOAT , NULL , DOC , COMPOSITE , UNKNOWN -> throw new EsqlIllegalArgumentException ("unsupported agg type: " + elementType );
289159 };
290160 }
291-
292- /** Returns the string representation for the data type. This reflects the engine's aggs naming structure. */
293- private static String dataTypeToString (DataType type , Class <?> aggClass ) {
294- if (aggClass == Count .class ) {
295- return "" ; // no type distinction
296- }
297- if (aggClass == ToPartial .class || aggClass == FromPartial .class ) {
298- return "" ;
299- }
300- if ((aggClass == Max .class || aggClass == Min .class || aggClass == Top .class ) && type .equals (DataType .IP )) {
301- return "Ip" ;
302- }
303-
304- return switch (type ) {
305- case BOOLEAN -> "Boolean" ;
306- case INTEGER , COUNTER_INTEGER -> "Int" ;
307- case LONG , DATETIME , COUNTER_LONG , DATE_NANOS -> "Long" ;
308- case DOUBLE , COUNTER_DOUBLE -> "Double" ;
309- case KEYWORD , IP , VERSION , TEXT , SEMANTIC_TEXT -> "BytesRef" ;
310- case GEO_POINT -> "GeoPoint" ;
311- case CARTESIAN_POINT -> "CartesianPoint" ;
312- case GEO_SHAPE -> "GeoShape" ;
313- case CARTESIAN_SHAPE -> "CartesianShape" ;
314- case AGGREGATE_METRIC_DOUBLE -> "AggregateMetricDouble" ;
315- case UNSUPPORTED , NULL , UNSIGNED_LONG , SHORT , BYTE , FLOAT , HALF_FLOAT , SCALED_FLOAT , OBJECT , SOURCE , DATE_PERIOD , TIME_DURATION ,
316- DOC_DATA_TYPE , TSID_DATA_TYPE , PARTIAL_AGG -> throw new EsqlIllegalArgumentException (
317- "illegal agg type: " + type .typeName ()
318- );
319- };
320- }
321161}
0 commit comments