2222
2323import java .util .Arrays ;
2424import java .util .List ;
25- import java .util .Objects ;
2625import java .util .function .Consumer ;
26+ import java .util .function .Function ;
2727import java .util .stream .Collectors ;
28+ import java .util .stream .Stream ;
2829
2930import javax .lang .model .element .ExecutableElement ;
3031import javax .lang .model .element .Modifier ;
3435
3536import static java .util .stream .Collectors .joining ;
3637import static org .elasticsearch .compute .gen .AggregatorImplementer .capitalize ;
37- import static org .elasticsearch .compute .gen .Methods .findMethod ;
38- import static org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
3938import static org .elasticsearch .compute .gen .Methods .requireAnyArgs ;
39+ import static org .elasticsearch .compute .gen .Methods .requireAnyType ;
4040import static org .elasticsearch .compute .gen .Methods .requireArgs ;
4141import static org .elasticsearch .compute .gen .Methods .requireName ;
4242import static org .elasticsearch .compute .gen .Methods .requirePrimitiveOrImplements ;
@@ -79,7 +79,6 @@ public class GroupingAggregatorImplementer {
7979 private final List <TypeMirror > warnExceptions ;
8080 private final ExecutableElement init ;
8181 private final ExecutableElement combine ;
82- private final ExecutableElement combineStates ;
8382 private final List <Parameter > createParameters ;
8483 private final ClassName implementation ;
8584 private final List <AggregatorImplementer .IntermediateStateDesc > intermediateState ;
@@ -100,25 +99,21 @@ public GroupingAggregatorImplementer(
10099
101100 this .init = requireStaticMethod (
102101 declarationType ,
103- // This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
102+ // This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState
104103 requirePrimitiveOrImplements (elements , Types .RELEASABLE ),
105104 requireName ("init" , "initGrouping" ),
106105 requireAnyArgs ("<arbitrary init arguments>" )
107106 );
108107 this .aggState = AggregationState .create (elements , init .getReturnType (), warnExceptions .isEmpty () == false , true );
109108
110- // TODO optional timestamp
111- this .combine = findRequiredMethod (declarationType , new String [] { "combine" }, e -> {
112- if (e .getParameters ().size () == 0 ) {
113- return false ;
114- }
115- TypeName firstParamType = TypeName .get (e .getParameters ().get (0 ).asType ());
116- return Objects .equals (firstParamType .toString (), aggState .declaredType ().toString ());
117- });
118- // TODO support multiple parameters
109+ this .combine = requireStaticMethod (
110+ declarationType ,
111+ aggState .declaredType ().isPrimitive () ? requireType (aggState .declaredType ()) : requireVoidType (),
112+ requireName ("combine" ),
113+ combineArgs (aggState , includeTimestampVector )
114+ );
119115 this .aggParam = AggregationParameter .create (combine .getParameters ().get (combine .getParameters ().size () - 1 ).asType ());
120116
121- this .combineStates = findMethod (declarationType , "combineStates" );
122117 this .createParameters = init .getParameters ()
123118 .stream ()
124119 .map (Parameter ::from )
@@ -136,6 +131,25 @@ public GroupingAggregatorImplementer(
136131 this .includeTimestampVector = includeTimestampVector ;
137132 }
138133
134+ private static Methods .ArgumentMatcher combineArgs (AggregationState aggState , boolean includeTimestampVector ) {
135+ if (aggState .declaredType ().isPrimitive ()) {
136+ return requireArgs (requireType (aggState .declaredType ()), requireAnyType ("<aggregation input column type>" ));
137+ } else if (includeTimestampVector ) {
138+ return requireArgs (
139+ requireType (aggState .declaredType ()),
140+ requireType (TypeName .INT ),
141+ requireType (TypeName .LONG ), // @timestamp
142+ requireAnyType ("<aggregation input column type>" )
143+ );
144+ } else {
145+ return requireArgs (
146+ requireType (aggState .declaredType ()),
147+ requireType (TypeName .INT ),
148+ requireAnyType ("<aggregation input column type>" )
149+ );
150+ }
151+ }
152+
139153 public ClassName implementation () {
140154 return implementation ;
141155 }
@@ -568,22 +582,33 @@ private MethodSpec addIntermediateInput() {
568582 });
569583 builder .endControlFlow ();
570584 } else {
571- // TODO combineIntermediate with optional block parameter
572- builder .addStatement ("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess () + ")" , declarationType );
585+ var stateHasBlock = intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block );
586+ requireStaticMethod (
587+ declarationType ,
588+ requireVoidType (),
589+ requireName ("combineIntermediate" ),
590+ requireArgs (
591+ Stream .of (
592+ Stream .of (aggState .declaredType (), TypeName .INT ), // aggState and groupId
593+ intermediateState .stream ().map (AggregatorImplementer .IntermediateStateDesc ::combineArgType ),
594+ Stream .of (TypeName .INT ).filter (p -> stateHasBlock ) // position
595+ ).flatMap (Function .identity ()).map (Methods ::requireType ).toArray (Methods .TypeMatcher []::new )
596+ )
597+ );
598+
599+ builder .addStatement (
600+ "$T.combineIntermediate(state, groupId, "
601+ + intermediateState .stream ().map (desc -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ))
602+ + (stateHasBlock ? ", groupPosition + positionOffset" : "" )
603+ + ")" ,
604+ declarationType
605+ );
573606 }
574607 builder .endControlFlow ();
575608 }
576609 return builder .build ();
577610 }
578611
579- String intermediateStateRowAccess () {
580- String rowAccess = intermediateState .stream ().map (desc -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ));
581- if (intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block )) {
582- rowAccess += ", groupPosition + positionOffset" ;
583- }
584- return rowAccess ;
585- }
586-
587612 private MethodSpec addIntermediateRowInput () {
588613 MethodSpec .Builder builder = MethodSpec .methodBuilder ("addIntermediateRowInput" );
589614 builder .addAnnotation (Override .class ).addModifiers (Modifier .PUBLIC );
0 commit comments