22
22
23
23
import java .util .Arrays ;
24
24
import java .util .List ;
25
- import java .util .Objects ;
26
25
import java .util .function .Consumer ;
26
+ import java .util .function .Function ;
27
27
import java .util .stream .Collectors ;
28
+ import java .util .stream .Stream ;
28
29
29
30
import javax .lang .model .element .ExecutableElement ;
30
31
import javax .lang .model .element .Modifier ;
34
35
35
36
import static java .util .stream .Collectors .joining ;
36
37
import static org .elasticsearch .compute .gen .AggregatorImplementer .capitalize ;
37
- import static org .elasticsearch .compute .gen .Methods .findMethod ;
38
- import static org .elasticsearch .compute .gen .Methods .findRequiredMethod ;
38
+ import static org .elasticsearch .compute .gen .Methods .requireAnyArgs ;
39
+ import static org .elasticsearch .compute .gen .Methods .requireAnyType ;
40
+ import static org .elasticsearch .compute .gen .Methods .requireArgs ;
41
+ import static org .elasticsearch .compute .gen .Methods .requireName ;
42
+ import static org .elasticsearch .compute .gen .Methods .requirePrimitiveOrImplements ;
43
+ import static org .elasticsearch .compute .gen .Methods .requireStaticMethod ;
44
+ import static org .elasticsearch .compute .gen .Methods .requireType ;
45
+ import static org .elasticsearch .compute .gen .Methods .requireVoidType ;
39
46
import static org .elasticsearch .compute .gen .Methods .vectorAccessorName ;
40
47
import static org .elasticsearch .compute .gen .Types .BIG_ARRAYS ;
48
+ import static org .elasticsearch .compute .gen .Types .BLOCK ;
41
49
import static org .elasticsearch .compute .gen .Types .BLOCK_ARRAY ;
42
50
import static org .elasticsearch .compute .gen .Types .BYTES_REF ;
43
51
import static org .elasticsearch .compute .gen .Types .DRIVER_CONTEXT ;
@@ -71,9 +79,6 @@ public class GroupingAggregatorImplementer {
71
79
private final List <TypeMirror > warnExceptions ;
72
80
private final ExecutableElement init ;
73
81
private final ExecutableElement combine ;
74
- private final ExecutableElement combineStates ;
75
- private final ExecutableElement evaluateFinal ;
76
- private final ExecutableElement combineIntermediate ;
77
82
private final List <Parameter > createParameters ;
78
83
private final ClassName implementation ;
79
84
private final List <AggregatorImplementer .IntermediateStateDesc > intermediateState ;
@@ -92,22 +97,23 @@ public GroupingAggregatorImplementer(
92
97
this .declarationType = declarationType ;
93
98
this .warnExceptions = warnExceptions ;
94
99
95
- this .init = findRequiredMethod (declarationType , new String [] { "init" , "initGrouping" }, e -> true );
100
+ this .init = requireStaticMethod (
101
+ declarationType ,
102
+ // This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState
103
+ requirePrimitiveOrImplements (elements , Types .RELEASABLE ),
104
+ requireName ("init" , "initGrouping" ),
105
+ requireAnyArgs ("<arbitrary init arguments>" )
106
+ );
96
107
this .aggState = AggregationState .create (elements , init .getReturnType (), warnExceptions .isEmpty () == false , true );
97
108
98
- this .combine = findRequiredMethod (declarationType , new String [] { "combine" }, e -> {
99
- if (e .getParameters ().size () == 0 ) {
100
- return false ;
101
- }
102
- TypeName firstParamType = TypeName .get (e .getParameters ().get (0 ).asType ());
103
- return Objects .equals (firstParamType .toString (), aggState .declaredType ().toString ());
104
- });
105
- // TODO support multiple parameters
109
+ this .combine = requireStaticMethod (
110
+ declarationType ,
111
+ aggState .declaredType ().isPrimitive () ? requireType (aggState .declaredType ()) : requireVoidType (),
112
+ requireName ("combine" ),
113
+ combineArgs (aggState , includeTimestampVector )
114
+ );
106
115
this .aggParam = AggregationParameter .create (combine .getParameters ().get (combine .getParameters ().size () - 1 ).asType ());
107
116
108
- this .combineStates = findMethod (declarationType , "combineStates" );
109
- this .combineIntermediate = findMethod (declarationType , "combineIntermediate" );
110
- this .evaluateFinal = findMethod (declarationType , "evaluateFinal" );
111
117
this .createParameters = init .getParameters ()
112
118
.stream ()
113
119
.map (Parameter ::from )
@@ -125,6 +131,25 @@ public GroupingAggregatorImplementer(
125
131
this .includeTimestampVector = includeTimestampVector ;
126
132
}
127
133
134
+ private static Methods .ArgumentMatcher combineArgs (AggregationState aggState , boolean includeTimestampVector ) {
135
+ if (aggState .declaredType ().isPrimitive ()) {
136
+ return requireArgs (requireType (aggState .declaredType ()), requireAnyType ("<aggregation input column type>" ));
137
+ } else if (includeTimestampVector ) {
138
+ return requireArgs (
139
+ requireType (aggState .declaredType ()),
140
+ requireType (TypeName .INT ),
141
+ requireType (TypeName .LONG ), // @timestamp
142
+ requireAnyType ("<aggregation input column type>" )
143
+ );
144
+ } else {
145
+ return requireArgs (
146
+ requireType (aggState .declaredType ()),
147
+ requireType (TypeName .INT ),
148
+ requireAnyType ("<aggregation input column type>" )
149
+ );
150
+ }
151
+ }
152
+
128
153
public ClassName implementation () {
129
154
return implementation ;
130
155
}
@@ -557,31 +582,33 @@ private MethodSpec addIntermediateInput() {
557
582
});
558
583
builder .endControlFlow ();
559
584
} else {
560
- builder .addStatement ("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess () + ")" , declarationType );
585
+ var stateHasBlock = intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block );
586
+ requireStaticMethod (
587
+ declarationType ,
588
+ requireVoidType (),
589
+ requireName ("combineIntermediate" ),
590
+ requireArgs (
591
+ Stream .of (
592
+ Stream .of (aggState .declaredType (), TypeName .INT ), // aggState and groupId
593
+ intermediateState .stream ().map (AggregatorImplementer .IntermediateStateDesc ::combineArgType ),
594
+ Stream .of (TypeName .INT ).filter (p -> stateHasBlock ) // position
595
+ ).flatMap (Function .identity ()).map (Methods ::requireType ).toArray (Methods .TypeMatcher []::new )
596
+ )
597
+ );
598
+
599
+ builder .addStatement (
600
+ "$T.combineIntermediate(state, groupId, "
601
+ + intermediateState .stream ().map (desc -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ))
602
+ + (stateHasBlock ? ", groupPosition + positionOffset" : "" )
603
+ + ")" ,
604
+ declarationType
605
+ );
561
606
}
562
607
builder .endControlFlow ();
563
608
}
564
609
return builder .build ();
565
610
}
566
611
567
- String intermediateStateRowAccess () {
568
- String rowAccess = intermediateState .stream ().map (desc -> desc .access ("groupPosition + positionOffset" )).collect (joining (", " ));
569
- if (intermediateState .stream ().anyMatch (AggregatorImplementer .IntermediateStateDesc ::block )) {
570
- rowAccess += ", groupPosition + positionOffset" ;
571
- }
572
- return rowAccess ;
573
- }
574
-
575
- private void combineStates (MethodSpec .Builder builder ) {
576
- if (combineStates == null ) {
577
- builder .beginControlFlow ("if (inState.hasValue(position))" );
578
- builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
579
- builder .endControlFlow ();
580
- return ;
581
- }
582
- builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
583
- }
584
-
585
612
private MethodSpec addIntermediateRowInput () {
586
613
MethodSpec .Builder builder = MethodSpec .methodBuilder ("addIntermediateRowInput" );
587
614
builder .addAnnotation (Override .class ).addModifiers (Modifier .PUBLIC );
@@ -593,7 +620,24 @@ private MethodSpec addIntermediateRowInput() {
593
620
builder .endControlFlow ();
594
621
builder .addStatement ("$T inState = (($T) input).state" , aggState .type (), implementation );
595
622
builder .addStatement ("state.enableGroupIdTracking(new $T.Empty())" , SEEN_GROUP_IDS );
596
- combineStates (builder );
623
+ if (aggState .declaredType ().isPrimitive ()) {
624
+ builder .beginControlFlow ("if (inState.hasValue(position))" );
625
+ builder .addStatement ("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))" , declarationType );
626
+ builder .endControlFlow ();
627
+ } else {
628
+ requireStaticMethod (
629
+ declarationType ,
630
+ requireVoidType (),
631
+ requireName ("combineStates" ),
632
+ requireArgs (
633
+ requireType (aggState .declaredType ()),
634
+ requireType (TypeName .INT ),
635
+ requireType (aggState .declaredType ()),
636
+ requireType (TypeName .INT )
637
+ )
638
+ );
639
+ builder .addStatement ("$T.combineStates(state, groupId, inState, position)" , declarationType );
640
+ }
597
641
return builder .build ();
598
642
}
599
643
@@ -617,9 +661,15 @@ private MethodSpec evaluateFinal() {
617
661
.addParameter (INT_VECTOR , "selected" )
618
662
.addParameter (DRIVER_CONTEXT , "driverContext" );
619
663
620
- if (evaluateFinal == null ) {
664
+ if (aggState . declaredType (). isPrimitive () ) {
621
665
builder .addStatement ("blocks[offset] = state.toValuesBlock(selected, driverContext)" );
622
666
} else {
667
+ requireStaticMethod (
668
+ declarationType ,
669
+ requireType (BLOCK ),
670
+ requireName ("evaluateFinal" ),
671
+ requireArgs (requireType (aggState .declaredType ()), requireType (INT_VECTOR ), requireType (DRIVER_CONTEXT ))
672
+ );
623
673
builder .addStatement ("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)" , declarationType );
624
674
}
625
675
return builder .build ();
0 commit comments