Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.getMethod;
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
Expand All @@ -53,7 +54,6 @@
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
Expand All @@ -62,6 +62,8 @@
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
import static org.elasticsearch.compute.gen.Types.blockType;
import static org.elasticsearch.compute.gen.Types.fromString;
import static org.elasticsearch.compute.gen.Types.scratchType;
import static org.elasticsearch.compute.gen.Types.vectorType;

/**
Expand All @@ -85,7 +87,7 @@ public class AggregatorImplementer {

private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean tryToUseVectors;

public AggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -119,7 +121,8 @@ public AggregatorImplementer(
return a;
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.tryToUseVectors = aggParams.stream().anyMatch(a -> (a instanceof BlockArgument) == false)
&& aggParams.stream().noneMatch(a -> a instanceof StandardArgument && a.hasVector() == false);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -199,7 +202,7 @@ private TypeSpec type() {
builder.addMethod(addRawInput());
builder.addMethod(addRawInputExploded(true));
builder.addMethod(addRawInputExploded(false));
if (hasOnlyBlockArguments == false) {
if (tryToUseVectors) {
builder.addMethod(addRawVector(false));
builder.addMethod(addRawVector(true));
}
Expand Down Expand Up @@ -340,16 +343,18 @@ private MethodSpec addRawInputExploded(boolean hasMask) {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";
if (tryToUseVectors) {
for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";

a.resolveVectors(builder, rawBlock, "return");
a.resolveVectors(builder, rawBlock, "return");
}
}

builder.addStatement(invokeAddRaw(hasOnlyBlockArguments, hasMask));
builder.addStatement(invokeAddRaw(tryToUseVectors == false, hasMask));
return builder.build();
}

Expand Down Expand Up @@ -499,9 +504,9 @@ private MethodSpec.Builder initAddRaw(boolean blockStyle, boolean masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}
return builder;
Expand Down Expand Up @@ -610,8 +615,8 @@ private MethodSpec addIntermediateInput() {
).map(Methods::requireType).toArray(TypeMatcher[]::new)
)
);
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
for (IntermediateStateDesc interState : intermediateState) {
interState.addScratchDeclaration(builder);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was buggy prior to my change, but without causing a defect?

If there were multiple BYTES_REF state-members, they would have shared the same scratch, which seems incorrect?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems possible. I imagine we don't have any intermediate states with two strings.

}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
}
Expand Down Expand Up @@ -706,13 +711,25 @@ public String access(String position) {
if (block) {
return name();
}
String s = name() + "." + vectorAccessorName(elementType()) + "(" + position;
if (elementType().equals("BYTES_REF")) {
s += ", scratch";
String s = name() + ".";
if (vectorType(elementType) != null) {
s += vectorAccessorName(elementType()) + "(" + position;
} else {
s += getMethod(fromString(elementType())) + "(" + name() + ".getFirstValueIndex(" + position + ")";
}
if (scratchType(elementType()) != null) {
s += ", " + name() + "Scratch";
}
return s + ")";
}

public void addScratchDeclaration(MethodSpec.Builder builder) {
ClassName scratchType = scratchType(elementType());
if (scratchType != null) {
builder.addStatement("$T $L = new $T()", scratchType, name() + "Scratch", scratchType);
}
}

public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("Block $L = page.getBlock(channels.get($L))", name + "Uncast", offset);
ClassName blockType = blockType(elementType());
Expand All @@ -721,7 +738,7 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("return");
builder.endControlFlow();
}
if (block) {
if (block || vectorType(elementType) == null) {
builder.addStatement("$T $L = ($T) $L", blockType, name, blockType, name + "Uncast");
} else {
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
Expand All @@ -732,6 +749,7 @@ public TypeName combineArgType() {
var type = Types.fromString(elementType);
return block ? blockType(type) : type;
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.compute.gen.argument.BlockArgument;
import org.elasticsearch.compute.gen.argument.BuilderArgument;
import org.elasticsearch.compute.gen.argument.FixedArgument;
import org.elasticsearch.compute.gen.argument.PositionArgument;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class EvaluatorImplementer {
private final ProcessFunction processFunction;
private final ClassName implementation;
private final boolean processOutputsMultivalued;
private final boolean vectorsUnsupported;
private final boolean allNullsIsNull;

public EvaluatorImplementer(
Expand All @@ -69,6 +71,10 @@ public EvaluatorImplementer(
declarationType.getSimpleName() + extraName + "Evaluator"
);
this.processOutputsMultivalued = this.processFunction.hasBlockType;
boolean anyParameterNotSupportingVectors = this.processFunction.args.stream()
.filter(a -> a instanceof FixedArgument == false && a instanceof PositionArgument == false)
.anyMatch(a -> a.hasVector() == false);
vectorsUnsupported = processOutputsMultivalued || anyParameterNotSupportingVectors;
this.allNullsIsNull = allNullsIsNull;
}

Expand Down Expand Up @@ -101,7 +107,7 @@ private TypeSpec type() {
builder.addMethod(eval());
builder.addMethod(processFunction.baseRamBytesUsed());

if (processOutputsMultivalued) {
if (vectorsUnsupported) {
if (processFunction.args.stream().anyMatch(x -> x instanceof FixedArgument == false)) {
builder.addMethod(realEval(true));
}
Expand Down Expand Up @@ -145,7 +151,7 @@ private MethodSpec eval() {
builder.addModifiers(Modifier.PUBLIC).returns(BLOCK).addParameter(PAGE, "page");
processFunction.args.forEach(a -> a.evalToBlock(builder));
String invokeBlockEval = invokeRealEval(true);
if (processOutputsMultivalued) {
if (vectorsUnsupported) {
builder.addStatement(invokeBlockEval);
} else {
processFunction.args.forEach(a -> a.resolveVectors(builder, invokeBlockEval));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_EVALUATOR_CONTEXT;
Expand Down Expand Up @@ -93,6 +92,7 @@ public class GroupingAggregatorImplementer {
private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean allArgumentsSupportVectors;

public GroupingAggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -128,6 +128,7 @@ public GroupingAggregatorImplementer(
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.allArgumentsSupportVectors = aggParams.stream().noneMatch(a -> a instanceof StandardArgument && a.hasVector() == false);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -204,7 +205,7 @@ private TypeSpec type() {
builder.addMethod(prepareProcessRawInputPage());
for (ClassName groupIdClass : GROUP_IDS_CLASSES) {
builder.addMethod(addRawInputLoop(groupIdClass, false));
if (hasOnlyBlockArguments == false) {
if (hasOnlyBlockArguments == false && allArgumentsSupportVectors) {
builder.addMethod(addRawInputLoop(groupIdClass, true));
}
builder.addMethod(addIntermediateInput(groupIdClass));
Expand Down Expand Up @@ -330,26 +331,31 @@ private MethodSpec prepareProcessRawInputPage() {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
builder.addStatement(
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
String groupIdTrackingStatement = "maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")";

if (allArgumentsSupportVectors) {

for (Argument a : aggParams) {
builder.addStatement(
"maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")"
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
returnAddInput(builder, false);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}
builder.endControlFlow();
}
builder.endControlFlow();
returnAddInput(builder, true);
} else {
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}

returnAddInput(builder, true);
return builder.build();
}

Expand Down Expand Up @@ -443,9 +449,9 @@ private MethodSpec addRawInputLoop(TypeName groupsType, boolean valuesAreVector)
);
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}

Expand Down Expand Up @@ -645,11 +651,7 @@ private MethodSpec addIntermediateInput(TypeName groupsType) {
.collect(Collectors.joining(", "));
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType);
} else {
if (intermediateState.stream()
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
.anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
intermediateState.forEach(state -> state.addScratchDeclaration(builder));
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
{
if (groupsIsBlock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ public static String getMethod(TypeName elementType) {
if (elementType.equals(TypeName.FLOAT)) {
return "getFloat";
}
if (elementType.equals(Types.EXPONENTIAL_HISTOGRAM)) {
return "getExponentialHistogram";
}
throw new IllegalArgumentException("unknown get method for [" + elementType + "]");
}

Expand Down
Loading