Skip to content

Commit 7ca3ce3

Browse files
committed
Included TopNLongBlockHash into BlockHash.ubuild() and GroupSpec logics, and added it to benchmark
1 parent 29fa4d1 commit 7ca3ce3

File tree

4 files changed

+72
-19
lines changed

4 files changed

+72
-19
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public class AggregatorBenchmark {
7373
static final int BLOCK_LENGTH = 8 * 1024;
7474
private static final int OP_COUNT = 1024;
7575
private static final int GROUPS = 5;
76+
private static final int TOP_N_LIMIT = 3;
7677

7778
private static final BlockFactory blockFactory = BlockFactory.getInstance(
7879
new NoopCircuitBreaker("noop"),
@@ -90,6 +91,7 @@ public class AggregatorBenchmark {
9091
private static final String TWO_ORDINALS = "two_" + ORDINALS;
9192
private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
9293
private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
94+
private static final String TOP_N_LONGS = "top_n_" + LONGS;
9395

9496
private static final String VECTOR_DOUBLES = "vector_doubles";
9597
private static final String HALF_NULL_DOUBLES = "half_null_doubles";
@@ -147,7 +149,8 @@ static void selfTest() {
147149
TWO_BYTES_REFS,
148150
TWO_ORDINALS,
149151
LONGS_AND_BYTES_REFS,
150-
TWO_LONGS_AND_BYTES_REFS }
152+
TWO_LONGS_AND_BYTES_REFS,
153+
TOP_N_LONGS }
151154
)
152155
public String grouping;
153156

@@ -161,8 +164,7 @@ static void selfTest() {
161164
public String filter;
162165

163166
private static Operator operator(DriverContext driverContext, String grouping, String op, String dataType, String filter) {
164-
165-
if (grouping.equals("none")) {
167+
if (grouping.equals(NONE)) {
166168
return new AggregationOperator(
167169
List.of(supplier(op, dataType, filter).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
168170
driverContext
@@ -188,6 +190,12 @@ private static Operator operator(DriverContext driverContext, String grouping, S
188190
new BlockHash.GroupSpec(1, ElementType.LONG),
189191
new BlockHash.GroupSpec(2, ElementType.BYTES_REF)
190192
);
193+
case TOP_N_LONGS -> List.of(new BlockHash.GroupSpec(
194+
0,
195+
ElementType.LONG,
196+
false,
197+
new BlockHash.TopNDef(0, true, true, TOP_N_LIMIT)
198+
));
191199
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
192200
};
193201
return new HashAggregationOperator(
@@ -271,10 +279,14 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
271279
case BOOLEANS -> 2;
272280
default -> GROUPS;
273281
};
282+
int availableGroups = switch (grouping) {
283+
case TOP_N_LONGS -> TOP_N_LIMIT;
284+
default -> groups;
285+
};
274286
switch (op) {
275287
case AVG -> {
276288
DoubleBlock dValues = (DoubleBlock) values;
277-
for (int g = 0; g < groups; g++) {
289+
for (int g = 0; g < availableGroups; g++) {
278290
long group = g;
279291
long sum = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum();
280292
long count = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count();
@@ -286,7 +298,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
286298
}
287299
case COUNT -> {
288300
LongBlock lValues = (LongBlock) values;
289-
for (int g = 0; g < groups; g++) {
301+
for (int g = 0; g < availableGroups; g++) {
290302
long group = g;
291303
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).count() * opCount;
292304
if (lValues.getLong(g) != expected) {
@@ -296,7 +308,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
296308
}
297309
case COUNT_DISTINCT -> {
298310
LongBlock lValues = (LongBlock) values;
299-
for (int g = 0; g < groups; g++) {
311+
for (int g = 0; g < availableGroups; g++) {
300312
long group = g;
301313
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).distinct().count();
302314
long count = lValues.getLong(g);
@@ -310,15 +322,15 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
310322
switch (dataType) {
311323
case LONGS -> {
312324
LongBlock lValues = (LongBlock) values;
313-
for (int g = 0; g < groups; g++) {
325+
for (int g = 0; g < availableGroups; g++) {
314326
if (lValues.getLong(g) != (long) g) {
315327
throw new AssertionError(prefix + "expected [" + g + "] but was [" + lValues.getLong(g) + "]");
316328
}
317329
}
318330
}
319331
case DOUBLES -> {
320332
DoubleBlock dValues = (DoubleBlock) values;
321-
for (int g = 0; g < groups; g++) {
333+
for (int g = 0; g < availableGroups; g++) {
322334
if (dValues.getDouble(g) != (long) g) {
323335
throw new AssertionError(prefix + "expected [" + g + "] but was [" + dValues.getDouble(g) + "]");
324336
}
@@ -331,7 +343,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
331343
switch (dataType) {
332344
case LONGS -> {
333345
LongBlock lValues = (LongBlock) values;
334-
for (int g = 0; g < groups; g++) {
346+
for (int g = 0; g < availableGroups; g++) {
335347
long group = g;
336348
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
337349
if (lValues.getLong(g) != expected) {
@@ -341,7 +353,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
341353
}
342354
case DOUBLES -> {
343355
DoubleBlock dValues = (DoubleBlock) values;
344-
for (int g = 0; g < groups; g++) {
356+
for (int g = 0; g < availableGroups; g++) {
345357
long group = g;
346358
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).max().getAsLong();
347359
if (dValues.getDouble(g) != expected) {
@@ -356,7 +368,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
356368
switch (dataType) {
357369
case LONGS -> {
358370
LongBlock lValues = (LongBlock) values;
359-
for (int g = 0; g < groups; g++) {
371+
for (int g = 0; g < availableGroups; g++) {
360372
long group = g;
361373
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
362374
if (lValues.getLong(g) != expected) {
@@ -366,7 +378,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
366378
}
367379
case DOUBLES -> {
368380
DoubleBlock dValues = (DoubleBlock) values;
369-
for (int g = 0; g < groups; g++) {
381+
for (int g = 0; g < availableGroups; g++) {
370382
long group = g;
371383
long expected = LongStream.range(0, BLOCK_LENGTH).filter(l -> l % groups == group).sum() * opCount;
372384
if (dValues.getDouble(g) != expected) {
@@ -391,6 +403,14 @@ private static void checkGroupingBlock(String prefix, String grouping, Block blo
391403
}
392404
}
393405
}
406+
case TOP_N_LONGS -> {
407+
LongBlock groups = (LongBlock) block;
408+
for (int g = 0; g < TOP_N_LIMIT; g++) {
409+
if (groups.getLong(g) != (long) g) {
410+
throw new AssertionError(prefix + "bad group expected [" + g + "] but was [" + groups.getLong(g) + "]");
411+
}
412+
}
413+
}
394414
case INTS -> {
395415
IntBlock groups = (IntBlock) block;
396416
for (int g = 0; g < GROUPS; g++) {
@@ -495,7 +515,7 @@ private static void checkUngrouped(String prefix, String op, String dataType, Pa
495515

496516
private static Page page(BlockFactory blockFactory, String grouping, String blockType) {
497517
Block dataBlock = dataBlock(blockFactory, blockType);
498-
if (grouping.equals("none")) {
518+
if (grouping.equals(NONE)) {
499519
return new Page(dataBlock);
500520
}
501521
List<Block> blocks = groupingBlocks(grouping, blockType);
@@ -564,7 +584,7 @@ private static Block groupingBlock(String grouping, String blockType) {
564584
default -> throw new UnsupportedOperationException("bad grouping [" + grouping + "]");
565585
};
566586
return switch (grouping) {
567-
case LONGS -> {
587+
case TOP_N_LONGS, LONGS -> {
568588
var builder = blockFactory.newLongBlockBuilder(BLOCK_LENGTH);
569589
for (int i = 0; i < BLOCK_LENGTH; i++) {
570590
for (int v = 0; v < valuesPerGroup; v++) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.compute.data.IntBlock;
2424
import org.elasticsearch.compute.data.IntVector;
2525
import org.elasticsearch.compute.data.Page;
26+
import org.elasticsearch.core.Nullable;
2627
import org.elasticsearch.core.Releasable;
2728
import org.elasticsearch.core.ReleasableIterator;
2829
import org.elasticsearch.index.analysis.AnalysisRegistry;
@@ -113,13 +114,30 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
113114
@Override
114115
public abstract BitArray seenGroupIds(BigArrays bigArrays);
115116

117+
/**
118+
* Configuration for a BlockHash group spec that is later sorted and limited (Top-N).
119+
* <p>
120+
* Part of a performance improvement to avoid aggregating groups that will not be used.
121+
* </p>
122+
*
123+
* @param order The order of this group in the sort, starting at 0
124+
* @param asc True if this group will be sorted ascending. False if descending.
125+
* @param nullsFirst True if the nulls should be the first elements in the TopN. False if they should be kept last.
126+
* @param limit The number of elements to keep, including nulls.
127+
*/
128+
public record TopNDef(int order, boolean asc, boolean nullsFirst, int limit) {}
129+
116130
/**
117131
* @param isCategorize Whether this group is a CATEGORIZE() or not.
118132
* May be changed in the future when more stateful grouping functions are added.
119133
*/
120-
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
134+
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize, @Nullable TopNDef topNDef) {
121135
public GroupSpec(int channel, ElementType elementType) {
122-
this(channel, elementType, false);
136+
this(channel, elementType, false, null);
137+
}
138+
139+
public GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
140+
this(channel, elementType, isCategorize, null);
123141
}
124142
}
125143

@@ -134,7 +152,18 @@ public GroupSpec(int channel, ElementType elementType) {
134152
*/
135153
public static BlockHash build(List<GroupSpec> groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) {
136154
if (groups.size() == 1) {
137-
return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory);
155+
GroupSpec group = groups.get(0);
156+
if (group.topNDef() != null && group.elementType() == ElementType.LONG) {
157+
TopNDef topNDef = group.topNDef();
158+
return new LongTopNBlockHash(
159+
group.channel(),
160+
topNDef.asc(),
161+
topNDef.nullsFirst(),
162+
topNDef.limit(),
163+
blockFactory
164+
);
165+
}
166+
return newForElementType(group.channel(), group.elementType(), blockFactory);
138167
}
139168
if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
140169
switch (groups.size()) {

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/TopNBlockHashTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ private void hashBatchesCallbackOnLast(Consumer<OrdsAndKeys> callback, Block[]..
363363
private BlockHash buildBlockHash(int emitBatchSize, Block... values) {
364364
List<BlockHash.GroupSpec> specs = new ArrayList<>(values.length);
365365
for (int c = 0; c < values.length; c++) {
366-
specs.add(new BlockHash.GroupSpec(c, values[c].elementType()));
366+
specs.add(new BlockHash.GroupSpec(c, values[c].elementType(), false, topNDef(c)));
367367
}
368368
assert forcePackedHash == false : "Packed TopN hash not implemented yet";
369369
/*return forcePackedHash
@@ -386,4 +386,8 @@ private String topNParametersString(int differentValues, int unusedInsertedValue
386386
+ ", entries="
387387
+ Math.min(differentValues, limit + unusedInsertedValues);
388388
}
389+
390+
private BlockHash.TopNDef topNDef(int order) {
391+
return new BlockHash.TopNDef(order, asc, nullsFirst, limit);
392+
}
389393
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ BlockHash.GroupSpec toHashGroupSpec() {
360360
throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
361361
}
362362

363-
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
363+
return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize, null);
364364
}
365365

366366
ElementType elementType() {

0 commit comments

Comments
 (0)