5050import static org .elasticsearch .compute .gen .Types .BYTES_REF ;
5151import static org .elasticsearch .compute .gen .Types .DRIVER_CONTEXT ;
5252import static org .elasticsearch .compute .gen .Types .ELEMENT_TYPE ;
53+ import static org .elasticsearch .compute .gen .Types .GROUPING_AGGREGATOR_EVALUATOR_CONTEXT ;
5354import static org .elasticsearch .compute .gen .Types .GROUPING_AGGREGATOR_FUNCTION ;
5455import static org .elasticsearch .compute .gen .Types .GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT ;
5556import static org .elasticsearch .compute .gen .Types .INTERMEDIATE_STATE_DESC ;
@@ -82,7 +83,7 @@ public class GroupingAggregatorImplementer {
8283 private final List <Parameter > createParameters ;
8384 private final ClassName implementation ;
8485 private final List <AggregatorImplementer .IntermediateStateDesc > intermediateState ;
85- private final boolean includeTimestampVector ;
86+ private final boolean timseries ;
8687
8788 private final AggregationState aggState ;
8889 private final AggregationParameter aggParam ;
@@ -92,7 +93,7 @@ public GroupingAggregatorImplementer(
9293 TypeElement declarationType ,
9394 IntermediateState [] interStateAnno ,
9495 List <TypeMirror > warnExceptions ,
95- boolean includeTimestampVector
96+ boolean timseries
9697 ) {
9798 this .declarationType = declarationType ;
9899 this .warnExceptions = warnExceptions ;
@@ -109,7 +110,7 @@ public GroupingAggregatorImplementer(
109110 declarationType ,
110111 aggState .declaredType ().isPrimitive () ? requireType (aggState .declaredType ()) : requireVoidType (),
111112 requireName ("combine" ),
112- combineArgs (aggState , includeTimestampVector )
113+ combineArgs (aggState , timseries )
113114 );
114115 // TODO support multiple parameters
115116 this .aggParam = AggregationParameter .create (combine .getParameters ().getLast ().asType ());
@@ -128,7 +129,7 @@ public GroupingAggregatorImplementer(
128129 this .intermediateState = Arrays .stream (interStateAnno )
129130 .map (AggregatorImplementer .IntermediateStateDesc ::newIntermediateStateDesc )
130131 .toList ();
131- this .includeTimestampVector = includeTimestampVector ;
132+ this .timseries = timseries ;
132133 }
133134
134135 private static Methods .ArgumentMatcher combineArgs (AggregationState aggState , boolean includeTimestampVector ) {
@@ -318,7 +319,7 @@ private MethodSpec prepareProcessPage() {
318319
319320 builder .addStatement ("$T valuesBlock = page.getBlock(channels.get(0))" , blockType (aggParam .type ()));
320321 builder .addStatement ("$T valuesVector = valuesBlock.asVector()" , vectorType (aggParam .type ()));
321- if (includeTimestampVector ) {
322+ if (timseries ) {
322323 builder .addStatement ("$T timestampsBlock = page.getBlock(channels.get(1))" , LONG_BLOCK );
323324 builder .addStatement ("$T timestampsVector = timestampsBlock.asVector()" , LONG_VECTOR );
324325
@@ -327,7 +328,7 @@ private MethodSpec prepareProcessPage() {
327328 builder .endControlFlow ();
328329 }
329330 builder .beginControlFlow ("if (valuesVector == null)" );
330- String extra = includeTimestampVector ? ", timestampsVector" : "" ;
331+ String extra = timseries ? ", timestampsVector" : "" ;
331332 {
332333 builder .beginControlFlow ("if (valuesBlock.mayHaveNulls())" );
333334 builder .addStatement ("state.enableGroupIdTracking(seenGroupIds)" );
@@ -373,7 +374,7 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
373374 MethodSpec .Builder builder = MethodSpec .methodBuilder ("addRawInput" );
374375 builder .addModifiers (Modifier .PRIVATE );
375376 builder .addParameter (TypeName .INT , "positionOffset" ).addParameter (groupsType , "groups" ).addParameter (valuesType , "values" );
376- if (includeTimestampVector ) {
377+ if (timseries ) {
377378 builder .addParameter (LONG_VECTOR , "timestamps" );
378379 }
379380 if (aggParam .isBytesRef ()) {
@@ -456,7 +457,7 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
456457
457458 private void combineRawInputForBytesRef (MethodSpec .Builder builder , String blockVariable , String offsetVariable ) {
458459 // scratch is a BytesRef var that must have been defined before the iteration starts
459- if (includeTimestampVector ) {
460+ if (timseries ) {
460461 if (offsetVariable .contains (" + " )) {
461462 builder .addStatement ("var valuePosition = $L" , offsetVariable );
462463 offsetVariable = "valuePosition" ;
@@ -474,7 +475,7 @@ private void combineRawInputForBytesRef(MethodSpec.Builder builder, String block
474475 }
475476
476477 private void combineRawInputForPrimitive (MethodSpec .Builder builder , String blockVariable , String offsetVariable ) {
477- if (includeTimestampVector ) {
478+ if (timseries ) {
478479 if (offsetVariable .contains (" + " )) {
479480 builder .addStatement ("var valuePosition = $L" , offsetVariable );
480481 offsetVariable = "valuePosition" ;
@@ -498,7 +499,7 @@ private void combineRawInputForPrimitive(MethodSpec.Builder builder, String bloc
498499 }
499500
500501 private void combineRawInputForVoid (MethodSpec .Builder builder , String blockVariable , String offsetVariable ) {
501- if (includeTimestampVector ) {
502+ if (timseries ) {
502503 if (offsetVariable .contains (" + " )) {
503504 builder .addStatement ("var valuePosition = $L" , offsetVariable );
504505 offsetVariable = "valuePosition" ;
@@ -683,18 +684,30 @@ private MethodSpec evaluateFinal() {
683684 .addParameter (BLOCK_ARRAY , "blocks" )
684685 .addParameter (TypeName .INT , "offset" )
685686 .addParameter (INT_VECTOR , "selected" )
686- .addParameter (DRIVER_CONTEXT , "driverContext " );
687+ .addParameter (GROUPING_AGGREGATOR_EVALUATOR_CONTEXT , "evaluatorContext " );
687688
688689 if (aggState .declaredType ().isPrimitive ()) {
689- builder .addStatement ("blocks[offset] = state.toValuesBlock(selected, driverContext)" );
690+ builder .addStatement ("blocks[offset] = state.toValuesBlock(selected, evaluatorContext.driverContext())" );
691+ } else if (timseries ) {
692+ requireStaticMethod (
693+ declarationType ,
694+ requireType (BLOCK ),
695+ requireName ("evaluateFinal" ),
696+ requireArgs (
697+ requireType (aggState .declaredType ()),
698+ requireType (INT_VECTOR ),
699+ requireType (GROUPING_AGGREGATOR_EVALUATOR_CONTEXT )
700+ )
701+ );
702+ builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, evaluatorContext)" , declarationType );
690703 } else {
691704 requireStaticMethod (
692705 declarationType ,
693706 requireType (BLOCK ),
694707 requireName ("evaluateFinal" ),
695708 requireArgs (requireType (aggState .declaredType ()), requireType (INT_VECTOR ), requireType (DRIVER_CONTEXT ))
696709 );
697- builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)" , declarationType );
710+ builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, evaluatorContext. driverContext() )" , declarationType );
698711 }
699712 return builder .build ();
700713 }
0 commit comments