Skip to content

Commit e660ebb

Browse files
authored
ESQL: Standardize block args (#135160)
Standardizes how aggs and scalars work with arguments that take all values at a position. Aggs were using an arrays of values that we had to copy all of the values into the array. And allocate it. Scalars passed down the `Block` and the scalar read from the block on it's own. That's generally more efficient and not a lot harder. So I standardized on that. Previously scalars that took a `Block` parameter also took an implicit builder and position parameter. But aggs don't need the builder. And *do* need the position. This makes both of those parameters explicit rather than implicit.
1 parent 485099c commit e660ebb

30 files changed

+398
-262
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.ann;
9+
10+
import java.lang.annotation.ElementType;
11+
import java.lang.annotation.Retention;
12+
import java.lang.annotation.RetentionPolicy;
13+
import java.lang.annotation.Target;
14+
15+
/**
16+
* Used on parameters on methods annotated with {@link Evaluator} or in
17+
* {@link Aggregator} or {@link GroupingAggregator} to indicate an argument
18+
* that is the position in a block.
19+
*/
20+
@Target(ElementType.PARAMETER)
21+
@Retention(RetentionPolicy.SOURCE)
22+
public @interface Position {
23+
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import org.elasticsearch.compute.ann.IntermediateState;
2020
import org.elasticsearch.compute.gen.Methods.TypeMatcher;
2121
import org.elasticsearch.compute.gen.argument.Argument;
22-
import org.elasticsearch.compute.gen.argument.ArrayArgument;
22+
import org.elasticsearch.compute.gen.argument.BlockArgument;
23+
import org.elasticsearch.compute.gen.argument.PositionArgument;
2324
import org.elasticsearch.compute.gen.argument.StandardArgument;
2425

2526
import java.util.ArrayList;
@@ -110,12 +111,13 @@ public AggregatorImplementer(
110111
requireName("combine"),
111112
requireArgsStartsWith(requireType(aggState.declaredType()), requireAnyType("<aggregation input column type>"))
112113
);
113-
this.aggParams = combine.getParameters().stream().skip(1).map(v -> {
114+
this.aggParams = combine.getParameters().stream().skip(1).flatMap(v -> {
114115
Argument a = Argument.fromParameter(types, v);
115116
return switch (a) {
116-
case StandardArgument sa -> new AggregationParameter(sa.name(), sa.type(), false);
117-
case ArrayArgument aa -> new AggregationParameter(aa.name(), aa.componentType(), true);
118-
default -> throw new IllegalArgumentException("unsupported argument [" + a + "]");
117+
case StandardArgument sa -> Stream.of(new AggregationParameter(sa.name(), sa.type(), false));
118+
case BlockArgument ba -> Stream.of(new AggregationParameter(ba.name(), Types.elementType(ba.type()), true));
119+
case PositionArgument pa -> Stream.of();
120+
default -> throw new IllegalArgumentException("unsupported argument [" + declarationType + "][" + a + "]");
119121
};
120122
}).toList();
121123

@@ -435,22 +437,10 @@ private MethodSpec addRawBlock(boolean masked) {
435437
if (aggParams.size() > 1) {
436438
throw new IllegalArgumentException("array mode not supported for multiple args");
437439
}
438-
builder.addStatement("int start = $L.getFirstValueIndex(p)", aggParams.getFirst().blockName());
439-
builder.addStatement("int end = start + $L.getValueCount(p)", aggParams.getFirst().blockName());
440-
// TODO move this to the top of the loop
441-
builder.addStatement(
442-
"$L[] valuesArray = new $L[end - start]",
443-
aggParams.getFirst().arrayType(),
444-
aggParams.getFirst().arrayType()
440+
warningsBlock(
441+
builder,
442+
() -> builder.addStatement("$T.combine(state, p, $L)", declarationType, aggParams.getFirst().blockName())
445443
);
446-
builder.beginControlFlow("for (int i = start; i < end; i++)");
447-
builder.addStatement(
448-
"valuesArray[i-start] = $L.get$L(i)",
449-
aggParams.getFirst().blockName(),
450-
capitalize(aggParams.getFirst().arrayType())
451-
);
452-
builder.endControlFlow();
453-
combineRawInputForArray(builder, "valuesArray");
454444
} else {
455445
if (first == null && aggState.hasSeen()) {
456446
builder.addStatement("state.seen(true)");
@@ -547,10 +537,6 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
547537
builder.addStatement(pattern.toString(), params.toArray());
548538
}
549539

550-
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
551-
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
552-
}
553-
554540
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
555541
if (warnExceptions.isEmpty() == false) {
556542
builder.beginControlFlow("try");

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/ConsumeProcessor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.compute.gen;
99

1010
import org.elasticsearch.compute.ann.Fixed;
11+
import org.elasticsearch.compute.ann.Position;
1112

1213
import java.util.List;
1314
import java.util.Set;
@@ -44,7 +45,8 @@ public Set<String> getSupportedAnnotationTypes() {
4445
"org.elasticsearch.xpack.esql.expression.function.MapParam",
4546
"org.elasticsearch.rest.ServerlessScope",
4647
"org.elasticsearch.xcontent.ParserConstructor",
47-
Fixed.class.getName()
48+
Fixed.class.getName(),
49+
Position.class.getName()
4850
);
4951
}
5052

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorImplementer.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ private MethodSpec realEval(boolean blockStyle) {
219219

220220
StringBuilder pattern = new StringBuilder();
221221
List<Object> args = new ArrayList<>();
222-
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
222+
pattern.append("$T.$N(");
223223
args.add(declarationType);
224224
args.add(processFunction.function.getSimpleName());
225225
processFunction.args.stream().forEach(a -> {
@@ -312,10 +312,7 @@ static class ProcessFunction {
312312
}
313313
builderArg = ba;
314314
} else if (arg instanceof BlockArgument) {
315-
if (builderArg != null && args.size() == 2 && hasBlockType == false) {
316-
args.clear();
317-
hasBlockType = true;
318-
}
315+
hasBlockType = true;
319316
}
320317
args.add(arg);
321318
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationParameter;
2121
import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationState;
2222
import org.elasticsearch.compute.gen.argument.Argument;
23-
import org.elasticsearch.compute.gen.argument.ArrayArgument;
23+
import org.elasticsearch.compute.gen.argument.BlockArgument;
24+
import org.elasticsearch.compute.gen.argument.PositionArgument;
2425
import org.elasticsearch.compute.gen.argument.StandardArgument;
2526

2627
import java.util.ArrayList;
@@ -37,7 +38,6 @@
3738
import javax.lang.model.util.Elements;
3839

3940
import static java.util.stream.Collectors.joining;
40-
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
4141
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
4242
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
4343
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
@@ -118,12 +118,13 @@ public GroupingAggregatorImplementer(
118118
requireName("combine"),
119119
combineArgs(aggState)
120120
);
121-
this.aggParams = combine.getParameters().stream().skip(aggState.declaredType().isPrimitive() ? 1 : 2).map(v -> {
121+
this.aggParams = combine.getParameters().stream().skip(aggState.declaredType().isPrimitive() ? 1 : 2).flatMap(v -> {
122122
Argument a = Argument.fromParameter(types, v);
123123
return switch (a) {
124-
case StandardArgument sa -> new AggregationParameter(sa.name(), sa.type(), false);
125-
case ArrayArgument aa -> new AggregationParameter(aa.name(), aa.componentType(), true);
126-
default -> throw new IllegalArgumentException("unsupported argument [" + a + "]");
124+
case StandardArgument sa -> Stream.of(new AggregationParameter(sa.name(), sa.type(), false));
125+
case BlockArgument ba -> Stream.of(new AggregationParameter(ba.name(), Types.elementType(ba.type()), true));
126+
case PositionArgument pa -> Stream.of();
127+
default -> throw new IllegalArgumentException("unsupported argument [" + declarationType + "][" + a + "]");
127128
};
128129
}).toList();
129130

@@ -476,21 +477,14 @@ private MethodSpec addRawInputLoop(TypeName groupsType, boolean valuesAreVector)
476477
if (aggParams.size() > 1) {
477478
throw new IllegalArgumentException("array mode not supported for multiple args");
478479
}
479-
String arrayType = aggParams.getFirst().type().toString().replace("[]", "");
480-
builder.addStatement("int valuesStart = $L.getFirstValueIndex(valuesPosition)", aggParams.getFirst().blockName());
481-
builder.addStatement(
482-
"int valuesEnd = valuesStart + $L.getValueCount(valuesPosition)",
483-
aggParams.getFirst().blockName()
484-
);
485-
builder.addStatement("$L[] valuesArray = new $L[valuesEnd - valuesStart]", arrayType, arrayType);
486-
builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
487-
builder.addStatement(
488-
"valuesArray[v-valuesStart] = $L.get$L(v)",
489-
aggParams.getFirst().blockName(),
490-
capitalize(aggParams.getFirst().arrayType())
480+
warningsBlock(
481+
builder,
482+
() -> builder.addStatement(
483+
"$T.combine(state, groupId, valuesPosition, $L)",
484+
declarationType,
485+
aggParams.getFirst().blockName()
486+
)
491487
);
492-
builder.endControlFlow();
493-
combineRawInputForArray(builder, "valuesArray");
494488
} else {
495489
for (AggregationParameter p : aggParams) {
496490
builder.addStatement("int $L = $L.getFirstValueIndex(valuesPosition)", p.startName(), p.blockName());
@@ -536,6 +530,9 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
536530
pattern.append("$T.combine(state, groupId");
537531
params.add(declarationType);
538532
}
533+
if (aggParams.getFirst().isArray()) {
534+
pattern.append(", p");
535+
}
539536
for (AggregationParameter p : aggParams) {
540537
pattern.append(", $L");
541538
params.add(p.valueName());
@@ -547,10 +544,6 @@ private void invokeCombineRawInput(TypeName returnType, MethodSpec.Builder build
547544
builder.addStatement(pattern.toString(), params.toArray());
548545
}
549546

550-
private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
551-
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
552-
}
553-
554547
private boolean shouldWrapAddInput(boolean valuesAreVector) {
555548
return optionalStaticMethod(
556549
declarationType,

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/argument/Argument.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import com.squareup.javapoet.TypeSpec;
1414

1515
import org.elasticsearch.compute.ann.Fixed;
16+
import org.elasticsearch.compute.ann.Position;
1617
import org.elasticsearch.compute.gen.Types;
1718

1819
import java.util.List;
@@ -41,6 +42,12 @@ static Argument fromParameter(javax.lang.model.util.Types types, VariableElement
4142
Types.extendsSuper(types, v.asType(), "org.elasticsearch.core.Releasable")
4243
);
4344
}
45+
46+
Position position = v.getAnnotation(Position.class);
47+
if (position != null) {
48+
return new PositionArgument();
49+
}
50+
4451
if (type instanceof ClassName c
4552
&& c.simpleName().equals("Builder")
4653
&& c.enclosingClassName() != null
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.gen.argument;
9+
10+
import com.squareup.javapoet.MethodSpec;
11+
import com.squareup.javapoet.TypeName;
12+
import com.squareup.javapoet.TypeSpec;
13+
14+
import java.util.List;
15+
16+
/**
17+
* The position in a block.
18+
*/
19+
public record PositionArgument() implements Argument {
20+
@Override
21+
public TypeName dataType(boolean blockStyle) {
22+
return TypeName.INT;
23+
}
24+
25+
@Override
26+
public String paramName(boolean blockStyle) {
27+
// No need to pass it
28+
return null;
29+
}
30+
31+
@Override
32+
public void declareField(TypeSpec.Builder builder) {
33+
// Nothing to do
34+
}
35+
36+
@Override
37+
public void declareFactoryField(TypeSpec.Builder builder) {
38+
// Nothing to do
39+
}
40+
41+
@Override
42+
public void implementCtor(MethodSpec.Builder builder) {
43+
// Nothing to do
44+
}
45+
46+
@Override
47+
public void implementFactoryCtor(MethodSpec.Builder builder) {
48+
// Nothing to do
49+
}
50+
51+
@Override
52+
public String factoryInvocation(MethodSpec.Builder factoryMethodBuilder) {
53+
return null;
54+
}
55+
56+
@Override
57+
public void evalToBlock(MethodSpec.Builder builder) {
58+
// nothing to do
59+
}
60+
61+
@Override
62+
public void closeEvalToBlock(MethodSpec.Builder builder) {
63+
// nothing to do
64+
}
65+
66+
@Override
67+
public void resolveVectors(MethodSpec.Builder builder, String invokeBlockEval) {
68+
// nothing to do
69+
}
70+
71+
@Override
72+
public void createScratch(MethodSpec.Builder builder) {
73+
// nothing to do
74+
}
75+
76+
@Override
77+
public void skipNull(MethodSpec.Builder builder) {
78+
// nothing to do
79+
}
80+
81+
@Override
82+
public void allBlocksAreNull(MethodSpec.Builder builder) {
83+
// nothing to do
84+
}
85+
86+
@Override
87+
public void read(MethodSpec.Builder builder, boolean blockStyle) {
88+
// nothing to do
89+
}
90+
91+
@Override
92+
public void buildInvocation(StringBuilder pattern, List<Object> args, boolean blockStyle) {
93+
pattern.append("p");
94+
}
95+
96+
@Override
97+
public void buildToStringInvocation(StringBuilder pattern, List<Object> args, String prefix) {
98+
// nothing to do
99+
}
100+
101+
@Override
102+
public String closeInvocation() {
103+
return null;
104+
}
105+
106+
@Override
107+
public void sumBaseRamBytesUsed(MethodSpec.Builder builder) {}
108+
}

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/spatial/SpatialExtentCartesianShapeDocValuesAggregatorFunction.java

Lines changed: 2 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)