3636import static org .elasticsearch .compute .gen .AggregatorImplementer .capitalize ;
3737import static org .elasticsearch .compute .gen .Methods .findMethod ;
3838import static org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
39+ import static org .elasticsearch .compute .gen .Methods .requireAnyArgs ;
40+ import static org .elasticsearch .compute .gen .Methods .requireArgs ;
41+ import static org .elasticsearch .compute .gen .Methods .requireName ;
42+ import static org .elasticsearch .compute .gen .Methods .requirePrimitiveOrImplements ;
43+ import static org .elasticsearch .compute .gen .Methods .requireStaticMethod ;
44+ import static org .elasticsearch .compute .gen .Methods .requireType ;
45+ import static org .elasticsearch .compute .gen .Methods .requireVoidType ;
3946import static org .elasticsearch .compute .gen .Methods .vectorAccessorName ;
4047import static org .elasticsearch .compute .gen .Types .BIG_ARRAYS ;
48+ import static org .elasticsearch .compute .gen .Types .BLOCK ;
4149import static org .elasticsearch .compute .gen .Types .BLOCK_ARRAY ;
4250import static org .elasticsearch .compute .gen .Types .BYTES_REF ;
4351import static org .elasticsearch .compute .gen .Types .DRIVER_CONTEXT ;
@@ -72,8 +80,6 @@ public class GroupingAggregatorImplementer {
7280 private final ExecutableElement init ;
7381 private final ExecutableElement combine ;
7482 private final ExecutableElement combineStates ;
75- private final ExecutableElement evaluateFinal ;
76- private final ExecutableElement combineIntermediate ;
7783 private final List <Parameter > createParameters ;
7884 private final ClassName implementation ;
7985 private final List <AggregatorImplementer .IntermediateStateDesc > intermediateState ;
@@ -92,9 +98,16 @@ public GroupingAggregatorImplementer(
9298 this .declarationType = declarationType ;
9399 this .warnExceptions = warnExceptions ;
94100
95- this .init = findRequiredMethod (declarationType , new String [] { "init" , "initGrouping" }, e -> true );
101+ this .init = requireStaticMethod (
102+ declarationType ,
103+ // This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
104+ requirePrimitiveOrImplements (elements , Types .RELEASABLE ),
105+ requireName ("init" , "initGrouping" ),
106+ requireAnyArgs ("<arbitrary init arguments>" )
107+ );
96108 this .aggState = AggregationState .create (elements , init .getReturnType (), warnExceptions .isEmpty () == false , true );
97109
110+ // TODO optional timestamp
98111 this .combine = findRequiredMethod (declarationType , new String [] { "combine" }, e -> {
99112 if (e .getParameters ().size () == 0 ) {
100113 return false ;
@@ -106,8 +119,6 @@ public GroupingAggregatorImplementer(
106119 this .aggParam = AggregationParameter .create (combine .getParameters ().get (combine .getParameters ().size () - 1 ).asType ());
107120
108121 this .combineStates = findMethod (declarationType , "combineStates" );
109- this .combineIntermediate = findMethod (declarationType , "combineIntermediate" );
110- this .evaluateFinal = findMethod (declarationType , "evaluateFinal" );
111122 this .createParameters = init .getParameters ()
112123 .stream ()
113124 .map (Parameter ::from )
@@ -557,6 +568,7 @@ private MethodSpec addIntermediateInput() {
557568 });
558569 builder .endControlFlow ();
559570 } else {
571+ // TODO combineIntermediate with optional block parameter
560572 builder .addStatement ("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess () + ")" , declarationType );
561573 }
562574 builder .endControlFlow ();
@@ -572,16 +584,6 @@ String intermediateStateRowAccess() {
572584 return rowAccess ;
573585 }
574586
575- private void combineStates (MethodSpec .Builder builder ) {
576- if (combineStates == null ) {
577- builder .beginControlFlow ("if (inState.hasValue(position))" );
578- builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
579- builder .endControlFlow ();
580- return ;
581- }
582- builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
583- }
584-
585587 private MethodSpec addIntermediateRowInput () {
586588 MethodSpec .Builder builder = MethodSpec .methodBuilder ("addIntermediateRowInput" );
587589 builder .addAnnotation (Override .class ).addModifiers (Modifier .PUBLIC );
@@ -593,7 +595,24 @@ private MethodSpec addIntermediateRowInput() {
593595 builder .endControlFlow ();
594596 builder .addStatement ("$T inState = (($T) input).state" , aggState .type (), implementation );
595597 builder .addStatement ("state.enableGroupIdTracking(new $T.Empty())" , SEEN_GROUP_IDS );
596- combineStates (builder );
598+ if (aggState .declaredType ().isPrimitive ()) {
599+ builder .beginControlFlow ("if (inState.hasValue(position))" );
600+ builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
601+ builder .endControlFlow ();
602+ } else {
603+ requireStaticMethod (
604+ declarationType ,
605+ requireVoidType (),
606+ requireName ("combineStates" ),
607+ requireArgs (
608+ requireType (aggState .declaredType ()),
609+ requireType (TypeName .INT ),
610+ requireType (aggState .declaredType ()),
611+ requireType (TypeName .INT )
612+ )
613+ );
614+ builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
615+ }
597616 return builder .build ();
598617 }
599618
@@ -617,9 +636,15 @@ private MethodSpec evaluateFinal() {
617636 .addParameter (INT_VECTOR , "selected" )
618637 .addParameter (DRIVER_CONTEXT , "driverContext" );
619638
620- if (evaluateFinal == null ) {
639+ if (aggState . declaredType (). isPrimitive () ) {
621640 builder .addStatement ("blocks[offset] = state.toValuesBlock(selected, driverContext)" );
622641 } else {
642+ requireStaticMethod (
643+ declarationType ,
644+ requireType (BLOCK ),
645+ requireName ("evaluateFinal" ),
646+ requireArgs (requireType (aggState .declaredType ()), requireType (INT_VECTOR ), requireType (DRIVER_CONTEXT ))
647+ );
623648 builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)" , declarationType );
624649 }
625650 return builder .build ();
0 commit comments