Skip to content

Commit be957db

Browse files
committed
handle methods in grouping aggregator
1 parent 54634a1 commit be957db

File tree

2 files changed

+50
-54
lines changed

2 files changed

+50
-54
lines changed

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
import java.util.Arrays;
2424
import java.util.List;
25-
import java.util.Objects;
2625
import java.util.function.Consumer;
26+
import java.util.function.Function;
2727
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
2829

2930
import javax.lang.model.element.ExecutableElement;
3031
import javax.lang.model.element.Modifier;
@@ -34,9 +35,8 @@
3435

3536
import static java.util.stream.Collectors.joining;
3637
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
37-
import static org.elasticsearch.compute.gen.Methods.findMethod;
38-
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
3938
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
39+
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
4040
import static org.elasticsearch.compute.gen.Methods.requireArgs;
4141
import static org.elasticsearch.compute.gen.Methods.requireName;
4242
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
@@ -79,7 +79,6 @@ public class GroupingAggregatorImplementer {
7979
private final List<TypeMirror> warnExceptions;
8080
private final ExecutableElement init;
8181
private final ExecutableElement combine;
82-
private final ExecutableElement combineStates;
8382
private final List<Parameter> createParameters;
8483
private final ClassName implementation;
8584
private final List<AggregatorImplementer.IntermediateStateDesc> intermediateState;
@@ -100,25 +99,21 @@ public GroupingAggregatorImplementer(
10099

101100
this.init = requireStaticMethod(
102101
declarationType,
103-
// This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
102+
// This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState
104103
requirePrimitiveOrImplements(elements, Types.RELEASABLE),
105104
requireName("init", "initGrouping"),
106105
requireAnyArgs("<arbitrary init arguments>")
107106
);
108107
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true);
109108

110-
// TODO optional timestamp
111-
this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
112-
if (e.getParameters().size() == 0) {
113-
return false;
114-
}
115-
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
116-
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
117-
});
118-
// TODO support multiple parameters
109+
this.combine = requireStaticMethod(
110+
declarationType,
111+
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
112+
requireName("combine"),
113+
combineArgs(aggState, includeTimestampVector)
114+
);
119115
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());
120116

121-
this.combineStates = findMethod(declarationType, "combineStates");
122117
this.createParameters = init.getParameters()
123118
.stream()
124119
.map(Parameter::from)
@@ -136,6 +131,25 @@ public GroupingAggregatorImplementer(
136131
this.includeTimestampVector = includeTimestampVector;
137132
}
138133

134+
private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
135+
if (aggState.declaredType().isPrimitive()) {
136+
return requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"));
137+
} else if (includeTimestampVector) {
138+
return requireArgs(
139+
requireType(aggState.declaredType()),
140+
requireType(TypeName.INT),
141+
requireType(TypeName.LONG), // @timestamp
142+
requireAnyType("<aggregation input column type>")
143+
);
144+
} else {
145+
return requireArgs(
146+
requireType(aggState.declaredType()),
147+
requireType(TypeName.INT),
148+
requireAnyType("<aggregation input column type>")
149+
);
150+
}
151+
}
152+
139153
public ClassName implementation() {
140154
return implementation;
141155
}
@@ -568,22 +582,33 @@ private MethodSpec addIntermediateInput() {
568582
});
569583
builder.endControlFlow();
570584
} else {
571-
// TODO combineIntermediate with optional block parameter
572-
builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType);
585+
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
586+
requireStaticMethod(
587+
declarationType,
588+
requireVoidType(),
589+
requireName("combineIntermediate"),
590+
requireArgs(
591+
Stream.of(
592+
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
593+
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
594+
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
595+
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
596+
)
597+
);
598+
599+
builder.addStatement(
600+
"$T.combineIntermediate(state, groupId, "
601+
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
602+
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
603+
+ ")",
604+
declarationType
605+
);
573606
}
574607
builder.endControlFlow();
575608
}
576609
return builder.build();
577610
}
578611

579-
String intermediateStateRowAccess() {
580-
String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "));
581-
if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) {
582-
rowAccess += ", groupPosition + positionOffset";
583-
}
584-
return rowAccess;
585-
}
586-
587612
private MethodSpec addIntermediateRowInput() {
588613
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput");
589614
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,18 @@
99

1010
import com.squareup.javapoet.TypeName;
1111

12-
import java.util.Arrays;
1312
import java.util.List;
1413
import java.util.Objects;
1514
import java.util.Set;
1615
import java.util.function.Predicate;
1716
import java.util.stream.IntStream;
1817
import java.util.stream.Stream;
1918

20-
import javax.lang.model.element.Element;
2119
import javax.lang.model.element.ExecutableElement;
2220
import javax.lang.model.element.Modifier;
2321
import javax.lang.model.element.TypeElement;
2422
import javax.lang.model.type.DeclaredType;
2523
import javax.lang.model.type.TypeKind;
26-
import javax.lang.model.type.TypeMirror;
2724
import javax.lang.model.util.ElementFilter;
2825
import javax.lang.model.util.Elements;
2926

@@ -168,32 +165,6 @@ private static Stream<TypeElement> typeAndSuperType(TypeElement declarationType)
168165
}
169166
}
170167

171-
static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
172-
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
173-
if (result == null) {
174-
if (names.length == 1) {
175-
throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required");
176-
}
177-
throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required");
178-
}
179-
return result;
180-
}
181-
182-
static ExecutableElement findMethod(TypeElement declarationType, String name) {
183-
return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType));
184-
}
185-
186-
private static TypeElement superClassOf(TypeElement declarationType) {
187-
TypeMirror superclass = declarationType.getSuperclass();
188-
if (superclass instanceof DeclaredType declaredType) {
189-
Element superclassElement = declaredType.asElement();
190-
if (superclassElement instanceof TypeElement) {
191-
return (TypeElement) superclassElement;
192-
}
193-
}
194-
return null;
195-
}
196-
197168
static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
198169
return findMethod(names, filter, declarationType);
199170
}

0 commit comments

Comments
 (0)