Skip to content

Commit 0478be2

Browse files
committed
ESQL: Speed up VALUES for many buckets
Speeds up the VALUES agg when collecting from many buckets. Specifically, this speeds up the algorithm used to `finish` the aggregation. Most specifically, this makes the algorithm more tollerant to large numbers of groups being collected. The old algorithm was `O(n^2)` with the number of groups. The new one is `O(n)` ``` (groups) 1 219.683 ± 1.069 -> 223.477 ± 1.990 ms/op 1000 426.323 ± 75.963 -> 463.670 ± 7.275 ms/op 100000 36690.871 ± 4656.350 -> 7800.332 ± 2775.869 ms/op 200000 89422.113 ± 2972.606 -> 21920.288 ± 3427.962 ms/op 400000 timed out at 10 minutes -> 40051.524 ± 2011.706 ms/op ``` The `1` group version was not changed at all. That's just noise in the measurement. The small bump in the `1000` case is almost certainly worth it and real. The huge drop in the `100000` case is quite real.
1 parent 268413b commit 0478be2

File tree

7 files changed

+730
-177
lines changed

7 files changed

+730
-177
lines changed
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.benchmark.compute.operator;
11+
12+
import org.apache.lucene.util.BytesRef;
13+
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
14+
import org.elasticsearch.common.util.BigArrays;
15+
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
16+
import org.elasticsearch.compute.aggregation.AggregatorMode;
17+
import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
18+
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
19+
import org.elasticsearch.compute.data.Block;
20+
import org.elasticsearch.compute.data.BlockFactory;
21+
import org.elasticsearch.compute.data.BytesRefBlock;
22+
import org.elasticsearch.compute.data.ElementType;
23+
import org.elasticsearch.compute.data.LongBlock;
24+
import org.elasticsearch.compute.data.LongVector;
25+
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.compute.operator.AggregationOperator;
27+
import org.elasticsearch.compute.operator.DriverContext;
28+
import org.elasticsearch.compute.operator.HashAggregationOperator;
29+
import org.elasticsearch.compute.operator.Operator;
30+
import org.openjdk.jmh.annotations.Benchmark;
31+
import org.openjdk.jmh.annotations.BenchmarkMode;
32+
import org.openjdk.jmh.annotations.Fork;
33+
import org.openjdk.jmh.annotations.Measurement;
34+
import org.openjdk.jmh.annotations.Mode;
35+
import org.openjdk.jmh.annotations.OutputTimeUnit;
36+
import org.openjdk.jmh.annotations.Param;
37+
import org.openjdk.jmh.annotations.Scope;
38+
import org.openjdk.jmh.annotations.State;
39+
import org.openjdk.jmh.annotations.Warmup;
40+
41+
import java.util.ArrayList;
42+
import java.util.HashSet;
43+
import java.util.List;
44+
import java.util.Set;
45+
import java.util.concurrent.TimeUnit;
46+
47+
@Warmup(iterations = 5)
48+
@Measurement(iterations = 7)
49+
@BenchmarkMode(Mode.AverageTime)
50+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
51+
@State(Scope.Thread)
52+
@Fork(1)
53+
public class ValuesAggregatorBenchmark {
54+
static final int MIN_BLOCK_LENGTH = 8 * 1024;
55+
private static final int OP_COUNT = 1024;
56+
private static final BytesRef[] KEYWORDS = new BytesRef[] {
57+
new BytesRef("Tokyo"),
58+
new BytesRef("Delhi"),
59+
new BytesRef("Shanghai"),
60+
new BytesRef("São Paulo"),
61+
new BytesRef("Mexico City"),
62+
new BytesRef("Cairo") };
63+
64+
private static final BlockFactory blockFactory = BlockFactory.getInstance(
65+
new NoopCircuitBreaker("noop"),
66+
BigArrays.NON_RECYCLING_INSTANCE // TODO real big arrays?
67+
);
68+
69+
static {
70+
// Smoke test all the expected values and force loading subclasses more like prod
71+
try {
72+
for (String groups : ValuesAggregatorBenchmark.class.getField("groups").getAnnotationsByType(Param.class)[0].value()) {
73+
for (String dataType : ValuesAggregatorBenchmark.class.getField("dataType").getAnnotationsByType(Param.class)[0].value()) {
74+
run(Integer.parseInt(groups), dataType, 10);
75+
}
76+
}
77+
} catch (NoSuchFieldException e) {
78+
throw new AssertionError();
79+
}
80+
}
81+
82+
private static final String BYTES_REF = "BytesRef";
83+
84+
@Param({ "1", "1000", /*"1000000"*/ })
85+
public int groups;
86+
87+
@Param({ BYTES_REF })
88+
public String dataType;
89+
90+
private static Operator operator(DriverContext driverContext, int groups, String dataType) {
91+
if (groups == 1) {
92+
return new AggregationOperator(
93+
List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
94+
driverContext
95+
);
96+
}
97+
List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
98+
return new HashAggregationOperator(
99+
List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
100+
() -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
101+
driverContext
102+
);
103+
}
104+
105+
private static AggregatorFunctionSupplier supplier(String dataType) {
106+
return switch (dataType) {
107+
case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier();
108+
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
109+
};
110+
}
111+
112+
private static void checkExpected(int groups, String dataType, Page page) {
113+
String prefix = String.format("[%s][%s]", groups, dataType);
114+
int positions = page.getPositionCount();
115+
if (positions != groups) {
116+
throw new IllegalArgumentException(prefix + " expected " + groups + " positions, got " + positions);
117+
}
118+
if (groups == 1) {
119+
checkUngrouped(prefix, dataType, page);
120+
return;
121+
}
122+
checkGrouped(prefix, groups, dataType, page);
123+
}
124+
125+
private static void checkGrouped(String prefix, int groups, String dataType, Page page) {
126+
LongVector groupsVector = page.<LongBlock>getBlock(0).asVector();
127+
for (int p = 0; p < groups; p++) {
128+
long group = groupsVector.getLong(p);
129+
if (group != p) {
130+
throw new IllegalArgumentException(prefix + "[" + p + "] expected group " + p + " but was " + groups);
131+
}
132+
}
133+
switch (dataType) {
134+
case BYTES_REF -> {
135+
BytesRefBlock values = page.getBlock(1);
136+
// Build the expected values
137+
List<Set<BytesRef>> expected = new ArrayList<>(groups);
138+
for (int g = 0; g < groups; g++) {
139+
expected.add(new HashSet<>(KEYWORDS.length));
140+
}
141+
int blockLength = blockLength(groups);
142+
for (int p = 0; p < blockLength; p++) {
143+
expected.get(p % groups).add(KEYWORDS[p % KEYWORDS.length]);
144+
}
145+
146+
// Check them
147+
for (int p = 0; p < groups; p++) {
148+
checkExpectedBytesRef(prefix, values, p, expected.get(p));
149+
}
150+
}
151+
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
152+
}
153+
}
154+
155+
private static void checkUngrouped(String prefix, String dataType, Page page) {
156+
switch (dataType) {
157+
case BYTES_REF -> {
158+
BytesRefBlock values = page.getBlock(0);
159+
checkExpectedBytesRef(prefix, values, 0, Set.of(KEYWORDS));
160+
}
161+
default -> throw new IllegalArgumentException(prefix + " unsupported data type " + dataType);
162+
}
163+
}
164+
165+
private static void checkExpectedBytesRef(String prefix, BytesRefBlock values, int position, Set<BytesRef> expected) {
166+
int valueCount = values.getValueCount(position);
167+
if (valueCount != expected.size()) {
168+
throw new IllegalArgumentException(
169+
prefix + "[" + position + "] expected " + expected.size() + " values but count was " + valueCount
170+
);
171+
}
172+
BytesRef scratch = new BytesRef();
173+
for (int i = values.getFirstValueIndex(position); i < valueCount; i++) {
174+
BytesRef v = values.getBytesRef(i, scratch);
175+
if (expected.contains(v) == false) {
176+
throw new IllegalArgumentException(prefix + "[" + position + "] expected " + v + " to be in " + expected);
177+
}
178+
}
179+
}
180+
181+
private static Page page(int groups, String dataType) {
182+
Block dataBlock = dataBlock(groups, dataType);
183+
if (groups == 1) {
184+
return new Page(dataBlock);
185+
}
186+
return new Page(groupingBlock(groups), dataBlock);
187+
}
188+
189+
private static Block dataBlock(int groups, String dataType) {
190+
return switch (dataType) {
191+
case BYTES_REF -> {
192+
int blockLength = blockLength(groups);
193+
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
194+
for (int i = 0; i < blockLength; i++) {
195+
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
196+
}
197+
yield builder.build();
198+
}
199+
}
200+
default -> throw new IllegalArgumentException("unsupported data type " + dataType);
201+
};
202+
}
203+
204+
private static Block groupingBlock(int groups) {
205+
int blockLength = blockLength(groups);
206+
try (LongVector.Builder builder = blockFactory.newLongVectorBuilder(blockLength)) {
207+
for (int i = 0; i < blockLength; i++) {
208+
builder.appendLong(i % groups);
209+
}
210+
return builder.build().asBlock();
211+
}
212+
}
213+
214+
@Benchmark
215+
public void run() {
216+
run(groups, dataType, OP_COUNT);
217+
}
218+
219+
private static void run(int groups, String dataType, int opCount) {
220+
DriverContext driverContext = driverContext();
221+
try (Operator operator = operator(driverContext, groups, dataType)) {
222+
Page page = page(groups, dataType);
223+
for (int i = 0; i < opCount; i++) {
224+
operator.addInput(page.shallowCopy());
225+
}
226+
operator.finish();
227+
checkExpected(groups, dataType, operator.getOutput());
228+
}
229+
}
230+
231+
static DriverContext driverContext() {
232+
return new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory);
233+
}
234+
235+
static int blockLength(int groups) {
236+
return Math.max(MIN_BLOCK_LENGTH, groups);
237+
}
238+
}

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

Lines changed: 77 additions & 27 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)