Skip to content

Commit c2e632a

Browse files
authored
[9.0] backport various aggs code gen improvements (#122360)
1 parent 3202262 commit c2e632a

37 files changed

+754
-692
lines changed

x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@
3737
* are ever collected.
3838
* </p>
3939
* <p>
40-
* The generation code will also look for a method called {@code combineValueCount}
41-
* which is called once per received block with a count of values. NOTE: We may
42-
* not need this after we convert AVG into a composite operation.
43-
* </p>
44-
* <p>
4540
* The generation code also looks for the optional methods {@code combineIntermediate}
4641
* and {@code evaluateFinal} which are used to combine intermediate states and
4742
* produce the final output. If the first is missing then the generated code will
@@ -63,4 +58,8 @@
6358
*/
6459
Class<? extends Exception>[] warnExceptions() default {};
6560

61+
/**
62+
* If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
63+
*/
64+
boolean includeTimestamps() default false;
6665
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 232 additions & 230 deletions
Large diffs are not rendered by default.

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
8787
);
8888
if (aggClass.getAnnotation(Aggregator.class) != null) {
8989
IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value();
90-
implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes);
90+
implementer = new AggregatorImplementer(
91+
env.getElementUtils(),
92+
aggClass,
93+
intermediateState,
94+
warnExceptionsTypes,
95+
aggClass.getAnnotation(Aggregator.class).includeTimestamps()
96+
);
9197
write(aggClass, "aggregator", implementer.sourceFile(), env);
9298
}
9399
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
@@ -96,13 +102,12 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
96102
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
97103
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
98104
}
99-
boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
100105
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
101106
env.getElementUtils(),
102107
aggClass,
103108
intermediateState,
104109
warnExceptionsTypes,
105-
includeTimestamps
110+
aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps()
106111
);
107112
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
108113
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 184 additions & 148 deletions
Large diffs are not rendered by default.

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@
99

1010
import com.squareup.javapoet.TypeName;
1111

12-
import java.util.Arrays;
12+
import java.util.List;
13+
import java.util.Objects;
14+
import java.util.Set;
1315
import java.util.function.Predicate;
16+
import java.util.stream.IntStream;
17+
import java.util.stream.Stream;
1418

15-
import javax.lang.model.element.Element;
1619
import javax.lang.model.element.ExecutableElement;
1720
import javax.lang.model.element.Modifier;
1821
import javax.lang.model.element.TypeElement;
19-
import javax.lang.model.element.VariableElement;
2022
import javax.lang.model.type.DeclaredType;
21-
import javax.lang.model.type.TypeMirror;
23+
import javax.lang.model.type.TypeKind;
2224
import javax.lang.model.util.ElementFilter;
25+
import javax.lang.model.util.Elements;
2326

27+
import static java.util.stream.Collectors.joining;
2428
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
2529
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER;
2630
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
@@ -49,30 +53,116 @@
4953
* Finds declared methods for the code generator.
5054
*/
5155
public class Methods {
52-
static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
53-
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
54-
if (result == null) {
55-
if (names.length == 1) {
56-
throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required");
57-
}
58-
throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required");
56+
57+
static ExecutableElement requireStaticMethod(
58+
TypeElement declarationType,
59+
TypeMatcher returnTypeMatcher,
60+
NameMatcher nameMatcher,
61+
ArgumentMatcher argumentMatcher
62+
) {
63+
return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
64+
.filter(method -> method.getModifiers().contains(Modifier.STATIC))
65+
.filter(method -> nameMatcher.test(method.getSimpleName().toString()))
66+
.filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
67+
.filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
68+
.findFirst()
69+
.orElseThrow(() -> {
70+
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
71+
var signatures = nameMatcher.names.stream()
72+
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
73+
.collect(joining(" or "));
74+
return new IllegalArgumentException(message + signatures);
75+
});
76+
}
77+
78+
static NameMatcher requireName(String... names) {
79+
return new NameMatcher(Set.of(names));
80+
}
81+
82+
static TypeMatcher requireVoidType() {
83+
return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void");
84+
}
85+
86+
static TypeMatcher requireAnyType(String description) {
87+
return new TypeMatcher(type -> true, description);
88+
}
89+
90+
static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) {
91+
return new TypeMatcher(
92+
type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface),
93+
"[boolean|int|long|float|double|" + requiredInterface + "]"
94+
);
95+
}
96+
97+
static TypeMatcher requireType(TypeName requiredType) {
98+
return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString());
99+
}
100+
101+
static ArgumentMatcher requireAnyArgs(String description) {
102+
return new ArgumentMatcher(args -> true, description);
103+
}
104+
105+
static ArgumentMatcher requireArgs(TypeMatcher... argTypes) {
106+
return new ArgumentMatcher(
107+
args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))),
108+
Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", "))
109+
);
110+
}
111+
112+
record NameMatcher(Set<String> names) implements Predicate<String> {
113+
@Override
114+
public boolean test(String name) {
115+
return names.contains(name);
59116
}
60-
return result;
61117
}
62118

63-
static ExecutableElement findMethod(TypeElement declarationType, String name) {
64-
return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType));
119+
record TypeMatcher(Predicate<TypeName> matcher, String description) implements Predicate<TypeName> {
120+
@Override
121+
public boolean test(TypeName typeName) {
122+
return matcher.test(typeName);
123+
}
124+
125+
@Override
126+
public String toString() {
127+
return description;
128+
}
65129
}
66130

67-
private static TypeElement superClassOf(TypeElement declarationType) {
68-
TypeMirror superclass = declarationType.getSuperclass();
69-
if (superclass instanceof DeclaredType declaredType) {
70-
Element superclassElement = declaredType.asElement();
71-
if (superclassElement instanceof TypeElement) {
72-
return (TypeElement) superclassElement;
73-
}
131+
record ArgumentMatcher(Predicate<List<TypeName>> matcher, String description) implements Predicate<List<TypeName>> {
132+
@Override
133+
public boolean test(List<TypeName> typeName) {
134+
return matcher.test(typeName);
135+
}
136+
137+
@Override
138+
public String toString() {
139+
return description;
140+
}
141+
}
142+
143+
private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) {
144+
return allInterfacesOf(elements, type).anyMatch(
145+
anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString())
146+
);
147+
}
148+
149+
private static Stream<TypeName> allInterfacesOf(Elements elements, TypeName type) {
150+
var typeElement = elements.getTypeElement(type.toString());
151+
var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get);
152+
var interfaces = typeElement.getInterfaces().stream().map(TypeName::get);
153+
return Stream.concat(
154+
superType.flatMap(sType -> allInterfacesOf(elements, sType)),
155+
interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)))
156+
);
157+
}
158+
159+
private static Stream<TypeElement> typeAndSuperType(TypeElement declarationType) {
160+
if (declarationType.getSuperclass() instanceof DeclaredType declaredType
161+
&& declaredType.asElement() instanceof TypeElement superType) {
162+
return Stream.of(declarationType, superType);
163+
} else {
164+
return Stream.of(declarationType);
74165
}
75-
return null;
76166
}
77167

78168
static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
@@ -95,16 +185,6 @@ static ExecutableElement findMethod(String[] names, Predicate<ExecutableElement>
95185
return null;
96186
}
97187

98-
/**
99-
* Returns the arguments of a method after applying a filter.
100-
*/
101-
static VariableElement[] findMethodArguments(ExecutableElement method, Predicate<VariableElement> filter) {
102-
if (method.getParameters().isEmpty()) {
103-
return new VariableElement[0];
104-
}
105-
return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new);
106-
}
107-
108188
/**
109189
* Returns the name of the method used to add {@code valueType} instances
110190
* to vector or block builders.

0 commit comments

Comments
 (0)