Skip to content

Commit 7deaacd

Browse files
authored
Update aggs code generation to be more explicit about required methods (#121749)
1 parent 98c7570 commit 7deaacd

File tree

3 files changed

+251
-93
lines changed

3 files changed

+251
-93
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import org.elasticsearch.compute.ann.Aggregator;
1919
import org.elasticsearch.compute.ann.IntermediateState;
20+
import org.elasticsearch.compute.gen.Methods.TypeMatcher;
2021

2122
import java.util.Arrays;
2223
import java.util.List;
@@ -33,8 +34,14 @@
3334
import javax.lang.model.util.Elements;
3435

3536
import static java.util.stream.Collectors.joining;
36-
import static org.elasticsearch.compute.gen.Methods.findMethod;
37-
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
37+
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
38+
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
39+
import static org.elasticsearch.compute.gen.Methods.requireArgs;
40+
import static org.elasticsearch.compute.gen.Methods.requireName;
41+
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
42+
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
43+
import static org.elasticsearch.compute.gen.Methods.requireType;
44+
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
3845
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
3946
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
4047
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
@@ -66,8 +73,6 @@ public class AggregatorImplementer {
6673
private final List<TypeMirror> warnExceptions;
6774
private final ExecutableElement init;
6875
private final ExecutableElement combine;
69-
private final ExecutableElement combineIntermediate;
70-
private final ExecutableElement evaluateFinal;
7176
private final ClassName implementation;
7277
private final List<IntermediateStateDesc> intermediateState;
7378
private final List<Parameter> createParameters;
@@ -84,21 +89,24 @@ public AggregatorImplementer(
8489
this.declarationType = declarationType;
8590
this.warnExceptions = warnExceptions;
8691

87-
this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true);
92+
this.init = requireStaticMethod(
93+
declarationType,
94+
// This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
95+
requirePrimitiveOrImplements(elements, Types.RELEASABLE),
96+
requireName("init", "initSingle"),
97+
requireAnyArgs("<arbitrary init arguments>")
98+
);
8899
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false);
89100

90-
this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
91-
if (e.getParameters().size() == 0) {
92-
return false;
93-
}
94-
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
95-
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
96-
});
101+
this.combine = requireStaticMethod(
102+
declarationType,
103+
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
104+
requireName("combine"),
105+
requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
106+
);
97107
// TODO support multiple parameters
98108
this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType());
99109

100-
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
101-
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
102110
this.createParameters = init.getParameters()
103111
.stream()
104112
.map(Parameter::from)
@@ -447,12 +455,7 @@ private MethodSpec addIntermediateInput() {
447455
interState.assignToVariable(builder, i);
448456
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
449457
}
450-
if (combineIntermediate != null) {
451-
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
452-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
453-
}
454-
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
455-
} else if (aggState.declaredType().isPrimitive()) {
458+
if (aggState.declaredType().isPrimitive()) {
456459
if (warnExceptions.isEmpty()) {
457460
assert intermediateState.size() == 2;
458461
assert intermediateState.get(1).name().equals("seen");
@@ -485,7 +488,21 @@ private MethodSpec addIntermediateInput() {
485488
});
486489
builder.endControlFlow();
487490
} else {
488-
throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
491+
requireStaticMethod(
492+
declarationType,
493+
requireVoidType(),
494+
requireName("combineIntermediate"),
495+
requireArgs(
496+
Stream.concat(
497+
Stream.of(aggState.declaredType()), // aggState
498+
intermediateState.stream().map(IntermediateStateDesc::combineArgType) // intermediate state
499+
).map(Methods::requireType).toArray(TypeMatcher[]::new)
500+
)
501+
);
502+
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
503+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
504+
}
505+
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
489506
}
490507
return builder.build();
491508
}
@@ -524,7 +541,7 @@ private MethodSpec evaluateFinal() {
524541
builder.addStatement("return");
525542
builder.endControlFlow();
526543
}
527-
if (evaluateFinal == null) {
544+
if (aggState.declaredType().isPrimitive()) {
528545
builder.addStatement(switch (aggState.declaredType().toString()) {
529546
case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)";
530547
case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)";
@@ -534,6 +551,12 @@ private MethodSpec evaluateFinal() {
534551
default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]");
535552
});
536553
} else {
554+
requireStaticMethod(
555+
declarationType,
556+
requireType(BLOCK),
557+
requireName("evaluateFinal"),
558+
requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT))
559+
);
537560
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
538561
}
539562
return builder.build();
@@ -593,6 +616,11 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
593616
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
594617
}
595618
}
619+
620+
public TypeName combineArgType() {
621+
var type = Types.fromString(elementType);
622+
return block ? blockType(type) : type;
623+
}
596624
}
597625

598626
/**

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 89 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
import java.util.Arrays;
2424
import java.util.List;
25-
import java.util.Objects;
2625
import java.util.function.Consumer;
26+
import java.util.function.Function;
2727
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
2829

2930
import javax.lang.model.element.ExecutableElement;
3031
import javax.lang.model.element.Modifier;
@@ -34,10 +35,17 @@
3435

3536
import static java.util.stream.Collectors.joining;
3637
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
37-
import static org.elasticsearch.compute.gen.Methods.findMethod;
38-
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
38+
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
39+
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
40+
import static org.elasticsearch.compute.gen.Methods.requireArgs;
41+
import static org.elasticsearch.compute.gen.Methods.requireName;
42+
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
43+
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
44+
import static org.elasticsearch.compute.gen.Methods.requireType;
45+
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
3946
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
4047
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
48+
import static org.elasticsearch.compute.gen.Types.BLOCK;
4149
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
4250
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
4351
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
@@ -71,9 +79,6 @@ public class GroupingAggregatorImplementer {
7179
private final List<TypeMirror> warnExceptions;
7280
private final ExecutableElement init;
7381
private final ExecutableElement combine;
74-
private final ExecutableElement combineStates;
75-
private final ExecutableElement evaluateFinal;
76-
private final ExecutableElement combineIntermediate;
7782
private final List<Parameter> createParameters;
7883
private final ClassName implementation;
7984
private final List<AggregatorImplementer.IntermediateStateDesc> intermediateState;
@@ -92,22 +97,23 @@ public GroupingAggregatorImplementer(
9297
this.declarationType = declarationType;
9398
this.warnExceptions = warnExceptions;
9499

95-
this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true);
100+
this.init = requireStaticMethod(
101+
declarationType,
102+
// This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState
103+
requirePrimitiveOrImplements(elements, Types.RELEASABLE),
104+
requireName("init", "initGrouping"),
105+
requireAnyArgs("<arbitrary init arguments>")
106+
);
96107
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true);
97108

98-
this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
99-
if (e.getParameters().size() == 0) {
100-
return false;
101-
}
102-
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
103-
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
104-
});
105-
// TODO support multiple parameters
109+
this.combine = requireStaticMethod(
110+
declarationType,
111+
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
112+
requireName("combine"),
113+
combineArgs(aggState, includeTimestampVector)
114+
);
106115
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());
107116

108-
this.combineStates = findMethod(declarationType, "combineStates");
109-
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
110-
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
111117
this.createParameters = init.getParameters()
112118
.stream()
113119
.map(Parameter::from)
@@ -125,6 +131,25 @@ public GroupingAggregatorImplementer(
125131
this.includeTimestampVector = includeTimestampVector;
126132
}
127133

134+
private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
135+
if (aggState.declaredType().isPrimitive()) {
136+
return requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"));
137+
} else if (includeTimestampVector) {
138+
return requireArgs(
139+
requireType(aggState.declaredType()),
140+
requireType(TypeName.INT),
141+
requireType(TypeName.LONG), // @timestamp
142+
requireAnyType("<aggregation input column type>")
143+
);
144+
} else {
145+
return requireArgs(
146+
requireType(aggState.declaredType()),
147+
requireType(TypeName.INT),
148+
requireAnyType("<aggregation input column type>")
149+
);
150+
}
151+
}
152+
128153
public ClassName implementation() {
129154
return implementation;
130155
}
@@ -557,31 +582,33 @@ private MethodSpec addIntermediateInput() {
557582
});
558583
builder.endControlFlow();
559584
} else {
560-
builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType);
585+
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
586+
requireStaticMethod(
587+
declarationType,
588+
requireVoidType(),
589+
requireName("combineIntermediate"),
590+
requireArgs(
591+
Stream.of(
592+
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
593+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
594+
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
595+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
596+
)
597+
);
598+
599+
builder.addStatement(
600+
"$T.combineIntermediate(state, groupId, "
601+
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
602+
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
603+
+ ")",
604+
declarationType
605+
);
561606
}
562607
builder.endControlFlow();
563608
}
564609
return builder.build();
565610
}
566611

567-
String intermediateStateRowAccess() {
568-
String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "));
569-
if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) {
570-
rowAccess += ", groupPosition + positionOffset";
571-
}
572-
return rowAccess;
573-
}
574-
575-
private void combineStates(MethodSpec.Builder builder) {
576-
if (combineStates == null) {
577-
builder.beginControlFlow("if (inState.hasValue(position))");
578-
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
579-
builder.endControlFlow();
580-
return;
581-
}
582-
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
583-
}
584-
585612
private MethodSpec addIntermediateRowInput() {
586613
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput");
587614
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
@@ -593,7 +620,24 @@ private MethodSpec addIntermediateRowInput() {
593620
builder.endControlFlow();
594621
builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation);
595622
builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
596-
combineStates(builder);
623+
if (aggState.declaredType().isPrimitive()) {
624+
builder.beginControlFlow("if (inState.hasValue(position))");
625+
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
626+
builder.endControlFlow();
627+
} else {
628+
requireStaticMethod(
629+
declarationType,
630+
requireVoidType(),
631+
requireName("combineStates"),
632+
requireArgs(
633+
requireType(aggState.declaredType()),
634+
requireType(TypeName.INT),
635+
requireType(aggState.declaredType()),
636+
requireType(TypeName.INT)
637+
)
638+
);
639+
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
640+
}
597641
return builder.build();
598642
}
599643

@@ -617,9 +661,15 @@ private MethodSpec evaluateFinal() {
617661
.addParameter(INT_VECTOR, "selected")
618662
.addParameter(DRIVER_CONTEXT, "driverContext");
619663

620-
if (evaluateFinal == null) {
664+
if (aggState.declaredType().isPrimitive()) {
621665
builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)");
622666
} else {
667+
requireStaticMethod(
668+
declarationType,
669+
requireType(BLOCK),
670+
requireName("evaluateFinal"),
671+
requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT))
672+
);
623673
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType);
624674
}
625675
return builder.build();

0 commit comments

Comments
 (0)