Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
import static org.elasticsearch.compute.gen.Methods.requireArgs;
Expand Down Expand Up @@ -76,6 +77,7 @@ public class AggregatorImplementer {
private final List<TypeMirror> warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
private final ExecutableElement first;
private final List<Parameter> createParameters;
private final ClassName implementation;
private final List<IntermediateStateDesc> intermediateState;
Expand Down Expand Up @@ -114,6 +116,18 @@ public AggregatorImplementer(
.filter(f -> false == f.type().equals(BIG_ARRAYS) && false == f.type().equals(DRIVER_CONTEXT))
.toList();

this.first = aggState.declaredType.isPrimitive()
? null
: optionalStaticMethod(
declarationType,
requireVoidType(),
requireName("first"),
requireArgs(combine.getParameters().stream().map(p -> requireType(TypeName.get(p.asType()))).toArray(TypeMatcher[]::new))
);
if (this.aggState.hasSeen == false && this.first != null) {
throw new IllegalArgumentException("[first] method not supported without [seen] on agg state");
}

this.implementation = ClassName.get(
elements.getPackageOf(declarationType).toString(),
(declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator")
Expand Down Expand Up @@ -339,10 +353,17 @@ private MethodSpec addRawVector(boolean masked) {
builder.addComment("This type does not support vectors because all values are multi-valued");
return builder.build();
}

if (first != null) {
builder.addComment("Find the first value up front in the Vector path which is more complex but should be faster");
builder.addStatement("int valuesPosition = 0");
addRawVectorWithFirst(builder, true, masked);
addRawVectorWithFirst(builder, false, masked);
return builder.build();
}
if (aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
}

builder.beginControlFlow(
"for (int valuesPosition = 0; valuesPosition < $L.getPositionCount(); valuesPosition++)",
aggParams.getFirst().vectorName()
Expand All @@ -354,13 +375,39 @@ private MethodSpec addRawVector(boolean masked) {
for (AggregationParameter p : aggParams) {
p.read(builder, true);
}
combineRawInput(builder);

combineRawInput(builder, false);
}
builder.endControlFlow();
return builder.build();
}

private void addRawVectorWithFirst(MethodSpec.Builder builder, boolean firstPass, boolean masked) {
builder.beginControlFlow(
firstPass
? "while (state.seen() == false && valuesPosition < $L.getPositionCount())"
: "while (valuesPosition < $L.getPositionCount())",
aggParams.getFirst().vectorName()
);
{
if (masked) {
builder.beginControlFlow("if (mask.getBoolean(valuesPosition) == false)");
builder.addStatement("valuesPosition++");
builder.addStatement("continue");
builder.endControlFlow();
}
for (AggregationParameter p : aggParams) {
p.read(builder, true);
}
combineRawInput(builder, firstPass);
builder.addStatement("valuesPosition++");
if (firstPass) {
builder.addStatement("state.seen(true)");
builder.addStatement("break");
}
}
builder.endControlFlow();
}

private MethodSpec addRawBlock(boolean masked) {
MethodSpec.Builder builder = initAddRaw(false, masked);

Expand All @@ -374,9 +421,6 @@ private MethodSpec addRawBlock(boolean masked) {
builder.addStatement("continue");
builder.endControlFlow();
}
if (aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
}

if (aggParams.getFirst().isArray()) {
if (aggParams.size() > 1) {
Expand Down Expand Up @@ -412,7 +456,24 @@ private MethodSpec addRawBlock(boolean masked) {
);
p.read(builder, false);
}
combineRawInput(builder);
if (first != null) {
builder.addComment("Check seen in every iteration to save on complexity in the Block path");
builder.beginControlFlow("if (state.seen())");
{
combineRawInput(builder, false);
}
builder.nextControlFlow("else");
{
builder.addStatement("state.seen(true)");
combineRawInput(builder, true);
}
builder.endControlFlow();
} else {
if (aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
}
combineRawInput(builder, false);
}
for (AggregationParameter p : aggParams) {
builder.endControlFlow();
}
Expand Down Expand Up @@ -443,22 +504,26 @@ private MethodSpec.Builder initAddRaw(boolean valuesAreVector, boolean masked) {
return builder;
}

private void combineRawInput(MethodSpec.Builder builder) {
private void combineRawInput(MethodSpec.Builder builder, boolean useFirst) {
TypeName returnType = TypeName.get(combine.getReturnType());
warningsBlock(builder, () -> invokeCombineRawInput(returnType, builder));
warningsBlock(builder, () -> invokeCombineRawInput(returnType, builder, useFirst));
}

private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder builder) {
private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder builder, boolean useFirst) {
StringBuilder pattern = new StringBuilder();
List<Object> params = new ArrayList<>();
if (returnType.isPrimitive()) {
if (useFirst) {
throw new IllegalArgumentException("[first] not supported with primitive");
}
pattern.append("state.$TValue($T.combine(state.$TValue()");
params.add(returnType);
params.add(declarationType);
params.add(returnType);
} else if (returnType == TypeName.VOID) {
pattern.append("$T.combine(state");
pattern.append("$T.$L(state");
params.add(declarationType);
params.add(useFirst ? first.getSimpleName() : combine.getSimpleName());
} else {
throw new IllegalArgumentException("combine must return void or a primitive");
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading