Skip to content

Commit dfb6ae3

Browse files
committed
Update aggs code generation to be more explicit about required methods
1 parent e5ea00a commit dfb6ae3

File tree

3 files changed

+85
-47
lines changed

3 files changed

+85
-47
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
import javax.lang.model.util.Elements;
3535

3636
import static java.util.stream.Collectors.joining;
37-
import static org.elasticsearch.compute.gen.Methods.findMethod;
3837
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
38+
import static org.elasticsearch.compute.gen.Methods.requireMethod;
3939
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
4040
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
4141
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
@@ -78,8 +78,6 @@ public class AggregatorImplementer {
7878
private final List<TypeMirror> warnExceptions;
7979
private final ExecutableElement init;
8080
private final ExecutableElement combine;
81-
private final ExecutableElement combineIntermediate;
82-
private final ExecutableElement evaluateFinal;
8381
private final ClassName implementation;
8482
private final TypeName stateType;
8583
private final boolean stateTypeHasSeen;
@@ -114,8 +112,6 @@ public AggregatorImplementer(
114112
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
115113
return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString());
116114
});
117-
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
118-
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
119115
this.createParameters = init.getParameters()
120116
.stream()
121117
.map(Parameter::from)
@@ -198,9 +194,7 @@ static ClassName valueVectorType(ExecutableElement init, ExecutableElement combi
198194
}
199195

200196
public static String firstUpper(String s) {
201-
String head = s.toString().substring(0, 1).toUpperCase(Locale.ROOT);
202-
String tail = s.toString().substring(1);
203-
return head + tail;
197+
return Character.toUpperCase(s.charAt(0)) + s.substring(1);
204198
}
205199

206200
public JavaFile sourceFile() {
@@ -526,12 +520,7 @@ private MethodSpec addIntermediateInput() {
526520
interState.assignToVariable(builder, i);
527521
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
528522
}
529-
if (combineIntermediate != null) {
530-
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
531-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
532-
}
533-
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
534-
} else if (hasPrimitiveState()) {
523+
if (hasPrimitiveState()) {
535524
if (warnExceptions.isEmpty()) {
536525
assert intermediateState.size() == 2;
537526
assert intermediateState.get(1).name().equals("seen");
@@ -547,7 +536,6 @@ private MethodSpec addIntermediateInput() {
547536
}
548537
builder.nextControlFlow("else if (seen.getBoolean(0))");
549538
}
550-
551539
warningsBlock(builder, () -> {
552540
var state = intermediateState.get(0);
553541
var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))";
@@ -556,32 +544,39 @@ private MethodSpec addIntermediateInput() {
556544
});
557545
builder.endControlFlow();
558546
} else {
559-
throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
547+
requireMethod(
548+
declarationType,
549+
"combineIntermediate",
550+
"void",
551+
Stream.concat(Stream.of(stateType.toString()), intermediateState.stream().map(intermediateStateDesc -> {
552+
var type = Types.fromString(intermediateStateDesc.elementType());
553+
return intermediateStateDesc.block ? blockType(type).toString() : type.toString();
554+
})).toArray(String[]::new)
555+
);
556+
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
557+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
558+
}
559+
builder.addStatement(
560+
"$T.combineIntermediate(state, " + intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", ")) + ")",
561+
declarationType
562+
);
560563
}
561564
return builder.build();
562565
}
563566

564-
String intermediateStateRowAccess() {
565-
return intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", "));
566-
}
567-
568567
private String primitiveStateMethod() {
569-
switch (stateType.toString()) {
570-
case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState":
571-
return "booleanValue";
572-
case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState":
573-
return "intValue";
574-
case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState":
575-
return "longValue";
576-
case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState":
577-
return "doubleValue";
578-
case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState":
579-
return "floatValue";
580-
default:
581-
throw new IllegalArgumentException(
582-
"don't know how to fetch primitive values from " + stateType + ". define combineIntermediate."
583-
);
584-
}
568+
return switch (stateType.toString()) {
569+
case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState" ->
570+
"booleanValue";
571+
case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState" -> "intValue";
572+
case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState" ->
573+
"longValue";
574+
case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState" ->
575+
"doubleValue";
576+
case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState" ->
577+
"floatValue";
578+
default -> throw new IllegalArgumentException("don't know how to fetch primitive values from " + stateType + ".");
579+
};
585580
}
586581

587582
private MethodSpec evaluateIntermediate() {
@@ -611,9 +606,15 @@ private MethodSpec evaluateFinal() {
611606
builder.addStatement("return");
612607
builder.endControlFlow();
613608
}
614-
if (evaluateFinal == null) {
609+
if (hasPrimitiveState()) {
615610
primitiveStateToResult(builder);
616611
} else {
612+
requireMethod(
613+
declarationType,
614+
"evaluateFinal",
615+
"org.elasticsearch.compute.data.Block",
616+
new String[] { stateType.toString(), "org.elasticsearch.compute.operator.DriverContext" }
617+
);
617618
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
618619
}
619620
return builder.build();

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import com.squareup.javapoet.TypeName;
1111

1212
import java.util.Arrays;
13+
import java.util.List;
14+
import java.util.Objects;
1315
import java.util.function.Predicate;
16+
import java.util.stream.Stream;
1417

1518
import javax.lang.model.element.Element;
1619
import javax.lang.model.element.ExecutableElement;
@@ -21,6 +24,7 @@
2124
import javax.lang.model.type.TypeMirror;
2225
import javax.lang.model.util.ElementFilter;
2326

27+
import static java.util.stream.Collectors.joining;
2428
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
2529
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER;
2630
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
@@ -49,6 +53,7 @@
4953
* Finds declared methods for the code generator.
5054
*/
5155
public class Methods {
56+
5257
static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
5358
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
5459
if (result == null) {
@@ -95,14 +100,31 @@ static ExecutableElement findMethod(String[] names, Predicate<ExecutableElement>
95100
return null;
96101
}
97102

98-
/**
99-
* Returns the arguments of a method after applying a filter.
100-
*/
101-
static VariableElement[] findMethodArguments(ExecutableElement method, Predicate<VariableElement> filter) {
102-
if (method.getParameters().isEmpty()) {
103-
return new VariableElement[0];
103+
static void requireMethod(TypeElement element, String name, String returnType, String... parameterTypes) {
104+
var method = findMethod(new String[] { name }, e -> true, element, superClassOf(element));
105+
if (method == null || isNotSame(method.getReturnType(), returnType) || isNotSame(method.getParameters(), parameterTypes)) {
106+
throw new IllegalArgumentException("Requires method " + signature(element, name, returnType, parameterTypes));
107+
}
108+
}
109+
110+
private static boolean isNotSame(TypeMirror type, String required) {
111+
return Objects.equals(type.toString(), required) == false;
112+
}
113+
114+
private static boolean isNotSame(List<? extends VariableElement> types, String[] required) {
115+
if (types.size() != required.length) {
116+
return true;
117+
}
118+
for (int i = 0; i < types.size(); i++) {
119+
if (isNotSame(types.get(i).asType(), required[i])) {
120+
return true;
121+
}
104122
}
105-
return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new);
123+
return false;
124+
}
125+
126+
private static String signature(TypeElement element, String name, String returnType, String[] parameterTypes) {
127+
return "public static " + returnType + " " + element + "#" + name + Stream.of(parameterTypes).collect(joining(", ", "(", ")"));
106128
}
107129

108130
/**

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,37 @@ public class Types {
138138
static final ClassName RELEASABLE = ClassName.get("org.elasticsearch.core", "Releasable");
139139
static final ClassName RELEASABLES = ClassName.get("org.elasticsearch.core", "Releasables");
140140

141+
static TypeName fromString(String type) {
142+
return switch (type) {
143+
case "boolean", "BOOLEAN" -> TypeName.BOOLEAN;
144+
case "int", "INT" -> TypeName.INT;
145+
case "long", "LONG" -> TypeName.LONG;
146+
case "float", "FLOAT" -> TypeName.FLOAT;
147+
case "double", "DOUBLE" -> TypeName.DOUBLE;
148+
case "org.apache.lucene.util.BytesRef", "BYTES_REF" -> BYTES_REF;
149+
default -> throw new IllegalArgumentException("unknown type [" + type + "]");
150+
};
151+
}
152+
141153
static ClassName blockType(TypeName elementType) {
142154
if (elementType.equals(TypeName.BOOLEAN)) {
143155
return BOOLEAN_BLOCK;
144156
}
145-
if (elementType.equals(BYTES_REF)) {
146-
return BYTES_REF_BLOCK;
147-
}
148157
if (elementType.equals(TypeName.INT)) {
149158
return INT_BLOCK;
150159
}
151160
if (elementType.equals(TypeName.LONG)) {
152161
return LONG_BLOCK;
153162
}
163+
if (elementType.equals(TypeName.FLOAT)) {
164+
return FLOAT_BLOCK;
165+
}
154166
if (elementType.equals(TypeName.DOUBLE)) {
155167
return DOUBLE_BLOCK;
156168
}
169+
if (elementType.equals(BYTES_REF)) {
170+
return BYTES_REF_BLOCK;
171+
}
157172
throw new IllegalArgumentException("unknown block type for [" + elementType + "]");
158173
}
159174

0 commit comments

Comments
 (0)