3434import javax .lang .model .util .Elements ;
3535
3636import static java .util .stream .Collectors .joining ;
37- import static org .elasticsearch .compute .gen .Methods .findMethod ;
3837import static org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
38+ import static org .elasticsearch .compute .gen .Methods .requireMethod ;
3939import static org .elasticsearch .compute .gen .Methods .vectorAccessorName ;
4040import static org .elasticsearch .compute .gen .Types .AGGREGATOR_FUNCTION ;
4141import static org .elasticsearch .compute .gen .Types .BIG_ARRAYS ;
@@ -78,8 +78,6 @@ public class AggregatorImplementer {
7878 private final List <TypeMirror > warnExceptions ;
7979 private final ExecutableElement init ;
8080 private final ExecutableElement combine ;
81- private final ExecutableElement combineIntermediate ;
82- private final ExecutableElement evaluateFinal ;
8381 private final ClassName implementation ;
8482 private final TypeName stateType ;
8583 private final boolean stateTypeHasSeen ;
@@ -114,8 +112,6 @@ public AggregatorImplementer(
114112 TypeName firstParamType = TypeName .get (e .getParameters ().get (0 ).asType ());
115113 return firstParamType .isPrimitive () || firstParamType .toString ().equals (stateType .toString ());
116114 });
117- this .combineIntermediate = findMethod (declarationType , "combineIntermediate" );
118- this .evaluateFinal = findMethod (declarationType , "evaluateFinal" );
119115 this .createParameters = init .getParameters ()
120116 .stream ()
121117 .map (Parameter ::from )
@@ -198,9 +194,7 @@ static ClassName valueVectorType(ExecutableElement init, ExecutableElement combi
198194 }
199195
200196 public static String firstUpper (String s ) {
201- String head = s .toString ().substring (0 , 1 ).toUpperCase (Locale .ROOT );
202- String tail = s .toString ().substring (1 );
203- return head + tail ;
197+ return Character .toUpperCase (s .charAt (0 )) + s .substring (1 );
204198 }
205199
206200 public JavaFile sourceFile () {
@@ -526,12 +520,7 @@ private MethodSpec addIntermediateInput() {
526520 interState .assignToVariable (builder , i );
527521 builder .addStatement ("assert $L.getPositionCount() == 1" , interState .name ());
528522 }
529- if (combineIntermediate != null ) {
530- if (intermediateState .stream ().map (IntermediateStateDesc ::elementType ).anyMatch (n -> n .equals ("BYTES_REF" ))) {
531- builder .addStatement ("$T scratch = new $T()" , BYTES_REF , BYTES_REF );
532- }
533- builder .addStatement ("$T.combineIntermediate(state, " + intermediateStateRowAccess () + ")" , declarationType );
534- } else if (hasPrimitiveState ()) {
523+ if (hasPrimitiveState ()) {
535524 if (warnExceptions .isEmpty ()) {
536525 assert intermediateState .size () == 2 ;
537526 assert intermediateState .get (1 ).name ().equals ("seen" );
@@ -547,7 +536,6 @@ private MethodSpec addIntermediateInput() {
547536 }
548537 builder .nextControlFlow ("else if (seen.getBoolean(0))" );
549538 }
550-
551539 warningsBlock (builder , () -> {
552540 var state = intermediateState .get (0 );
553541 var s = "state.$L($T.combine(state.$L(), " + state .name () + "." + vectorAccessorName (state .elementType ()) + "(0)))" ;
@@ -556,32 +544,39 @@ private MethodSpec addIntermediateInput() {
556544 });
557545 builder .endControlFlow ();
558546 } else {
559- throw new IllegalArgumentException ("Don't know how to combine intermediate input. Define combineIntermediate" );
547+ requireMethod (
548+ declarationType ,
549+ "combineIntermediate" ,
550+ "void" ,
551+ Stream .concat (Stream .of (stateType .toString ()), intermediateState .stream ().map (intermediateStateDesc -> {
552+ var type = Types .fromString (intermediateStateDesc .elementType ());
553+ return intermediateStateDesc .block ? blockType (type ).toString () : type .toString ();
554+ })).toArray (String []::new )
555+ );
556+ if (intermediateState .stream ().map (IntermediateStateDesc ::elementType ).anyMatch (n -> n .equals ("BYTES_REF" ))) {
557+ builder .addStatement ("$T scratch = new $T()" , BYTES_REF , BYTES_REF );
558+ }
559+ builder .addStatement (
560+ "$T.combineIntermediate(state, " + intermediateState .stream ().map (desc -> desc .access ("0" )).collect (joining (", " )) + ")" ,
561+ declarationType
562+ );
560563 }
561564 return builder .build ();
562565 }
563566
564- String intermediateStateRowAccess () {
565- return intermediateState .stream ().map (desc -> desc .access ("0" )).collect (joining (", " ));
566- }
567-
568567 private String primitiveStateMethod () {
569- switch (stateType .toString ()) {
570- case "org.elasticsearch.compute.aggregation.BooleanState" , "org.elasticsearch.compute.aggregation.BooleanFallibleState" :
571- return "booleanValue" ;
572- case "org.elasticsearch.compute.aggregation.IntState" , "org.elasticsearch.compute.aggregation.IntFallibleState" :
573- return "intValue" ;
574- case "org.elasticsearch.compute.aggregation.LongState" , "org.elasticsearch.compute.aggregation.LongFallibleState" :
575- return "longValue" ;
576- case "org.elasticsearch.compute.aggregation.DoubleState" , "org.elasticsearch.compute.aggregation.DoubleFallibleState" :
577- return "doubleValue" ;
578- case "org.elasticsearch.compute.aggregation.FloatState" , "org.elasticsearch.compute.aggregation.FloatFallibleState" :
579- return "floatValue" ;
580- default :
581- throw new IllegalArgumentException (
582- "don't know how to fetch primitive values from " + stateType + ". define combineIntermediate."
583- );
584- }
568+ return switch (stateType .toString ()) {
569+ case "org.elasticsearch.compute.aggregation.BooleanState" , "org.elasticsearch.compute.aggregation.BooleanFallibleState" ->
570+ "booleanValue" ;
571+ case "org.elasticsearch.compute.aggregation.IntState" , "org.elasticsearch.compute.aggregation.IntFallibleState" -> "intValue" ;
572+ case "org.elasticsearch.compute.aggregation.LongState" , "org.elasticsearch.compute.aggregation.LongFallibleState" ->
573+ "longValue" ;
574+ case "org.elasticsearch.compute.aggregation.DoubleState" , "org.elasticsearch.compute.aggregation.DoubleFallibleState" ->
575+ "doubleValue" ;
576+ case "org.elasticsearch.compute.aggregation.FloatState" , "org.elasticsearch.compute.aggregation.FloatFallibleState" ->
577+ "floatValue" ;
578+ default -> throw new IllegalArgumentException ("don't know how to fetch primitive values from " + stateType + "." );
579+ };
585580 }
586581
587582 private MethodSpec evaluateIntermediate () {
@@ -611,9 +606,15 @@ private MethodSpec evaluateFinal() {
611606 builder .addStatement ("return" );
612607 builder .endControlFlow ();
613608 }
614- if (evaluateFinal == null ) {
609+ if (hasPrimitiveState () ) {
615610 primitiveStateToResult (builder );
616611 } else {
612+ requireMethod (
613+ declarationType ,
614+ "evaluateFinal" ,
615+ "org.elasticsearch.compute.data.Block" ,
616+ new String [] { stateType .toString (), "org.elasticsearch.compute.operator.DriverContext" }
617+ );
617618 builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, driverContext)" , declarationType );
618619 }
619620 return builder .build ();
0 commit comments