Skip to content

Commit bcb6b77

Browse files
committed
Optimize ordinal inputs in Values aggregation (elastic#127849)
Currently, time-series aggregations use the `values` aggregation to collect dimension values. While we might introduce a specialized aggregation for this in the future, for now, we are using `values`, and the inputs are likely ordinal blocks. This change speeds up the `values` aggregation when the inputs are ordinal-based. Execution time reduced from 461ms to 192ms for 1000 groups. ``` ValuesAggregatorBenchmark.run BytesRef 10000 avgt 7 461.938 ± 6.089 ms/op ValuesAggregatorBenchmark.run BytesRef 10000 avgt 7 192.898 ± 1.781 ms/op ```
1 parent c833893 commit bcb6b77

File tree

7 files changed

+266
-16
lines changed

7 files changed

+266
-16
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
import org.elasticsearch.compute.data.Block;
2222
import org.elasticsearch.compute.data.BlockFactory;
2323
import org.elasticsearch.compute.data.BytesRefBlock;
24+
import org.elasticsearch.compute.data.BytesRefVector;
2425
import org.elasticsearch.compute.data.ElementType;
2526
import org.elasticsearch.compute.data.IntBlock;
27+
import org.elasticsearch.compute.data.IntVector;
2628
import org.elasticsearch.compute.data.LongBlock;
2729
import org.elasticsearch.compute.data.LongVector;
30+
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
2831
import org.elasticsearch.compute.data.Page;
2932
import org.elasticsearch.compute.operator.AggregationOperator;
3033
import org.elasticsearch.compute.operator.DriverContext;
@@ -275,11 +278,18 @@ private static Block dataBlock(int groups, String dataType) {
275278
int blockLength = blockLength(groups);
276279
return switch (dataType) {
277280
case BYTES_REF -> {
278-
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
281+
try (
282+
BytesRefVector.Builder dict = blockFactory.newBytesRefVectorBuilder(blockLength);
283+
IntVector.Builder ords = blockFactory.newIntVectorBuilder(blockLength)
284+
) {
285+
final int dictLength = Math.min(blockLength, KEYWORDS.length);
286+
for (int i = 0; i < dictLength; i++) {
287+
dict.appendBytesRef(KEYWORDS[i]);
288+
}
279289
for (int i = 0; i < blockLength; i++) {
280-
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
290+
ords.appendInt(i % dictLength);
281291
}
282-
yield builder.build();
292+
yield new OrdinalBytesRefVector(ords.build(), dict.build()).asBlock();
283293
}
284294
}
285295
case INT -> {

docs/changelog/127849.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127849
2+
summary: Optimize ordinal inputs in Values aggregation
3+
area: "ES|QL"
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import static java.util.stream.Collectors.joining;
3737
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
38+
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
3839
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
3940
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
4041
import static org.elasticsearch.compute.gen.Methods.requireArgs;
@@ -332,10 +333,32 @@ private MethodSpec prepareProcessPage() {
332333
builder.beginControlFlow("if (valuesBlock.mayHaveNulls())");
333334
builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
334335
builder.endControlFlow();
335-
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)));
336+
if (shouldWrapAddInput(blockType(aggParam.type()))) {
337+
builder.addStatement(
338+
"var addInput = $L",
339+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))
340+
);
341+
builder.addStatement("return $T.wrapAddInput(addInput, state, valuesBlock)", declarationType);
342+
} else {
343+
builder.addStatement(
344+
"return $L",
345+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))
346+
);
347+
}
336348
}
337349
builder.endControlFlow();
338-
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)));
350+
if (shouldWrapAddInput(vectorType(aggParam.type()))) {
351+
builder.addStatement(
352+
"var addInput = $L",
353+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))
354+
);
355+
builder.addStatement("return $T.wrapAddInput(addInput, state, valuesVector)", declarationType);
356+
} else {
357+
builder.addStatement(
358+
"return $L",
359+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))
360+
);
361+
}
339362
return builder.build();
340363
}
341364

@@ -525,6 +548,15 @@ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVar
525548
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
526549
}
527550

551+
private boolean shouldWrapAddInput(TypeName valuesType) {
552+
return optionalStaticMethod(
553+
declarationType,
554+
requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT),
555+
requireName("wrapAddInput"),
556+
requireArgs(requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), requireType(aggState.declaredType()), requireType(valuesType))
557+
) != null;
558+
}
559+
528560
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
529561
if (warnExceptions.isEmpty() == false) {
530562
builder.beginControlFlow("try");

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,31 @@ static ExecutableElement requireStaticMethod(
5959
TypeMatcher returnTypeMatcher,
6060
NameMatcher nameMatcher,
6161
ArgumentMatcher argumentMatcher
62+
) {
63+
ExecutableElement method = optionalStaticMethod(declarationType, returnTypeMatcher, nameMatcher, argumentMatcher);
64+
if (method == null) {
65+
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
66+
var signatures = nameMatcher.names.stream()
67+
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
68+
.collect(joining(" or "));
69+
throw new IllegalArgumentException(message + signatures);
70+
}
71+
return method;
72+
}
73+
74+
static ExecutableElement optionalStaticMethod(
75+
TypeElement declarationType,
76+
TypeMatcher returnTypeMatcher,
77+
NameMatcher nameMatcher,
78+
ArgumentMatcher argumentMatcher
6279
) {
6380
return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
6481
.filter(method -> method.getModifiers().contains(Modifier.STATIC))
6582
.filter(method -> nameMatcher.test(method.getSimpleName().toString()))
6683
.filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
6784
.filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
6885
.findFirst()
69-
.orElseThrow(() -> {
70-
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
71-
var signatures = nameMatcher.names.stream()
72-
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
73-
.collect(joining(" or "));
74-
return new IllegalArgumentException(message + signatures);
75-
});
86+
.orElse(null);
7687
}
7788

7889
static NameMatcher requireName(String... names) {

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
12+
import org.elasticsearch.compute.data.BytesRefBlock;
13+
import org.elasticsearch.compute.data.BytesRefVector;
14+
import org.elasticsearch.compute.data.IntArrayBlock;
15+
import org.elasticsearch.compute.data.IntBigArrayBlock;
16+
import org.elasticsearch.compute.data.IntBlock;
17+
import org.elasticsearch.compute.data.IntVector;
18+
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
19+
import org.elasticsearch.core.Releasables;
20+
21+
final class ValuesBytesRefAggregators {
22+
static GroupingAggregatorFunction.AddInput wrapAddInput(
23+
GroupingAggregatorFunction.AddInput delegate,
24+
ValuesBytesRefAggregator.GroupingState state,
25+
BytesRefBlock values
26+
) {
27+
OrdinalBytesRefBlock valuesOrdinal = values.asOrdinals();
28+
if (valuesOrdinal == null) {
29+
return delegate;
30+
}
31+
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
32+
final IntVector hashIds;
33+
BytesRef spare = new BytesRef();
34+
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
35+
for (int p = 0; p < dict.getPositionCount(); p++) {
36+
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
37+
}
38+
hashIds = hashIdsBuilder.build();
39+
}
40+
IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock();
41+
return new GroupingAggregatorFunction.AddInput() {
42+
@Override
43+
public void add(int positionOffset, IntArrayBlock groupIds) {
44+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
45+
if (groupIds.isNull(groupPosition)) {
46+
continue;
47+
}
48+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
49+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
50+
for (int g = groupStart; g < groupEnd; g++) {
51+
int groupId = groupIds.getInt(g);
52+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
53+
continue;
54+
}
55+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
56+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
57+
for (int v = valuesStart; v < valuesEnd; v++) {
58+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
59+
}
60+
}
61+
}
62+
}
63+
64+
@Override
65+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
66+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
67+
if (groupIds.isNull(groupPosition)) {
68+
continue;
69+
}
70+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
71+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
72+
for (int g = groupStart; g < groupEnd; g++) {
73+
int groupId = groupIds.getInt(g);
74+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
75+
continue;
76+
}
77+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
78+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
79+
for (int v = valuesStart; v < valuesEnd; v++) {
80+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
81+
}
82+
}
83+
}
84+
}
85+
86+
@Override
87+
public void add(int positionOffset, IntVector groupIds) {
88+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
89+
int groupId = groupIds.getInt(groupPosition);
90+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
91+
continue;
92+
}
93+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
94+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
95+
for (int v = valuesStart; v < valuesEnd; v++) {
96+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
97+
}
98+
}
99+
}
100+
101+
@Override
102+
public void close() {
103+
Releasables.close(hashIds, delegate);
104+
}
105+
};
106+
}
107+
108+
static GroupingAggregatorFunction.AddInput wrapAddInput(
109+
GroupingAggregatorFunction.AddInput delegate,
110+
ValuesBytesRefAggregator.GroupingState state,
111+
BytesRefVector values
112+
) {
113+
var valuesOrdinal = values.asOrdinals();
114+
if (valuesOrdinal == null) {
115+
return delegate;
116+
}
117+
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
118+
final IntVector hashIds;
119+
BytesRef spare = new BytesRef();
120+
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
121+
for (int p = 0; p < dict.getPositionCount(); p++) {
122+
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
123+
}
124+
hashIds = hashIdsBuilder.build();
125+
}
126+
var ordinalIds = valuesOrdinal.getOrdinalsVector();
127+
return new GroupingAggregatorFunction.AddInput() {
128+
@Override
129+
public void add(int positionOffset, IntArrayBlock groupIds) {
130+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
131+
if (groupIds.isNull(groupPosition)) {
132+
continue;
133+
}
134+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
135+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
136+
for (int g = groupStart; g < groupEnd; g++) {
137+
int groupId = groupIds.getInt(g);
138+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
139+
}
140+
}
141+
}
142+
143+
@Override
144+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
145+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
146+
if (groupIds.isNull(groupPosition)) {
147+
continue;
148+
}
149+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
150+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
151+
for (int g = groupStart; g < groupEnd; g++) {
152+
int groupId = groupIds.getInt(g);
153+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
154+
}
155+
}
156+
}
157+
158+
@Override
159+
public void add(int positionOffset, IntVector groupIds) {
160+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
161+
int groupId = groupIds.getInt(groupPosition);
162+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
163+
}
164+
}
165+
166+
@Override
167+
public void close() {
168+
Releasables.close(hashIds, delegate);
169+
}
170+
};
171+
}
172+
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ $endif$
8787
return new GroupingState(bigArrays);
8888
}
8989

90+
$if(BytesRef)$
91+
public static GroupingAggregatorFunction.AddInput wrapAddInput(
92+
GroupingAggregatorFunction.AddInput delegate,
93+
GroupingState state,
94+
BytesRefBlock values
95+
) {
96+
return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values);
97+
}
98+
99+
public static GroupingAggregatorFunction.AddInput wrapAddInput(
100+
GroupingAggregatorFunction.AddInput delegate,
101+
GroupingState state,
102+
BytesRefVector values
103+
) {
104+
return ValuesBytesRefAggregators.wrapAddInput(delegate, state, values);
105+
}
106+
$endif$
107+
90108
public static void combine(GroupingState state, int groupId, $type$ v) {
91109
$if(long)$
92110
state.values.add(groupId, v);
@@ -234,8 +252,8 @@ $if(long||double)$
234252
private final LongLongHash values;
235253

236254
$elseif(BytesRef)$
237-
private final LongLongHash values;
238-
private final BytesRefHash bytes;
255+
final LongLongHash values;
256+
BytesRefHash bytes;
239257

240258
$elseif(int||float)$
241259
private final LongHash values;

0 commit comments

Comments
 (0)