Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.IntermediateState;
import org.elasticsearch.compute.gen.Methods.TypeMatcher;

import java.util.Arrays;
import java.util.List;
Expand All @@ -33,8 +34,14 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.findMethod;
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
import static org.elasticsearch.compute.gen.Methods.requireArgs;
import static org.elasticsearch.compute.gen.Methods.requireName;
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireType;
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
Expand Down Expand Up @@ -66,8 +73,6 @@ public class AggregatorImplementer {
private final List<TypeMirror> warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
private final ExecutableElement combineIntermediate;
private final ExecutableElement evaluateFinal;
private final ClassName implementation;
private final List<IntermediateStateDesc> intermediateState;
private final List<Parameter> createParameters;
Expand All @@ -84,21 +89,24 @@ public AggregatorImplementer(
this.declarationType = declarationType;
this.warnExceptions = warnExceptions;

this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true);
this.init = requireStaticMethod(
declarationType,
// This should be more restrictive and require org.elasticsearch.compute.aggregation.AggregatorState
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is enforced in 57d34e4 please let me know if that should be merged as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That commit looks good to me. It exists already right? So enforcing it looks fine

requirePrimitiveOrImplements(elements, Types.RELEASABLE),
requireName("init", "initSingle"),
requireAnyArgs("<arbitrary init arguments>")
);
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false);

this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
if (e.getParameters().size() == 0) {
return false;
}
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
});
this.combine = requireStaticMethod(
declarationType,
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
requireName("combine"),
requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
);
// TODO support multiple parameters
this.aggParam = AggregationParameter.create(combine.getParameters().get(1).asType());

this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
Expand Down Expand Up @@ -447,12 +455,7 @@ private MethodSpec addIntermediateInput() {
interState.assignToVariable(builder, i);
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
}
if (combineIntermediate != null) {
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
} else if (aggState.declaredType().isPrimitive()) {
if (aggState.declaredType().isPrimitive()) {
if (warnExceptions.isEmpty()) {
assert intermediateState.size() == 2;
assert intermediateState.get(1).name().equals("seen");
Expand Down Expand Up @@ -485,7 +488,21 @@ private MethodSpec addIntermediateInput() {
});
builder.endControlFlow();
} else {
throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
requireStaticMethod(
declarationType,
requireVoidType(),
requireName("combineIntermediate"),
requireArgs(
Stream.concat(
Stream.of(aggState.declaredType()), // aggState
intermediateState.stream().map(IntermediateStateDesc::combineArgType) // intermediate state
).map(Methods::requireType).toArray(TypeMatcher[]::new)
)
);
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
}
return builder.build();
}
Expand Down Expand Up @@ -524,7 +541,7 @@ private MethodSpec evaluateFinal() {
builder.addStatement("return");
builder.endControlFlow();
}
if (evaluateFinal == null) {
if (aggState.declaredType().isPrimitive()) {
builder.addStatement(switch (aggState.declaredType().toString()) {
case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)";
case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)";
Expand All @@ -534,6 +551,12 @@ private MethodSpec evaluateFinal() {
default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]");
});
} else {
requireStaticMethod(
declarationType,
requireType(BLOCK),
requireName("evaluateFinal"),
requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT))
);
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
}
return builder.build();
Expand Down Expand Up @@ -593,6 +616,11 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
}
}

public TypeName combineArgType() {
var type = Types.fromString(elementType);
return block ? blockType(type) : type;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
Expand All @@ -34,10 +35,17 @@

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
import static org.elasticsearch.compute.gen.Methods.findMethod;
import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
import static org.elasticsearch.compute.gen.Methods.requireArgs;
import static org.elasticsearch.compute.gen.Methods.requireName;
import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireType;
import static org.elasticsearch.compute.gen.Methods.requireVoidType;
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
Expand Down Expand Up @@ -71,9 +79,6 @@ public class GroupingAggregatorImplementer {
private final List<TypeMirror> warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
private final ExecutableElement combineStates;
private final ExecutableElement evaluateFinal;
private final ExecutableElement combineIntermediate;
private final List<Parameter> createParameters;
private final ClassName implementation;
private final List<AggregatorImplementer.IntermediateStateDesc> intermediateState;
Expand All @@ -92,22 +97,23 @@ public GroupingAggregatorImplementer(
this.declarationType = declarationType;
this.warnExceptions = warnExceptions;

this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true);
this.init = requireStaticMethod(
declarationType,
// This should be more restrictive and require org.elasticsearch.compute.aggregation.GroupingAggregatorState
requirePrimitiveOrImplements(elements, Types.RELEASABLE),
requireName("init", "initGrouping"),
requireAnyArgs("<arbitrary init arguments>")
);
this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true);

this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
if (e.getParameters().size() == 0) {
return false;
}
TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
return Objects.equals(firstParamType.toString(), aggState.declaredType().toString());
});
// TODO support multiple parameters
this.combine = requireStaticMethod(
declarationType,
aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
requireName("combine"),
combineArgs(aggState, includeTimestampVector)
);
this.aggParam = AggregationParameter.create(combine.getParameters().get(combine.getParameters().size() - 1).asType());

this.combineStates = findMethod(declarationType, "combineStates");
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
Expand All @@ -125,6 +131,25 @@ public GroupingAggregatorImplementer(
this.includeTimestampVector = includeTimestampVector;
}

private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
if (aggState.declaredType().isPrimitive()) {
return requireArgs(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"));
} else if (includeTimestampVector) {
return requireArgs(
requireType(aggState.declaredType()),
requireType(TypeName.INT),
requireType(TypeName.LONG), // @timestamp
requireAnyType("<aggregation input column type>")
);
} else {
return requireArgs(
requireType(aggState.declaredType()),
requireType(TypeName.INT),
requireAnyType("<aggregation input column type>")
);
}
}

public ClassName implementation() {
return implementation;
}
Expand Down Expand Up @@ -557,31 +582,33 @@ private MethodSpec addIntermediateInput() {
});
builder.endControlFlow();
} else {
builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType);
var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
requireStaticMethod(
declarationType,
requireVoidType(),
requireName("combineIntermediate"),
requireArgs(
Stream.of(
Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
)
);

builder.addStatement(
"$T.combineIntermediate(state, groupId, "
+ intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
+ (stateHasBlock ? ", groupPosition + positionOffset" : "")
+ ")",
declarationType
);
}
builder.endControlFlow();
}
return builder.build();
}

String intermediateStateRowAccess() {
String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "));
if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) {
rowAccess += ", groupPosition + positionOffset";
}
return rowAccess;
}

private void combineStates(MethodSpec.Builder builder) {
if (combineStates == null) {
builder.beginControlFlow("if (inState.hasValue(position))");
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
builder.endControlFlow();
return;
}
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
}

private MethodSpec addIntermediateRowInput() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput");
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
Expand All @@ -593,7 +620,24 @@ private MethodSpec addIntermediateRowInput() {
builder.endControlFlow();
builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation);
builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
combineStates(builder);
if (aggState.declaredType().isPrimitive()) {
builder.beginControlFlow("if (inState.hasValue(position))");
builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
builder.endControlFlow();
} else {
requireStaticMethod(
declarationType,
requireVoidType(),
requireName("combineStates"),
requireArgs(
requireType(aggState.declaredType()),
requireType(TypeName.INT),
requireType(aggState.declaredType()),
requireType(TypeName.INT)
)
);
builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
}
return builder.build();
}

Expand All @@ -617,9 +661,15 @@ private MethodSpec evaluateFinal() {
.addParameter(INT_VECTOR, "selected")
.addParameter(DRIVER_CONTEXT, "driverContext");

if (evaluateFinal == null) {
if (aggState.declaredType().isPrimitive()) {
builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)");
} else {
requireStaticMethod(
declarationType,
requireType(BLOCK),
requireName("evaluateFinal"),
requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT))
);
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType);
}
return builder.build();
Expand Down
Loading