Skip to content

Commit 54634a1

Browse files
committed
handle methods in grouping aggregator
1 parent 2bfc745 commit 54634a1

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
3737
import static org.elasticsearch.compute.gen.Methods.findMethod;
3838
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
39+
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
40+
import static org.elasticsearch.compute.gen.Methods.requireArgs;
41+
import static org.elasticsearch.compute.gen.Methods.requireName;
42+
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
43+
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
44+
import static org.elasticsearch.compute.gen.Methods.requireType;
45+
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
3946
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
4047
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
48+
import static org.elasticsearch.compute.gen.Types.BLOCK;
4149
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
4250
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
4351
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
@@ -72,8 +80,6 @@ public class GroupingAggregatorImplementer {
7280
private final ExecutableElement init;
7381
private final ExecutableElement combine;
7482
private final ExecutableElement combineStates;
75-
private final ExecutableElement evaluateFinal;
76-
private final ExecutableElement combineIntermediate;
7783
private final List<Parameter> createParameters;
7884
private final ClassName implementation;
7985
private final List<AggregatorImplementer.IntermediateStateDesc> intermediateState;
@@ -92,9 +98,16 @@ public GroupingAggregatorImplementer(
9298
this.declarationType = declarationType;
9399
this.warnExceptions = warnExceptions;
94100

95-
this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true);
101+
this.init = requireStaticMethod(
102+
declarationType,
103+
// This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
104+
requirePrimitiveOrImplements(elements, Types.RELEASABLE),
105+
requireName("init", "initGrouping"),
106+
requireAnyArgs("<arbitrary init arguments>")
107+
);
96108
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true);
97109

110+
// TODO optional timestamp
98111
this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
99112
if (e.getParameters().size() == 0) {
100113
return false;
@@ -106,8 +119,6 @@ public GroupingAggregatorImplementer(
106119
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());
107120

108121
this.combineStates = findMethod(declarationType, "combineStates");
109-
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
110-
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
111122
this.createParameters = init.getParameters()
112123
.stream()
113124
.map(Parameter::from)
@@ -557,6 +568,7 @@ private MethodSpec addIntermediateInput() {
557568
});
558569
builder.endControlFlow();
559570
} else {
571+
// TODO combineIntermediate with optional block parameter
560572
builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType);
561573
}
562574
builder.endControlFlow();
@@ -572,16 +584,6 @@ String intermediateStateRowAccess() {
572584
return rowAccess;
573585
}
574586

575-
private void combineStates(MethodSpec.Builder builder) {
576-
if (combineStates == null) {
577-
builder.beginControlFlow("if (inState.hasValue(position))");
578-
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
579-
builder.endControlFlow();
580-
return;
581-
}
582-
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
583-
}
584-
585587
private MethodSpec addIntermediateRowInput() {
586588
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput");
587589
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
@@ -593,7 +595,24 @@ private MethodSpec addIntermediateRowInput() {
593595
builder.endControlFlow();
594596
builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation);
595597
builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
596-
combineStates(builder);
598+
if (aggState.declaredType().isPrimitive()) {
599+
builder.beginControlFlow("if (inState.hasValue(position))");
600+
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
601+
builder.endControlFlow();
602+
} else {
603+
requireStaticMethod(
604+
declarationType,
605+
requireVoidType(),
606+
requireName("combineStates"),
607+
requireArgs(
608+
requireType(aggState.declaredType()),
609+
requireType(TypeName.INT),
610+
requireType(aggState.declaredType()),
611+
requireType(TypeName.INT)
612+
)
613+
);
614+
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
615+
}
597616
return builder.build();
598617
}
599618

@@ -617,9 +636,15 @@ private MethodSpec evaluateFinal() {
617636
.addParameter(INT_VECTOR, "selected")
618637
.addParameter(DRIVER_CONTEXT, "driverContext");
619638

620-
if (evaluateFinal == null) {
639+
if (aggState.declaredType().isPrimitive()) {
621640
builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)");
622641
} else {
642+
requireStaticMethod(
643+
declarationType,
644+
requireType(BLOCK),
645+
requireName("evaluateFinal"),
646+
requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT))
647+
);
623648
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType);
624649
}
625650
return builder.build();

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import javax.lang.model.element.Modifier;
2323
import javax.lang.model.element.TypeElement;
2424
import javax.lang.model.type.DeclaredType;
25+
import javax.lang.model.type.TypeKind;
2526
import javax.lang.model.type.TypeMirror;
2627
import javax.lang.model.util.ElementFilter;
2728
import javax.lang.model.util.Elements;
@@ -149,11 +150,13 @@ private static boolean isImplementing(Elements elements, TypeName type, TypeName
149150
}
150151

151152
private static Stream<TypeName> allInterfacesOf(Elements elements, TypeName type) {
152-
return elements.getTypeElement(type.toString())
153-
.getInterfaces()
154-
.stream()
155-
.map(TypeName::get)
156-
.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)));
153+
var typeElement = elements.getTypeElement(type.toString());
154+
var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get);
155+
var interfaces = typeElement.getInterfaces().stream().map(TypeName::get);
156+
return Stream.concat(
157+
superType.flatMap(sType -> allInterfacesOf(elements, sType)),
158+
interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)))
159+
);
157160
}
158161

159162
private static Stream<TypeElement> typeAndSuperType(TypeElement declarationType) {

0 commit comments

Comments
 (0)