Skip to content

Commit 30a2ccd

Browse files
committed
merge with refactoring
1 parent 556a3de commit 30a2ccd

File tree

2 files changed

+161
-50
lines changed

2 files changed

+161
-50
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import org.elasticsearch.compute.ann.Aggregator;
1919
import org.elasticsearch.compute.ann.IntermediateState;
20+
import org.elasticsearch.compute.gen.Methods.TypeMatcher;
2021

2122
import java.util.Arrays;
2223
import java.util.List;
@@ -33,8 +34,14 @@
3334
import javax.lang.model.util.Elements;
3435

3536
import static java.util.stream.Collectors.joining;
36-
import static org.elasticsearch.compute.gen.Methods.findMethod;
37-
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
37+
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
38+
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
39+
import static org.elasticsearch.compute.gen.Methods.requireArgs;
40+
import static org.elasticsearch.compute.gen.Methods.requireName;
41+
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
42+
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
43+
import static org.elasticsearch.compute.gen.Methods.requireType;
44+
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
3845
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
3946
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
4047
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
@@ -66,8 +73,6 @@ public class AggregatorImplementer {
6673
private final List<TypeMirror> warnExceptions;
6774
private final ExecutableElement init;
6875
private final ExecutableElement combine;
69-
private final ExecutableElement combineIntermediate;
70-
private final ExecutableElement evaluateFinal;
7176
private final ClassName implementation;
7277
private final List<IntermediateStateDesc> intermediateState;
7378
private final List<Parameter> createParameters;
@@ -84,21 +89,24 @@ public AggregatorImplementer(
8489
this.declarationType = declarationType;
8590
this.warnExceptions = warnExceptions;
8691

87-
this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true);
92+
this.init = requireStaticMethod(
93+
declarationType,
94+
requirePrimitiveOrImplements(elements, Types.RELEASABLE),// This should be more restrictive
95+
// org.elasticsearch.compute.aggregation.AggregatorState
96+
requireName("init", "initSingle"),
97+
requireAnyArgs("<arbitrary init arguments>")
98+
);
8899
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false);
89100

90-
this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
91-
if (e.getParameters().size() == 0) {
92-
return false;
93-
}
94-
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
95-
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
96-
});
101+
this.combine = requireStaticMethod(
102+
declarationType,
103+
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
104+
requireName("combine"),
105+
requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
106+
);
97107
// TODO support multiple parameters
98108
this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType());
99109

100-
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
101-
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
102110
this.createParameters = init.getParameters()
103111
.stream()
104112
.map(Parameter::from)
@@ -447,12 +455,7 @@ private MethodSpec addIntermediateInput() {
447455
interState.assignToVariable(builder, i);
448456
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
449457
}
450-
if (combineIntermediate != null) {
451-
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
452-
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
453-
}
454-
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
455-
} else if (aggState.declaredType().isPrimitive()) {
458+
if (aggState.declaredType().isPrimitive()) {
456459
if (warnExceptions.isEmpty()) {
457460
assert intermediateState.size() == 2;
458461
assert intermediateState.get(1).name().equals("seen");
@@ -485,7 +488,20 @@ private MethodSpec addIntermediateInput() {
485488
});
486489
builder.endControlFlow();
487490
} else {
488-
throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
491+
requireStaticMethod(
492+
declarationType,
493+
requireVoidType(),
494+
requireName("combineIntermediate"),
495+
requireArgs(
496+
Stream.concat(Stream.of(aggState.declaredType()), intermediateState.stream().map(IntermediateStateDesc::combineArgType))
497+
.map(Methods::requireType)
498+
.toArray(TypeMatcher[]::new)
499+
)
500+
);
501+
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
502+
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
503+
}
504+
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
489505
}
490506
return builder.build();
491507
}
@@ -524,7 +540,7 @@ private MethodSpec evaluateFinal() {
524540
builder.addStatement("return");
525541
builder.endControlFlow();
526542
}
527-
if (evaluateFinal == null) {
543+
if (aggState.declaredType().isPrimitive()) {
528544
builder.addStatement(switch (aggState.declaredType().toString()) {
529545
case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)";
530546
case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)";
@@ -534,6 +550,12 @@ private MethodSpec evaluateFinal() {
534550
default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]");
535551
});
536552
} else {
553+
requireStaticMethod(
554+
declarationType,
555+
requireType(BLOCK),
556+
requireName("evaluateFinal"),
557+
requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT))
558+
);
537559
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
538560
}
539561
return builder.build();
@@ -593,6 +615,11 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
593615
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
594616
}
595617
}
618+
619+
public TypeName combineArgType() {
620+
var type = Types.fromString(elementType);
621+
return block ? blockType(type) : type;
622+
}
596623
}
597624

598625
/**

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 112 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@
1212
import java.util.Arrays;
1313
import java.util.List;
1414
import java.util.Objects;
15+
import java.util.Set;
1516
import java.util.function.Predicate;
17+
import java.util.stream.IntStream;
1618
import java.util.stream.Stream;
1719

1820
import javax.lang.model.element.Element;
1921
import javax.lang.model.element.ExecutableElement;
2022
import javax.lang.model.element.Modifier;
2123
import javax.lang.model.element.TypeElement;
22-
import javax.lang.model.element.VariableElement;
2324
import javax.lang.model.type.DeclaredType;
2425
import javax.lang.model.type.TypeMirror;
2526
import javax.lang.model.util.ElementFilter;
27+
import javax.lang.model.util.Elements;
2628

2729
import static java.util.stream.Collectors.joining;
2830
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
@@ -54,6 +56,115 @@
5456
*/
5557
public class Methods {
5658

59+
static ExecutableElement requireStaticMethod(
60+
TypeElement declarationType,
61+
TypeMatcher returnTypeMatcher,
62+
NameMatcher nameMatcher,
63+
ArgumentMatcher argumentMatcher
64+
) {
65+
return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
66+
.filter(method -> method.getModifiers().contains(Modifier.STATIC))
67+
.filter(method -> nameMatcher.test(method.getSimpleName().toString()))
68+
.filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
69+
.filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
70+
.findFirst()
71+
.orElseThrow(() -> {
72+
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
73+
var signatures = nameMatcher.names.stream()
74+
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
75+
.collect(joining(" or "));
76+
return new IllegalArgumentException(message + signatures);
77+
});
78+
}
79+
80+
static NameMatcher requireName(String... names) {
81+
return new NameMatcher(Set.of(names));
82+
}
83+
84+
static TypeMatcher requireVoidType() {
85+
return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void");
86+
}
87+
88+
static TypeMatcher requireAnyType(String description) {
89+
return new TypeMatcher(type -> true, description);
90+
}
91+
92+
static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) {
93+
return new TypeMatcher(
94+
type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface),
95+
"[boolean|int|long|float|double|" + requiredInterface + "]"
96+
);
97+
}
98+
99+
static TypeMatcher requireType(TypeName requiredType) {
100+
return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString());
101+
}
102+
103+
static ArgumentMatcher requireAnyArgs(String description) {
104+
return new ArgumentMatcher(args -> true, description);
105+
}
106+
107+
static ArgumentMatcher requireArgs(TypeMatcher... argTypes) {
108+
return new ArgumentMatcher(
109+
args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))),
110+
Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", "))
111+
);
112+
}
113+
114+
record NameMatcher(Set<String> names) implements Predicate<String> {
115+
@Override
116+
public boolean test(String name) {
117+
return names.contains(name);
118+
}
119+
}
120+
121+
record TypeMatcher(Predicate<TypeName> matcher, String description) implements Predicate<TypeName> {
122+
@Override
123+
public boolean test(TypeName typeName) {
124+
return matcher.test(typeName);
125+
}
126+
127+
@Override
128+
public String toString() {
129+
return description;
130+
}
131+
}
132+
133+
record ArgumentMatcher(Predicate<List<TypeName>> matcher, String description) implements Predicate<List<TypeName>> {
134+
@Override
135+
public boolean test(List<TypeName> typeName) {
136+
return matcher.test(typeName);
137+
}
138+
139+
@Override
140+
public String toString() {
141+
return description;
142+
}
143+
}
144+
145+
private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) {
146+
return allInterfacesOf(elements, type).anyMatch(
147+
anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString())
148+
);
149+
}
150+
151+
private static Stream<TypeName> allInterfacesOf(Elements elements, TypeName type) {
152+
return elements.getTypeElement(type.toString())
153+
.getInterfaces()
154+
.stream()
155+
.map(TypeName::get)
156+
.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)));
157+
}
158+
159+
private static Stream<TypeElement> typeAndSuperType(TypeElement declarationType) {
160+
if (declarationType.getSuperclass() instanceof DeclaredType declaredType
161+
&& declaredType.asElement() instanceof TypeElement superType) {
162+
return Stream.of(declarationType, superType);
163+
} else {
164+
return Stream.of(declarationType);
165+
}
166+
}
167+
57168
static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
58169
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
59170
if (result == null) {
@@ -100,33 +211,6 @@ static ExecutableElement findMethod(String[] names, Predicate<ExecutableElement>
100211
return null;
101212
}
102213

103-
static void requireMethod(TypeElement element, String name, String returnType, String... parameterTypes) {
104-
var method = findMethod(new String[] { name }, e -> true, element, superClassOf(element));
105-
if (method == null || isNotSame(method.getReturnType(), returnType) || isNotSame(method.getParameters(), parameterTypes)) {
106-
throw new IllegalArgumentException("Requires method " + signature(element, name, returnType, parameterTypes));
107-
}
108-
}
109-
110-
private static boolean isNotSame(TypeMirror type, String required) {
111-
return Objects.equals(type.toString(), required) == false;
112-
}
113-
114-
private static boolean isNotSame(List<? extends VariableElement> types, String[] required) {
115-
if (types.size() != required.length) {
116-
return true;
117-
}
118-
for (int i = 0; i < types.size(); i++) {
119-
if (isNotSame(types.get(i).asType(), required[i])) {
120-
return true;
121-
}
122-
}
123-
return false;
124-
}
125-
126-
private static String signature(TypeElement element, String name, String returnType, String[] parameterTypes) {
127-
return "public static " + returnType + " " + element + "#" + name + Stream.of(parameterTypes).collect(joining(", ", "(", ")"));
128-
}
129-
130214
/**
131215
* Returns the name of the method used to add {@code valueType} instances
132216
* to vector or block builders.

0 commit comments

Comments
 (0)