Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,20 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.findMethod;
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
import static org.elasticsearch.compute.gen.Methods.requireMethod;
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.BYTES_REF_BLOCK;
import static org.elasticsearch.compute.gen.Types.BYTES_REF_VECTOR;
import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK;
import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.FLOAT_BLOCK;
import static org.elasticsearch.compute.gen.Types.FLOAT_VECTOR;
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
import static org.elasticsearch.compute.gen.Types.INT_BLOCK;
import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
import static org.elasticsearch.compute.gen.Types.blockType;
Expand All @@ -78,8 +67,6 @@ public class AggregatorImplementer {
private final List<TypeMirror> warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
private final ExecutableElement combineIntermediate;
private final ExecutableElement evaluateFinal;
private final ClassName implementation;
private final TypeName stateType;
private final boolean stateTypeHasSeen;
Expand Down Expand Up @@ -114,8 +101,6 @@ public AggregatorImplementer(
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString());
});
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
Expand Down Expand Up @@ -144,7 +129,7 @@ private TypeName choseStateType() {
if (false == initReturn.isPrimitive()) {
return initReturn;
}
String simpleName = firstUpper(initReturn.toString());
String simpleName = capitalize(initReturn.toString());
if (warnExceptions.isEmpty()) {
return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "State");
}
Expand All @@ -157,50 +142,22 @@ static String valueType(ExecutableElement init, ExecutableElement combine) {
return combine.getParameters().get(combine.getParameters().size() - 1).asType().toString();
}
String initReturn = init.getReturnType().toString();
switch (initReturn) {
case "double":
return "double";
case "float":
return "float";
case "long":
return "long";
case "int":
return "int";
case "boolean":
return "boolean";
default:
throw new IllegalArgumentException("unknown primitive type for " + initReturn);
if (Types.isPrimitive(initReturn)) {
return initReturn;
}
throw new IllegalArgumentException("unknown primitive type for " + initReturn);
}

static ClassName valueBlockType(ExecutableElement init, ExecutableElement combine) {
return switch (valueType(init, combine)) {
case "boolean" -> BOOLEAN_BLOCK;
case "double" -> DOUBLE_BLOCK;
case "float" -> FLOAT_BLOCK;
case "long" -> LONG_BLOCK;
case "int", "int[]" -> INT_BLOCK;
case "org.apache.lucene.util.BytesRef" -> BYTES_REF_BLOCK;
default -> throw new IllegalArgumentException("unknown block type for " + valueType(init, combine));
};
return Types.blockType(valueType(init, combine));
}

static ClassName valueVectorType(ExecutableElement init, ExecutableElement combine) {
return switch (valueType(init, combine)) {
case "boolean" -> BOOLEAN_VECTOR;
case "double" -> DOUBLE_VECTOR;
case "float" -> FLOAT_VECTOR;
case "long" -> LONG_VECTOR;
case "int", "int[]" -> INT_VECTOR;
case "org.apache.lucene.util.BytesRef" -> BYTES_REF_VECTOR;
default -> throw new IllegalArgumentException("unknown vector type for " + valueType(init, combine));
};
return Types.vectorType(valueType(init, combine));
}

public static String firstUpper(String s) {
String head = s.toString().substring(0, 1).toUpperCase(Locale.ROOT);
String tail = s.toString().substring(1);
return head + tail;
public static String capitalize(String s) {
return Character.toUpperCase(s.charAt(0)) + s.substring(1);
}

public JavaFile sourceFile() {
Expand Down Expand Up @@ -444,7 +401,7 @@ private MethodSpec addRawBlock(boolean masked) {
String arrayType = valueTypeString();
builder.addStatement("$L[] valuesArray = new $L[end - start]", arrayType, arrayType);
builder.beginControlFlow("for (int i = start; i < end; i++)");
builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", firstUpper(arrayType));
builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", capitalize(arrayType));
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
Expand Down Expand Up @@ -479,7 +436,7 @@ private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder
declarationType,
returnType,
blockVariable,
firstUpper(combine.getParameters().get(1).asType().toString())
capitalize(combine.getParameters().get(1).asType().toString())
);
}

Expand All @@ -492,7 +449,7 @@ private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVari
"$T.combine(state, $L.get$L(i))",
declarationType,
blockVariable,
firstUpper(combine.getParameters().get(1).asType().toString())
capitalize(combine.getParameters().get(1).asType().toString())
);
}

Expand Down Expand Up @@ -526,12 +483,7 @@ private MethodSpec addIntermediateInput() {
interState.assignToVariable(builder, i);
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
}
if (combineIntermediate != null) {
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
} else if (hasPrimitiveState()) {
if (hasPrimitiveState()) {
if (warnExceptions.isEmpty()) {
assert intermediateState.size() == 2;
assert intermediateState.get(1).name().equals("seen");
Expand All @@ -547,7 +499,6 @@ private MethodSpec addIntermediateInput() {
}
builder.nextControlFlow("else if (seen.getBoolean(0))");
}

warningsBlock(builder, () -> {
var state = intermediateState.get(0);
var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))";
Expand All @@ -556,32 +507,39 @@ private MethodSpec addIntermediateInput() {
});
builder.endControlFlow();
} else {
throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
requireMethod(
declarationType,
"combineIntermediate",
"void",
Stream.concat(Stream.of(stateType.toString()), intermediateState.stream().map(intermediateStateDesc -> {
var type = Types.fromString(intermediateStateDesc.elementType());
return intermediateStateDesc.block ? blockType(type).toString() : type.toString();
})).toArray(String[]::new)
);
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
builder.addStatement(
"$T.combineIntermediate(state, " + intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", ")) + ")",
declarationType
);
}
return builder.build();
}

String intermediateStateRowAccess() {
return intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", "));
}

private String primitiveStateMethod() {
switch (stateType.toString()) {
case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState":
return "booleanValue";
case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState":
return "intValue";
case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState":
return "longValue";
case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState":
return "doubleValue";
case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState":
return "floatValue";
default:
throw new IllegalArgumentException(
"don't know how to fetch primitive values from " + stateType + ". define combineIntermediate."
);
}
return switch (stateType.toString()) {
case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState" ->
"booleanValue";
case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState" -> "intValue";
case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState" ->
"longValue";
case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState" ->
"doubleValue";
case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState" ->
"floatValue";
default -> throw new IllegalArgumentException("don't know how to fetch primitive values from " + stateType + ".");
};
}

private MethodSpec evaluateIntermediate() {
Expand Down Expand Up @@ -611,9 +569,15 @@ private MethodSpec evaluateFinal() {
builder.addStatement("return");
builder.endControlFlow();
}
if (evaluateFinal == null) {
if (hasPrimitiveState()) {
primitiveStateToResult(builder);
} else {
requireMethod(
declarationType,
"evaluateFinal",
"org.elasticsearch.compute.data.Block",
new String[] { stateType.toString(), "org.elasticsearch.compute.operator.DriverContext" }
);
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
}
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.AggregatorImplementer.firstUpper;
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
import static org.elasticsearch.compute.gen.AggregatorImplementer.valueBlockType;
import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType;
import static org.elasticsearch.compute.gen.Methods.findMethod;
Expand Down Expand Up @@ -136,7 +136,7 @@ private TypeName choseStateType() {
if (false == initReturn.isPrimitive()) {
return initReturn;
}
String simpleName = firstUpper(initReturn.toString());
String simpleName = capitalize(initReturn.toString());
if (warnExceptions.isEmpty()) {
return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "ArrayState");
}
Expand Down Expand Up @@ -401,7 +401,7 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
String arrayType = valueTypeString();
builder.addStatement("$L[] valuesArray = new $L[valuesEnd - valuesStart]", arrayType, arrayType);
builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", firstUpper(arrayType));
builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", capitalize(arrayType));
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
Expand Down Expand Up @@ -447,7 +447,7 @@ private void combineRawInputForPrimitive(MethodSpec.Builder builder, String bloc
"state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
declarationType,
blockVariable,
firstUpper(valueTypeName().toString()),
capitalize(valueTypeName().toString()),
offsetVariable
);
}
Expand All @@ -461,13 +461,13 @@ private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVari
"$T.combine(state, groupId, $L.get$L($L))",
declarationType,
blockVariable,
firstUpper(valueTypeName().toString()),
capitalize(valueTypeName().toString()),
offsetVariable
);
}

private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) {
String blockType = firstUpper(valueTypeName().toString());
String blockType = capitalize(valueTypeName().toString());
if (offsetVariable.contains(" + ")) {
builder.addStatement("var valuePosition = $L", offsetVariable);
offsetVariable = "valuePosition";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import com.squareup.javapoet.TypeName;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Stream;

import javax.lang.model.element.Element;
import javax.lang.model.element.ExecutableElement;
Expand All @@ -21,6 +24,7 @@
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.ElementFilter;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
Expand Down Expand Up @@ -49,6 +53,7 @@
* Finds declared methods for the code generator.
*/
public class Methods {

static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
if (result == null) {
Expand Down Expand Up @@ -95,14 +100,31 @@ static ExecutableElement findMethod(String[] names, Predicate<ExecutableElement>
return null;
}

/**
* Returns the arguments of a method after applying a filter.
*/
static VariableElement[] findMethodArguments(ExecutableElement method, Predicate<VariableElement> filter) {
if (method.getParameters().isEmpty()) {
return new VariableElement[0];
static void requireMethod(TypeElement element, String name, String returnType, String... parameterTypes) {
var method = findMethod(new String[] { name }, e -> true, element, superClassOf(element));
if (method == null || isNotSame(method.getReturnType(), returnType) || isNotSame(method.getParameters(), parameterTypes)) {
throw new IllegalArgumentException("Requires method " + signature(element, name, returnType, parameterTypes));
}
}

private static boolean isNotSame(TypeMirror type, String required) {
return Objects.equals(type.toString(), required) == false;
}

private static boolean isNotSame(List<? extends VariableElement> types, String[] required) {
if (types.size() != required.length) {
return true;
}
for (int i = 0; i < types.size(); i++) {
if (isNotSame(types.get(i).asType(), required[i])) {
return true;
}
}
return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new);
return false;
}

private static String signature(TypeElement element, String name, String returnType, String[] parameterTypes) {
return "public static " + returnType + " " + element + "#" + name + Stream.of(parameterTypes).collect(joining(", ", "(", ")"));
}

/**
Expand Down
Loading