2222
2323import  java .util .Arrays ;
2424import  java .util .List ;
25- import  java .util .Objects ;
2625import  java .util .function .Consumer ;
26+ import  java .util .function .Function ;
2727import  java .util .stream .Collectors ;
28+ import  java .util .stream .Stream ;
2829
2930import  javax .lang .model .element .ExecutableElement ;
3031import  javax .lang .model .element .Modifier ;
3435
3536import  static  java .util .stream .Collectors .joining ;
3637import  static  org .elasticsearch .compute .gen .AggregatorImplementer .capitalize ;
37- import  static  org .elasticsearch .compute .gen .Methods .findMethod ;
38- import  static  org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
38+ import  static  org .elasticsearch .compute .gen .Methods .requireAnyArgs ;
39+ import  static  org .elasticsearch .compute .gen .Methods .requireAnyType ;
40+ import  static  org .elasticsearch .compute .gen .Methods .requireArgs ;
41+ import  static  org .elasticsearch .compute .gen .Methods .requireName ;
42+ import  static  org .elasticsearch .compute .gen .Methods .requirePrimitiveOrImplements ;
43+ import  static  org .elasticsearch .compute .gen .Methods .requireStaticMethod ;
44+ import  static  org .elasticsearch .compute .gen .Methods .requireType ;
45+ import  static  org .elasticsearch .compute .gen .Methods .requireVoidType ;
3946import  static  org .elasticsearch .compute .gen .Methods .vectorAccessorName ;
4047import  static  org .elasticsearch .compute .gen .Types .BIG_ARRAYS ;
48+ import  static  org .elasticsearch .compute .gen .Types .BLOCK ;
4149import  static  org .elasticsearch .compute .gen .Types .BLOCK_ARRAY ;
4250import  static  org .elasticsearch .compute .gen .Types .BYTES_REF ;
4351import  static  org .elasticsearch .compute .gen .Types .DRIVER_CONTEXT ;
@@ -71,9 +79,6 @@ public class GroupingAggregatorImplementer {
7179    private  final  List <TypeMirror > warnExceptions ;
7280    private  final  ExecutableElement  init ;
7381    private  final  ExecutableElement  combine ;
74-     private  final  ExecutableElement  combineStates ;
75-     private  final  ExecutableElement  evaluateFinal ;
76-     private  final  ExecutableElement  combineIntermediate ;
7782    private  final  List <Parameter > createParameters ;
7883    private  final  ClassName  implementation ;
7984    private  final  List <AggregatorImplementer .IntermediateStateDesc > intermediateState ;
@@ -92,22 +97,23 @@ public GroupingAggregatorImplementer(
9297        this .declarationType  = declarationType ;
9398        this .warnExceptions  = warnExceptions ;
9499
95-         this .init  = findRequiredMethod (declarationType , new  String [] { "init" , "initGrouping"  }, e  -> true );
100+         this .init  = requireStaticMethod (
101+             declarationType ,
102+             // This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState 
103+             requirePrimitiveOrImplements (elements , Types .RELEASABLE ),
104+             requireName ("init" , "initGrouping" ),
105+             requireAnyArgs ("<arbitrary init arguments>" )
106+         );
96107        this .aggState  = AggregationState .create (elements , init .getReturnType (), warnExceptions .isEmpty () == false , true );
97108
98-         this .combine  = findRequiredMethod (declarationType , new  String [] { "combine"  }, e  -> {
99-             if  (e .getParameters ().size () == 0 ) {
100-                 return  false ;
101-             }
102-             TypeName  firstParamType  = TypeName .get (e .getParameters ().get (0 ).asType ());
103-             return  Objects .equals (firstParamType .toString (), aggState .declaredType ().toString ());
104-         });
105-         // TODO support multiple parameters 
109+         this .combine  = requireStaticMethod (
110+             declarationType ,
111+             aggState .declaredType ().isPrimitive () ? requireType (aggState .declaredType ()) : requireVoidType (),
112+             requireName ("combine" ),
113+             combineArgs (aggState , includeTimestampVector )
114+         );
106115        this .aggParam  = AggregationParameter .create (combine .getParameters ().get (combine .getParameters ().size () - 1 ).asType ());
107116
108-         this .combineStates  = findMethod (declarationType , "combineStates" );
109-         this .combineIntermediate  = findMethod (declarationType , "combineIntermediate" );
110-         this .evaluateFinal  = findMethod (declarationType , "evaluateFinal" );
111117        this .createParameters  = init .getParameters ()
112118            .stream ()
113119            .map (Parameter ::from )
@@ -125,6 +131,25 @@ public GroupingAggregatorImplementer(
125131        this .includeTimestampVector  = includeTimestampVector ;
126132    }
127133
134+     private  static  Methods .ArgumentMatcher  combineArgs (AggregationState  aggState , boolean  includeTimestampVector ) {
135+         if  (aggState .declaredType ().isPrimitive ()) {
136+             return  requireArgs (requireType (aggState .declaredType ()), requireAnyType ("<aggregation input column type>" ));
137+         } else  if  (includeTimestampVector ) {
138+             return  requireArgs (
139+                 requireType (aggState .declaredType ()),
140+                 requireType (TypeName .INT ),
141+                 requireType (TypeName .LONG ), // @timestamp 
142+                 requireAnyType ("<aggregation input column type>" )
143+             );
144+         } else  {
145+             return  requireArgs (
146+                 requireType (aggState .declaredType ()),
147+                 requireType (TypeName .INT ),
148+                 requireAnyType ("<aggregation input column type>" )
149+             );
150+         }
151+     }
152+ 
128153    public  ClassName  implementation () {
129154        return  implementation ;
130155    }
@@ -557,31 +582,33 @@ private MethodSpec addIntermediateInput() {
557582                });
558583                builder .endControlFlow ();
559584            } else  {
560-                 builder .addStatement ("$T.combineIntermediate(state, groupId, "  + intermediateStateRowAccess () + ")" , declarationType );
585+                 var  stateHasBlock  = intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block );
586+                 requireStaticMethod (
587+                     declarationType ,
588+                     requireVoidType (),
589+                     requireName ("combineIntermediate" ),
590+                     requireArgs (
591+                         Stream .of (
592+                             Stream .of (aggState .declaredType (), TypeName .INT ), // aggState and groupId 
593+                             intermediateState .stream ().map (AggregatorImplementer .IntermediateStateDesc ::combineArgType ),
594+                             Stream .of (TypeName .INT ).filter (p  -> stateHasBlock ) // position 
595+                         ).flatMap (Function .identity ()).map (Methods ::requireType ).toArray (Methods .TypeMatcher []::new )
596+                     )
597+                 );
598+ 
599+                 builder .addStatement (
600+                     "$T.combineIntermediate(state, groupId, " 
601+                         + intermediateState .stream ().map (desc  -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ))
602+                         + (stateHasBlock  ? ", groupPosition + positionOffset"  : "" )
603+                         + ")" ,
604+                     declarationType 
605+                 );
561606            }
562607            builder .endControlFlow ();
563608        }
564609        return  builder .build ();
565610    }
566611
567-     String  intermediateStateRowAccess () {
568-         String  rowAccess  = intermediateState .stream ().map (desc  -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ));
569-         if  (intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block )) {
570-             rowAccess  += ", groupPosition + positionOffset" ;
571-         }
572-         return  rowAccess ;
573-     }
574- 
575-     private  void  combineStates (MethodSpec .Builder  builder ) {
576-         if  (combineStates  == null ) {
577-             builder .beginControlFlow ("if (inState.hasValue(position))" );
578-             builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
579-             builder .endControlFlow ();
580-             return ;
581-         }
582-         builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
583-     }
584- 
585612    private  MethodSpec  addIntermediateRowInput () {
586613        MethodSpec .Builder  builder  = MethodSpec .methodBuilder ("addIntermediateRowInput" );
587614        builder .addAnnotation (Override .class ).addModifiers (Modifier .PUBLIC );
@@ -593,7 +620,24 @@ private MethodSpec addIntermediateRowInput() {
593620        builder .endControlFlow ();
594621        builder .addStatement ("$T inState = (($T) input).state" , aggState .type (), implementation );
595622        builder .addStatement ("state.enableGroupIdTracking(new $T.Empty())" , SEEN_GROUP_IDS );
596-         combineStates (builder );
623+         if  (aggState .declaredType ().isPrimitive ()) {
624+             builder .beginControlFlow ("if (inState.hasValue(position))" );
625+             builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
626+             builder .endControlFlow ();
627+         } else  {
628+             requireStaticMethod (
629+                 declarationType ,
630+                 requireVoidType (),
631+                 requireName ("combineStates" ),
632+                 requireArgs (
633+                     requireType (aggState .declaredType ()),
634+                     requireType (TypeName .INT ),
635+                     requireType (aggState .declaredType ()),
636+                     requireType (TypeName .INT )
637+                 )
638+             );
639+             builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
640+         }
597641        return  builder .build ();
598642    }
599643
@@ -617,9 +661,15 @@ private MethodSpec evaluateFinal() {
617661            .addParameter (INT_VECTOR , "selected" )
618662            .addParameter (DRIVER_CONTEXT , "driverContext" );
619663
620-         if  (evaluateFinal  ==  null ) {
664+         if  (aggState . declaredType (). isPrimitive () ) {
621665            builder .addStatement ("blocks[offset] = state.toValuesBlock(selected, driverContext)" );
622666        } else  {
667+             requireStaticMethod (
668+                 declarationType ,
669+                 requireType (BLOCK ),
670+                 requireName ("evaluateFinal" ),
671+                 requireArgs (requireType (aggState .declaredType ()), requireType (INT_VECTOR ), requireType (DRIVER_CONTEXT ))
672+             );
623673            builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)" , declarationType );
624674        }
625675        return  builder .build ();
0 commit comments