1717
1818import  org .elasticsearch .compute .ann .Aggregator ;
1919import  org .elasticsearch .compute .ann .IntermediateState ;
20+ import  org .elasticsearch .compute .gen .Methods .TypeMatcher ;
2021
2122import  java .util .Arrays ;
2223import  java .util .List ;
3334import  javax .lang .model .util .Elements ;
3435
3536import  static  java .util .stream .Collectors .joining ;
36- import  static  org .elasticsearch .compute .gen .Methods .findMethod ;
37- import  static  org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
37+ import  static  org .elasticsearch .compute .gen .Methods .requireAnyArgs ;
38+ import  static  org .elasticsearch .compute .gen .Methods .requireAnyType ;
39+ import  static  org .elasticsearch .compute .gen .Methods .requireArgs ;
40+ import  static  org .elasticsearch .compute .gen .Methods .requireName ;
41+ import  static  org .elasticsearch .compute .gen .Methods .requirePrimitiveOrImplements ;
42+ import  static  org .elasticsearch .compute .gen .Methods .requireStaticMethod ;
43+ import  static  org .elasticsearch .compute .gen .Methods .requireType ;
44+ import  static  org .elasticsearch .compute .gen .Methods .requireVoidType ;
3845import  static  org .elasticsearch .compute .gen .Methods .vectorAccessorName ;
3946import  static  org .elasticsearch .compute .gen .Types .AGGREGATOR_FUNCTION ;
4047import  static  org .elasticsearch .compute .gen .Types .BIG_ARRAYS ;
@@ -66,8 +73,6 @@ public class AggregatorImplementer {
6673    private  final  List <TypeMirror > warnExceptions ;
6774    private  final  ExecutableElement  init ;
6875    private  final  ExecutableElement  combine ;
69-     private  final  ExecutableElement  combineIntermediate ;
70-     private  final  ExecutableElement  evaluateFinal ;
7176    private  final  ClassName  implementation ;
7277    private  final  List <IntermediateStateDesc > intermediateState ;
7378    private  final  List <Parameter > createParameters ;
@@ -84,21 +89,24 @@ public AggregatorImplementer(
8489        this .declarationType  = declarationType ;
8590        this .warnExceptions  = warnExceptions ;
8691
87-         this .init  = findRequiredMethod (declarationType , new  String [] { "init" , "initSingle"  }, e  -> true );
92+         this .init  = requireStaticMethod (
93+             declarationType ,
94+             requirePrimitiveOrImplements (elements , Types .RELEASABLE ),// This should be more restrictive 
95+                                                                      // org.elasticsearch.compute.aggregation.AggregatorState 
96+             requireName ("init" , "initSingle" ),
97+             requireAnyArgs ("<arbitrary init arguments>" )
98+         );
8899        this .aggState  = AggregationState .create (elements , init .getReturnType (), warnExceptions .isEmpty () == false , false );
89100
90-         this .combine  = findRequiredMethod (declarationType , new  String [] { "combine"  }, e  -> {
91-             if  (e .getParameters ().size () == 0 ) {
92-                 return  false ;
93-             }
94-             TypeName  firstParamType  = TypeName .get (e .getParameters ().get (0 ).asType ());
95-             return  Objects .equals (firstParamType .toString (), aggState .declaredType ().toString ());
96-         });
101+         this .combine  = requireStaticMethod (
102+             declarationType ,
103+             aggState .declaredType ().isPrimitive () ? requireType (aggState .declaredType ()) : requireVoidType (),
104+             requireName ("combine" ),
105+             requireArgs (requireType (aggState .declaredType ()), requireAnyType ("<aggregation input column type>" ))
106+         );
97107        // TODO support multiple parameters 
98108        this .aggParam  = AggregationParameter .create (combine .getParameters ().get (1 ).asType ());
99109
100-         this .combineIntermediate  = findMethod (declarationType , "combineIntermediate" );
101-         this .evaluateFinal  = findMethod (declarationType , "evaluateFinal" );
102110        this .createParameters  = init .getParameters ()
103111            .stream ()
104112            .map (Parameter ::from )
@@ -447,12 +455,7 @@ private MethodSpec addIntermediateInput() {
447455            interState .assignToVariable (builder , i );
448456            builder .addStatement ("assert $L.getPositionCount() == 1" , interState .name ());
449457        }
450-         if  (combineIntermediate  != null ) {
451-             if  (intermediateState .stream ().map (IntermediateStateDesc ::elementType ).anyMatch (n  -> n .equals ("BYTES_REF" ))) {
452-                 builder .addStatement ("$T scratch = new $T()" , BYTES_REF , BYTES_REF );
453-             }
454-             builder .addStatement ("$T.combineIntermediate(state, "  + intermediateStateRowAccess () + ")" , declarationType );
455-         } else  if  (aggState .declaredType ().isPrimitive ()) {
458+         if  (aggState .declaredType ().isPrimitive ()) {
456459            if  (warnExceptions .isEmpty ()) {
457460                assert  intermediateState .size () == 2 ;
458461                assert  intermediateState .get (1 ).name ().equals ("seen" );
@@ -485,7 +488,20 @@ private MethodSpec addIntermediateInput() {
485488            });
486489            builder .endControlFlow ();
487490        } else  {
488-             throw  new  IllegalArgumentException ("Don't know how to combine intermediate input. Define combineIntermediate" );
491+             requireStaticMethod (
492+                 declarationType ,
493+                 requireVoidType (),
494+                 requireName ("combineIntermediate" ),
495+                 requireArgs (
496+                     Stream .concat (Stream .of (aggState .declaredType ()), intermediateState .stream ().map (IntermediateStateDesc ::combineArgType ))
497+                         .map (Methods ::requireType )
498+                         .toArray (TypeMatcher []::new )
499+                 )
500+             );
501+             if  (intermediateState .stream ().map (IntermediateStateDesc ::elementType ).anyMatch (n  -> n .equals ("BYTES_REF" ))) {
502+                 builder .addStatement ("$T scratch = new $T()" , BYTES_REF , BYTES_REF );
503+             }
504+             builder .addStatement ("$T.combineIntermediate(state, "  + intermediateStateRowAccess () + ")" , declarationType );
489505        }
490506        return  builder .build ();
491507    }
@@ -524,7 +540,7 @@ private MethodSpec evaluateFinal() {
524540            builder .addStatement ("return" );
525541            builder .endControlFlow ();
526542        }
527-         if  (evaluateFinal  ==  null ) {
543+         if  (aggState . declaredType (). isPrimitive () ) {
528544            builder .addStatement (switch  (aggState .declaredType ().toString ()) {
529545                case  "boolean"  -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)" ;
530546                case  "int"  -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)" ;
@@ -534,6 +550,12 @@ private MethodSpec evaluateFinal() {
534550                default  -> throw  new  IllegalArgumentException ("Unexpected primitive type: ["  + aggState .declaredType () + "]" );
535551            });
536552        } else  {
553+             requireStaticMethod (
554+                 declarationType ,
555+                 requireType (BLOCK ),
556+                 requireName ("evaluateFinal" ),
557+                 requireArgs (requireType (aggState .declaredType ()), requireType (DRIVER_CONTEXT ))
558+             );
537559            builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, driverContext)" , declarationType );
538560        }
539561        return  builder .build ();
@@ -593,6 +615,11 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
593615                builder .addStatement ("$T $L = (($T) $L).asVector()" , vectorType (elementType ), name , blockType , name  + "Uncast" );
594616            }
595617        }
618+ 
619+         public  TypeName  combineArgType () {
620+             var  type  = Types .fromString (elementType );
621+             return  block  ? blockType (type ) : type ;
622+         }
596623    }
597624
598625    /** 
0 commit comments