Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Methods.getMethod;
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
Expand All @@ -53,7 +54,6 @@
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
Expand All @@ -62,6 +62,8 @@
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
import static org.elasticsearch.compute.gen.Types.blockType;
import static org.elasticsearch.compute.gen.Types.fromString;
import static org.elasticsearch.compute.gen.Types.scratchType;
import static org.elasticsearch.compute.gen.Types.vectorType;

/**
Expand All @@ -85,7 +87,7 @@ public class AggregatorImplementer {

private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean tryToUseVectors;

public AggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -119,7 +121,8 @@ public AggregatorImplementer(
return a;
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.tryToUseVectors = aggParams.stream().anyMatch(a -> (a instanceof BlockArgument) == false)
&& aggParams.stream().noneMatch(a -> a.supportsVectorReadAccess() == false);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -199,7 +202,7 @@ private TypeSpec type() {
builder.addMethod(addRawInput());
builder.addMethod(addRawInputExploded(true));
builder.addMethod(addRawInputExploded(false));
if (hasOnlyBlockArguments == false) {
if (tryToUseVectors) {
builder.addMethod(addRawVector(false));
builder.addMethod(addRawVector(true));
}
Expand Down Expand Up @@ -340,16 +343,18 @@ private MethodSpec addRawInputExploded(boolean hasMask) {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";
if (tryToUseVectors) {
for (Argument a : aggParams) {
String rawBlock = "addRawBlock("
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ (hasMask ? ", mask" : "")
+ ")";

a.resolveVectors(builder, rawBlock, "return");
a.resolveVectors(builder, rawBlock, "return");
}
}

builder.addStatement(invokeAddRaw(hasOnlyBlockArguments, hasMask));
builder.addStatement(invokeAddRaw(tryToUseVectors == false, hasMask));
return builder.build();
}

Expand Down Expand Up @@ -499,9 +504,9 @@ private MethodSpec.Builder initAddRaw(boolean blockStyle, boolean masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}
return builder;
Expand Down Expand Up @@ -610,8 +615,8 @@ private MethodSpec addIntermediateInput() {
).map(Methods::requireType).toArray(TypeMatcher[]::new)
)
);
if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
for (IntermediateStateDesc interState : intermediateState) {
interState.addScratchDeclaration(builder);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was buggy prior to my change, but without causing a defect?

If there were multiple BYTES_REF state-members, they would have shared the same scratch, which seems incorrect?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems possible. I imagine we don't have any intermediate states with two strings.

}
builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
}
Expand Down Expand Up @@ -706,13 +711,25 @@ public String access(String position) {
if (block) {
return name();
}
String s = name() + "." + vectorAccessorName(elementType()) + "(" + position;
if (elementType().equals("BYTES_REF")) {
s += ", scratch";
String s = name() + ".";
if (vectorType(elementType) != null) {
s += vectorAccessorName(elementType()) + "(" + position;
} else {
s += getMethod(fromString(elementType())) + "(" + name() + ".getFirstValueIndex(" + position + ")";
}
if (scratchType(elementType()) != null) {
s += ", " + name() + "Scratch";
}
return s + ")";
}

public void addScratchDeclaration(MethodSpec.Builder builder) {
ClassName scratchType = scratchType(elementType());
if (scratchType != null) {
builder.addStatement("$T $L = new $T()", scratchType, name() + "Scratch", scratchType);
}
}

public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("Block $L = page.getBlock(channels.get($L))", name + "Uncast", offset);
ClassName blockType = blockType(elementType());
Expand All @@ -721,7 +738,7 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("return");
builder.endControlFlow();
}
if (block) {
if (block || vectorType(elementType) == null) {
builder.addStatement("$T $L = ($T) $L", blockType, name, blockType, name + "Uncast");
} else {
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
Expand All @@ -732,6 +749,7 @@ public TypeName combineArgType() {
var type = Types.fromString(elementType);
return block ? blockType(type) : type;
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class EvaluatorImplementer {
private final ProcessFunction processFunction;
private final ClassName implementation;
private final boolean processOutputsMultivalued;
private final boolean vectorsUnsupported;
private final boolean allNullsIsNull;

public EvaluatorImplementer(
Expand All @@ -69,6 +70,8 @@ public EvaluatorImplementer(
declarationType.getSimpleName() + extraName + "Evaluator"
);
this.processOutputsMultivalued = this.processFunction.hasBlockType;
boolean anyParameterNotSupportingVectors = this.processFunction.args.stream().anyMatch(a -> a.supportsVectorReadAccess() == false);
vectorsUnsupported = processOutputsMultivalued || anyParameterNotSupportingVectors;
this.allNullsIsNull = allNullsIsNull;
}

Expand Down Expand Up @@ -101,7 +104,7 @@ private TypeSpec type() {
builder.addMethod(eval());
builder.addMethod(processFunction.baseRamBytesUsed());

if (processOutputsMultivalued) {
if (vectorsUnsupported) {
if (processFunction.args.stream().anyMatch(x -> x instanceof FixedArgument == false)) {
builder.addMethod(realEval(true));
}
Expand Down Expand Up @@ -145,7 +148,7 @@ private MethodSpec eval() {
builder.addModifiers(Modifier.PUBLIC).returns(BLOCK).addParameter(PAGE, "page");
processFunction.args.forEach(a -> a.evalToBlock(builder));
String invokeBlockEval = invokeRealEval(true);
if (processOutputsMultivalued) {
if (vectorsUnsupported) {
builder.addStatement(invokeBlockEval);
} else {
processFunction.args.forEach(a -> a.resolveVectors(builder, invokeBlockEval));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_EVALUATOR_CONTEXT;
Expand Down Expand Up @@ -93,6 +92,7 @@ public class GroupingAggregatorImplementer {
private final AggregationState aggState;
private final List<Argument> aggParams;
private final boolean hasOnlyBlockArguments;
private final boolean allArgumentsSupportVectors;

public GroupingAggregatorImplementer(
Elements elements,
Expand Down Expand Up @@ -128,6 +128,7 @@ public GroupingAggregatorImplementer(
}).filter(a -> a instanceof PositionArgument == false).toList();

this.hasOnlyBlockArguments = this.aggParams.stream().allMatch(a -> a instanceof BlockArgument);
this.allArgumentsSupportVectors = aggParams.stream().noneMatch(a -> a.supportsVectorReadAccess() == false);

this.createParameters = init.getParameters()
.stream()
Expand Down Expand Up @@ -204,7 +205,7 @@ private TypeSpec type() {
builder.addMethod(prepareProcessRawInputPage());
for (ClassName groupIdClass : GROUP_IDS_CLASSES) {
builder.addMethod(addRawInputLoop(groupIdClass, false));
if (hasOnlyBlockArguments == false) {
if (hasOnlyBlockArguments == false && allArgumentsSupportVectors) {
builder.addMethod(addRawInputLoop(groupIdClass, true));
}
builder.addMethod(addIntermediateInput(groupIdClass));
Expand Down Expand Up @@ -330,26 +331,31 @@ private MethodSpec prepareProcessRawInputPage() {
builder.addStatement("$T $L = page.getBlock(channels.get($L))", a.dataType(true), a.blockName(), i);
}

for (Argument a : aggParams) {
builder.addStatement(
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
String groupIdTrackingStatement = "maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")";

if (allArgumentsSupportVectors) {

for (Argument a : aggParams) {
builder.addStatement(
"maybeEnableGroupIdTracking(seenGroupIds, "
+ aggParams.stream().map(arg -> arg.blockName()).collect(joining(", "))
+ ")"
"$T $L = $L.asVector()",
vectorType(a.elementType()),
(a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName(),
a.blockName()
);
returnAddInput(builder, false);
builder.beginControlFlow("if ($L == null)", (a instanceof BlockArgument) ? (a.name() + "Vector") : a.vectorName());
{
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}
builder.endControlFlow();
}
builder.endControlFlow();
returnAddInput(builder, true);
} else {
builder.addStatement(groupIdTrackingStatement);
returnAddInput(builder, false);
}

returnAddInput(builder, true);
return builder.build();
}

Expand Down Expand Up @@ -443,9 +449,9 @@ private MethodSpec addRawInputLoop(TypeName groupsType, boolean valuesAreVector)
);
}
for (Argument a : aggParams) {
if (a.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T $L = new $T()", BYTES_REF, a.scratchName(), BYTES_REF);
if (a.scratchType() != null) {
// Add scratch var that will be used for some blocks/vectors, e.g. for bytes_ref
builder.addStatement("$T $L = new $T()", a.scratchType(), a.scratchName(), a.scratchType());
}
}

Expand Down Expand Up @@ -645,11 +651,7 @@ private MethodSpec addIntermediateInput(TypeName groupsType) {
.collect(Collectors.joining(", "));
builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType);
} else {
if (intermediateState.stream()
.map(AggregatorImplementer.IntermediateStateDesc::elementType)
.anyMatch(n -> n.equals("BYTES_REF"))) {
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
intermediateState.forEach(state -> state.addScratchDeclaration(builder));
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
{
if (groupsIsBlock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ public static String getMethod(TypeName elementType) {
if (elementType.equals(TypeName.FLOAT)) {
return "getFloat";
}
if (elementType.equals(Types.EXPONENTIAL_HISTOGRAM)) {
return "getExponentialHistogram";
}
throw new IllegalArgumentException("unknown get method for [" + elementType + "]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,16 @@ public class Types {
public static final ClassName LONG_BLOCK = ClassName.get(DATA_PACKAGE, "LongBlock");
public static final ClassName DOUBLE_BLOCK = ClassName.get(DATA_PACKAGE, "DoubleBlock");
public static final ClassName FLOAT_BLOCK = ClassName.get(DATA_PACKAGE, "FloatBlock");
public static final ClassName EXPONENTIAL_HISTOGRAM_BLOCK = ClassName.get(DATA_PACKAGE, "ExponentialHistogramBlock");
public static final ClassName EXPONENTIAL_HISTOGRAM_SCRATCH = ClassName.get(DATA_PACKAGE, "ExponentialHistogramScratch");

static final ClassName BOOLEAN_BLOCK_BUILDER = BOOLEAN_BLOCK.nestedClass("Builder");
static final ClassName BYTES_REF_BLOCK_BUILDER = BYTES_REF_BLOCK.nestedClass("Builder");
static final ClassName INT_BLOCK_BUILDER = INT_BLOCK.nestedClass("Builder");
static final ClassName LONG_BLOCK_BUILDER = LONG_BLOCK.nestedClass("Builder");
static final ClassName DOUBLE_BLOCK_BUILDER = DOUBLE_BLOCK.nestedClass("Builder");
static final ClassName FLOAT_BLOCK_BUILDER = FLOAT_BLOCK.nestedClass("Builder");
static final ClassName EXPONENTIAL_HISTOGRAM_BLOCK_BUILDER = ClassName.get(DATA_PACKAGE, "ExponentialHistogramBlockBuilder");

static final ClassName ELEMENT_TYPE = ClassName.get(DATA_PACKAGE, "ElementType");

Expand Down Expand Up @@ -133,24 +136,32 @@ public class Types {
static final ClassName SOURCE = ClassName.get("org.elasticsearch.xpack.esql.core.tree", "Source");

public static final ClassName BYTES_REF = ClassName.get("org.apache.lucene.util", "BytesRef");
public static final ClassName EXPONENTIAL_HISTOGRAM = ClassName.get("org.elasticsearch.exponentialhistogram", "ExponentialHistogram");

public static final ClassName RELEASABLE = ClassName.get("org.elasticsearch.core", "Releasable");
public static final ClassName RELEASABLES = ClassName.get("org.elasticsearch.core", "Releasables");

private record TypeDef(TypeName type, String alias, ClassName block, ClassName vector) {
private record TypeDef(TypeName type, String alias, ClassName block, ClassName vector, ClassName scratch) {

public static TypeDef of(TypeName type, String alias, String block, String vector) {
return new TypeDef(type, alias, ClassName.get(DATA_PACKAGE, block), ClassName.get(DATA_PACKAGE, vector));
public static TypeDef of(TypeName type, String alias, String block, String vector, ClassName scratch) {
return new TypeDef(
type,
alias,
ClassName.get(DATA_PACKAGE, block),
vector == null ? null : ClassName.get(DATA_PACKAGE, vector),
scratch
);
}
}

private static final Map<String, TypeDef> TYPES = Stream.of(
TypeDef.of(TypeName.BOOLEAN, "BOOLEAN", "BooleanBlock", "BooleanVector"),
TypeDef.of(TypeName.INT, "INT", "IntBlock", "IntVector"),
TypeDef.of(TypeName.LONG, "LONG", "LongBlock", "LongVector"),
TypeDef.of(TypeName.FLOAT, "FLOAT", "FloatBlock", "FloatVector"),
TypeDef.of(TypeName.DOUBLE, "DOUBLE", "DoubleBlock", "DoubleVector"),
TypeDef.of(BYTES_REF, "BYTES_REF", "BytesRefBlock", "BytesRefVector")
TypeDef.of(TypeName.BOOLEAN, "BOOLEAN", "BooleanBlock", "BooleanVector", null),
TypeDef.of(TypeName.INT, "INT", "IntBlock", "IntVector", null),
TypeDef.of(TypeName.LONG, "LONG", "LongBlock", "LongVector", null),
TypeDef.of(TypeName.FLOAT, "FLOAT", "FloatBlock", "FloatVector", null),
TypeDef.of(TypeName.DOUBLE, "DOUBLE", "DoubleBlock", "DoubleVector", null),
TypeDef.of(BYTES_REF, "BYTES_REF", "BytesRefBlock", "BytesRefVector", BYTES_REF),
TypeDef.of(EXPONENTIAL_HISTOGRAM, "EXPONENTIAL_HISTOGRAM", "ExponentialHistogramBlock", null, EXPONENTIAL_HISTOGRAM_SCRATCH)
)
.flatMap(def -> Stream.of(def.type.toString(), def.type + "[]", def.alias).map(alias -> Map.entry(alias, def)))
.collect(toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));
Expand Down Expand Up @@ -183,6 +194,14 @@ static ClassName vectorType(String elementType) {
return findRequired(elementType, "vector").vector;
}

public static ClassName scratchType(String elementType) {
TypeDef typeDef = TYPES.get(elementType);
if (typeDef != null) {
return typeDef.scratch;
}
return null;
}

static ClassName builderType(TypeName resultType) {
if (resultType.equals(BOOLEAN_BLOCK)) {
return BOOLEAN_BLOCK_BUILDER;
Expand Down Expand Up @@ -220,6 +239,9 @@ static ClassName builderType(TypeName resultType) {
if (resultType.equals(FLOAT_VECTOR)) {
return FLOAT_VECTOR_BUILDER;
}
if (resultType.equals(EXPONENTIAL_HISTOGRAM_BLOCK)) {
return EXPONENTIAL_HISTOGRAM_BLOCK_BUILDER;
}
throw new IllegalArgumentException("unknown builder type for [" + resultType + "]");
}

Expand Down Expand Up @@ -261,6 +283,9 @@ public static TypeName elementType(TypeName t) {
if (t.equals(FLOAT_BLOCK) || t.equals(FLOAT_VECTOR) || t.equals(FLOAT_BLOCK_BUILDER)) {
return TypeName.FLOAT;
}
if (t.equals(EXPONENTIAL_HISTOGRAM_BLOCK) || t.equals(EXPONENTIAL_HISTOGRAM_BLOCK_BUILDER)) {
return EXPONENTIAL_HISTOGRAM;
}
throw new IllegalArgumentException("unknown element type for [" + t + "]");
}

Expand Down
Loading