diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 96ecb47709f39..a07796882d46e 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -17,6 +17,7 @@ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.gen.Methods.TypeMatcher; import java.util.Arrays; import java.util.List; @@ -33,8 +34,14 @@ import javax.lang.model.util.Elements; import static java.util.stream.Collectors.joining; -import static org.elasticsearch.compute.gen.Methods.findMethod; -import static org.elasticsearch.compute.gen.Methods.findRequiredMethod; +import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; +import static org.elasticsearch.compute.gen.Methods.requireAnyType; +import static org.elasticsearch.compute.gen.Methods.requireArgs; +import static org.elasticsearch.compute.gen.Methods.requireName; +import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements; +import static org.elasticsearch.compute.gen.Methods.requireStaticMethod; +import static org.elasticsearch.compute.gen.Methods.requireType; +import static org.elasticsearch.compute.gen.Methods.requireVoidType; import static org.elasticsearch.compute.gen.Methods.vectorAccessorName; import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION; import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; @@ -66,8 +73,6 @@ public class AggregatorImplementer { private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; - private final ExecutableElement combineIntermediate; - private final ExecutableElement evaluateFinal; private final ClassName implementation; private final List intermediateState; private final List createParameters; @@ -84,21 +89,24 @@ public AggregatorImplementer( this.declarationType = declarationType; this.warnExceptions = warnExceptions; - this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true); + this.init = requireStaticMethod( + declarationType, + // This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState + requirePrimitiveOrImplements(elements, Types.RELEASABLE), + requireName("init", "initSingle"), + requireAnyArgs("") + ); this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false); - this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> { - if (e.getParameters().size() == 0) { - return false; - } - TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType()); - return Objects.equals(firstParamType.toString(), aggState.declaredType().toString()); - }); + this.combine = requireStaticMethod( + declarationType, + aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(), + requireName("combine"), + requireArgs(requireType(aggState.declaredType()), requireAnyType("")) + ); // TODO support multiple parameters this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType()); - this.combineIntermediate = findMethod(declarationType, "combineIntermediate"); - this.evaluateFinal = findMethod(declarationType, "evaluateFinal"); this.createParameters = init.getParameters() .stream() .map(Parameter::from) @@ -447,12 +455,7 @@ private MethodSpec addIntermediateInput() { interState.assignToVariable(builder, i); builder.addStatement("assert $L.getPositionCount() == 1", interState.name()); } - if (combineIntermediate != null) { - if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); - } - builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType); - } else if (aggState.declaredType().isPrimitive()) { + if (aggState.declaredType().isPrimitive()) { if (warnExceptions.isEmpty()) { assert intermediateState.size() == 2; assert intermediateState.get(1).name().equals("seen"); @@ -485,7 +488,21 @@ private MethodSpec addIntermediateInput() { }); builder.endControlFlow(); } else { - throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate"); + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.concat( + Stream.of(aggState.declaredType()), // aggState + intermediateState.stream().map(IntermediateStateDesc::combineArgType) // intermediate state + ).map(Methods::requireType).toArray(TypeMatcher[]::new) + ) + ); + if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); + } + builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType); } return builder.build(); } @@ -524,7 +541,7 @@ private MethodSpec evaluateFinal() { builder.addStatement("return"); builder.endControlFlow(); } - if (evaluateFinal == null) { + if (aggState.declaredType().isPrimitive()) { builder.addStatement(switch (aggState.declaredType().toString()) { case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)"; case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)"; @@ -534,6 +551,12 @@ private MethodSpec evaluateFinal() { default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]"); }); } else { + requireStaticMethod( + declarationType, + requireType(BLOCK), + requireName("evaluateFinal"), + requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT)) + ); builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType); } return builder.build(); @@ -593,6 +616,11 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) { builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast"); } } + + public TypeName combineArgType() { + var type = Types.fromString(elementType); + return block ? blockType(type) : type; + } } /** diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 7f2a8f366caea..180589fc64373 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -22,9 +22,10 @@ import java.util.Arrays; import java.util.List; -import java.util.Objects; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; @@ -34,10 +35,17 @@ import static java.util.stream.Collectors.joining; import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize; -import static org.elasticsearch.compute.gen.Methods.findMethod; -import static org.elasticsearch.compute.gen.Methods.findRequiredMethod; +import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; +import static org.elasticsearch.compute.gen.Methods.requireAnyType; +import static org.elasticsearch.compute.gen.Methods.requireArgs; +import static org.elasticsearch.compute.gen.Methods.requireName; +import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements; +import static org.elasticsearch.compute.gen.Methods.requireStaticMethod; +import static org.elasticsearch.compute.gen.Methods.requireType; +import static org.elasticsearch.compute.gen.Methods.requireVoidType; import static org.elasticsearch.compute.gen.Methods.vectorAccessorName; import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; +import static org.elasticsearch.compute.gen.Types.BLOCK; import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY; import static org.elasticsearch.compute.gen.Types.BYTES_REF; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; @@ -71,9 +79,6 @@ public class GroupingAggregatorImplementer { private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; - private final ExecutableElement combineStates; - private final ExecutableElement evaluateFinal; - private final ExecutableElement combineIntermediate; private final List createParameters; private final ClassName implementation; private final List intermediateState; @@ -92,22 +97,23 @@ public GroupingAggregatorImplementer( this.declarationType = declarationType; this.warnExceptions = warnExceptions; - this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true); + this.init = requireStaticMethod( + declarationType, + // This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState + requirePrimitiveOrImplements(elements, Types.RELEASABLE), + requireName("init", "initGrouping"), + requireAnyArgs("") + ); this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true); - this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> { - if (e.getParameters().size() == 0) { - return false; - } - TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType()); - return Objects.equals(firstParamType.toString(), aggState.declaredType().toString()); - }); - // TODO support multiple parameters + this.combine = requireStaticMethod( + declarationType, + aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(), + requireName("combine"), + combineArgs(aggState, includeTimestampVector) + ); this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType()); - this.combineStates = findMethod(declarationType, "combineStates"); - this.combineIntermediate = findMethod(declarationType, "combineIntermediate"); - this.evaluateFinal = findMethod(declarationType, "evaluateFinal"); this.createParameters = init.getParameters() .stream() .map(Parameter::from) @@ -125,6 +131,25 @@ public GroupingAggregatorImplementer( this.includeTimestampVector = includeTimestampVector; } + private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) { + if (aggState.declaredType().isPrimitive()) { + return requireArgs(requireType(aggState.declaredType()), requireAnyType("")); + } else if (includeTimestampVector) { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireType(TypeName.LONG), // @timestamp + requireAnyType("") + ); + } else { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireAnyType("") + ); + } + } + public ClassName implementation() { return implementation; } @@ -557,31 +582,33 @@ private MethodSpec addIntermediateInput() { }); builder.endControlFlow(); } else { - builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType); + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.of( + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) + ); + + builder.addStatement( + "$T.combineIntermediate(state, groupId, " + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) + + (stateHasBlock ? ", groupPosition + positionOffset" : "") + + ")", + declarationType + ); } builder.endControlFlow(); } return builder.build(); } - String intermediateStateRowAccess() { - String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")); - if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) { - rowAccess += ", groupPosition + positionOffset"; - } - return rowAccess; - } - - private void combineStates(MethodSpec.Builder builder) { - if (combineStates == null) { - builder.beginControlFlow("if (inState.hasValue(position))"); - builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType); - builder.endControlFlow(); - return; - } - builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType); - } - private MethodSpec addIntermediateRowInput() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); @@ -593,7 +620,24 @@ private MethodSpec addIntermediateRowInput() { builder.endControlFlow(); builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation); builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS); - combineStates(builder); + if (aggState.declaredType().isPrimitive()) { + builder.beginControlFlow("if (inState.hasValue(position))"); + builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType); + builder.endControlFlow(); + } else { + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineStates"), + requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireType(aggState.declaredType()), + requireType(TypeName.INT) + ) + ); + builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType); + } return builder.build(); } @@ -617,9 +661,15 @@ private MethodSpec evaluateFinal() { .addParameter(INT_VECTOR, "selected") .addParameter(DRIVER_CONTEXT, "driverContext"); - if (evaluateFinal == null) { + if (aggState.declaredType().isPrimitive()) { builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)"); } else { + requireStaticMethod( + declarationType, + requireType(BLOCK), + requireName("evaluateFinal"), + requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT)) + ); builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType); } return builder.build(); diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java index 6f98f1f797ab0..f2fa7b8084448 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java @@ -9,18 +9,22 @@ import com.squareup.javapoet.TypeName; -import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Set; import java.util.function.Predicate; +import java.util.stream.IntStream; +import java.util.stream.Stream; -import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; -import javax.lang.model.element.VariableElement; import javax.lang.model.type.DeclaredType; -import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeKind; import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; +import static java.util.stream.Collectors.joining; import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK; import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER; import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR; @@ -49,30 +53,116 @@ * Finds declared methods for the code generator. */ public class Methods { - static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate filter) { - ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType)); - if (result == null) { - if (names.length == 1) { - throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required"); - } - throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required"); + + static ExecutableElement requireStaticMethod( + TypeElement declarationType, + TypeMatcher returnTypeMatcher, + NameMatcher nameMatcher, + ArgumentMatcher argumentMatcher + ) { + return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream()) + .filter(method -> method.getModifiers().contains(Modifier.STATIC)) + .filter(method -> nameMatcher.test(method.getSimpleName().toString())) + .filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType()))) + .filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList())) + .findFirst() + .orElseThrow(() -> { + var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; + var signatures = nameMatcher.names.stream() + .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") + .collect(joining(" or ")); + return new IllegalArgumentException(message + signatures); + }); + } + + static NameMatcher requireName(String... names) { + return new NameMatcher(Set.of(names)); + } + + static TypeMatcher requireVoidType() { + return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void"); + } + + static TypeMatcher requireAnyType(String description) { + return new TypeMatcher(type -> true, description); + } + + static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) { + return new TypeMatcher( + type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface), + "[boolean|int|long|float|double|" + requiredInterface + "]" + ); + } + + static TypeMatcher requireType(TypeName requiredType) { + return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString()); + } + + static ArgumentMatcher requireAnyArgs(String description) { + return new ArgumentMatcher(args -> true, description); + } + + static ArgumentMatcher requireArgs(TypeMatcher... argTypes) { + return new ArgumentMatcher( + args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))), + Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", ")) + ); + } + + record NameMatcher(Set names) implements Predicate { + @Override + public boolean test(String name) { + return names.contains(name); } - return result; } - static ExecutableElement findMethod(TypeElement declarationType, String name) { - return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType)); + record TypeMatcher(Predicate matcher, String description) implements Predicate { + @Override + public boolean test(TypeName typeName) { + return matcher.test(typeName); + } + + @Override + public String toString() { + return description; + } } - private static TypeElement superClassOf(TypeElement declarationType) { - TypeMirror superclass = declarationType.getSuperclass(); - if (superclass instanceof DeclaredType declaredType) { - Element superclassElement = declaredType.asElement(); - if (superclassElement instanceof TypeElement) { - return (TypeElement) superclassElement; - } + record ArgumentMatcher(Predicate> matcher, String description) implements Predicate> { + @Override + public boolean test(List typeName) { + return matcher.test(typeName); + } + + @Override + public String toString() { + return description; + } + } + + private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) { + return allInterfacesOf(elements, type).anyMatch( + anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString()) + ); + } + + private static Stream allInterfacesOf(Elements elements, TypeName type) { + var typeElement = elements.getTypeElement(type.toString()); + var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get); + var interfaces = typeElement.getInterfaces().stream().map(TypeName::get); + return Stream.concat( + superType.flatMap(sType -> allInterfacesOf(elements, sType)), + interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface))) + ); + } + + private static Stream typeAndSuperType(TypeElement declarationType) { + if (declarationType.getSuperclass() instanceof DeclaredType declaredType + && declaredType.asElement() instanceof TypeElement superType) { + return Stream.of(declarationType, superType); + } else { + return Stream.of(declarationType); } - return null; } static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate filter) { @@ -95,16 +185,6 @@ static ExecutableElement findMethod(String[] names, Predicate return null; } - /** - * Returns the arguments of a method after applying a filter. - */ - static VariableElement[] findMethodArguments(ExecutableElement method, Predicate filter) { - if (method.getParameters().isEmpty()) { - return new VariableElement[0]; - } - return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new); - } - /** * Returns the name of the method used to add {@code valueType} instances * to vector or block builders.