Skip to content

Commit 8cc77ac

Browse files
committed
Introduce grouping in TopN operator
1 parent 20c02f4 commit 8cc77ac

File tree

16 files changed

+335
-49
lines changed

16 files changed

+335
-49
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/TopNBenchmark.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ private static Operator operator(String data, int topCount) {
125125
topCount,
126126
elementTypes,
127127
encoders,
128+
List.of(),
128129
IntStream.range(0, count).mapToObj(c -> new TopNOperator.SortOrder(c, false, false)).toList(),
129130
16 * 1024
130131
);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Page.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.io.IOException;
1717
import java.util.Arrays;
1818
import java.util.Objects;
19+
import java.util.stream.IntStream;
1920

2021
/**
2122
* A page is a column-oriented data abstraction that allows data to be passed between operators in
@@ -311,4 +312,9 @@ public Page filter(int... positions) {
311312
}
312313
return new Page(filteredBlocks);
313314
}
315+
316+
public Page subPage(int fromIndex, int toIndex) {
317+
// TODO: optimize!
318+
return filter(IntStream.range(fromIndex, toIndex).toArray());
319+
}
314320
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.collect.Iterators;
1616
import org.elasticsearch.compute.data.Block;
1717
import org.elasticsearch.compute.data.BlockFactory;
18+
import org.elasticsearch.compute.data.BytesRefBlock;
1819
import org.elasticsearch.compute.data.ElementType;
1920
import org.elasticsearch.compute.data.Page;
2021
import org.elasticsearch.compute.operator.BreakingBytesRefBuilder;
@@ -23,11 +24,14 @@
2324
import org.elasticsearch.core.Releasable;
2425
import org.elasticsearch.core.Releasables;
2526

27+
import java.nio.charset.Charset;
2628
import java.util.ArrayList;
2729
import java.util.Arrays;
2830
import java.util.Collections;
2931
import java.util.Iterator;
3032
import java.util.List;
33+
import java.util.Map;
34+
import java.util.TreeMap;
3135

3236
/**
3337
* An operator that sorts "rows" of values by encoding the values to sort on, as bytes (using BytesRef). Each data type is encoded
@@ -194,6 +198,16 @@ private void writeValues(int position, BreakingBytesRefBuilder values) {
194198
}
195199
}
196200

201+
public record Partition(int channel) {
202+
203+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Partition.class);
204+
205+
@Override
206+
public String toString() {
207+
return "Partition[channel=" + this.channel + "]";
208+
}
209+
}
210+
197211
public record SortOrder(int channel, boolean asc, boolean nullsFirst) {
198212

199213
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(SortOrder.class);
@@ -224,6 +238,7 @@ public record TopNOperatorFactory(
224238
int topCount,
225239
List<ElementType> elementTypes,
226240
List<TopNEncoder> encoders,
241+
List<Partition> partitions,
227242
List<SortOrder> sortOrders,
228243
int maxPageSize
229244
) implements OperatorFactory {
@@ -243,6 +258,7 @@ public TopNOperator get(DriverContext driverContext) {
243258
topCount,
244259
elementTypes,
245260
encoders,
261+
partitions,
246262
sortOrders,
247263
maxPageSize
248264
);
@@ -256,6 +272,8 @@ public String describe() {
256272
+ elementTypes
257273
+ ", encoders="
258274
+ encoders
275+
+ ", partitions="
276+
+ partitions
259277
+ ", sortOrders="
260278
+ sortOrders
261279
+ "]";
@@ -264,12 +282,14 @@ public String describe() {
264282

265283
private final BlockFactory blockFactory;
266284
private final CircuitBreaker breaker;
267-
private final Queue inputQueue;
285+
private final Map<String, Queue> inputQueues;
268286

287+
private final int topCount;
269288
private final int maxPageSize;
270289

271290
private final List<ElementType> elementTypes;
272291
private final List<TopNEncoder> encoders;
292+
private final List<Partition> partitions;
273293
private final List<SortOrder> sortOrders;
274294

275295
private Row spare;
@@ -304,16 +324,19 @@ public TopNOperator(
304324
int topCount,
305325
List<ElementType> elementTypes,
306326
List<TopNEncoder> encoders,
327+
List<Partition> partitions,
307328
List<SortOrder> sortOrders,
308329
int maxPageSize
309330
) {
310331
this.blockFactory = blockFactory;
311332
this.breaker = breaker;
333+
this.topCount = topCount;
312334
this.maxPageSize = maxPageSize;
313335
this.elementTypes = elementTypes;
314336
this.encoders = encoders;
337+
this.partitions = partitions;
315338
this.sortOrders = sortOrders;
316-
this.inputQueue = new Queue(topCount);
339+
this.inputQueues = new TreeMap<>();
317340
}
318341

319342
static int compareRows(Row r1, Row r2) {
@@ -385,6 +408,8 @@ public void addInput(Page page) {
385408
spareKeysPreAllocSize = Math.max(spare.keys.length(), spareKeysPreAllocSize / 2);
386409
spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2);
387410

411+
String partitionKey = getPartitionKey(page, i);
412+
Queue inputQueue = inputQueues.computeIfAbsent(partitionKey, key -> new Queue(topCount));
388413
spare = inputQueue.insertWithOverflow(spare);
389414
}
390415
} finally {
@@ -394,6 +419,28 @@ public void addInput(Page page) {
394419
}
395420
}
396421

422+
/**
423+
* Calculates the partition key of the i-th row of the given page.
424+
*
425+
* @param page page for which the partition key should be calculated
426+
* @param i row index
427+
* @return partition key of the i-th row of the given page
428+
*/
429+
private String getPartitionKey(Page page, int i) {
430+
if (partitions.isEmpty()) {
431+
return "";
432+
}
433+
assert page.getPositionCount() > 0;
434+
StringBuilder builder = new StringBuilder();
435+
for (Partition partition : partitions) {
436+
try (var block = page.getBlock(partition.channel).filter(i)) {
437+
BytesRef partitionFieldValue = ((BytesRefBlock) block).getBytesRef(i, new BytesRef());
438+
builder.append(partitionFieldValue.utf8ToString());
439+
}
440+
}
441+
return builder.toString();
442+
}
443+
397444
@Override
398445
public void finish() {
399446
if (output == null) {
@@ -407,14 +454,17 @@ private Iterator<Page> toPages() {
407454
spare.close();
408455
spare = null;
409456
}
410-
if (inputQueue.size() == 0) {
411-
return Collections.emptyIterator();
412-
}
413-
List<Row> list = new ArrayList<>(inputQueue.size());
414-
List<Page> result = new ArrayList<>();
415-
ResultBuilder[] builders = null;
416457
boolean success = false;
458+
List<Row> list = null;
459+
ResultBuilder[] builders = null;
460+
List<Page> result = new ArrayList<>();
461+
// TODO: optimize case where all the queues are empty
417462
try {
463+
for (var entry : inputQueues.entrySet()) {
464+
Queue inputQueue = entry.getValue();
465+
466+
list = new ArrayList<>(inputQueue.size());
467+
builders = null;
418468
while (inputQueue.size() > 0) {
419469
list.add(inputQueue.pop());
420470
}
@@ -483,6 +533,7 @@ private Iterator<Page> toPages() {
483533
}
484534
}
485535
assert builders == null;
536+
}
486537
success = true;
487538
return result.iterator();
488539
} finally {
@@ -524,20 +575,20 @@ public Page getOutput() {
524575

525576
@Override
526577
public void close() {
578+
List<Releasable> releasables = new ArrayList<>();
579+
releasables.addAll(inputQueues.values().stream().map(Releasables::wrap).toList());
580+
releasables.add(output == null ? null : Releasables.wrap(() -> Iterators.map(output, p -> p::releaseBlocks)));
527581
/*
528582
* If we close before calling finish then spare and inputQueue will be live rows
529583
* that need closing. If we close after calling finish then the output iterator
530584
* will contain pages of results that have yet to be returned.
531585
*/
532-
Releasables.closeExpectNoException(
533-
spare,
534-
inputQueue == null ? null : Releasables.wrap(inputQueue),
535-
output == null ? null : Releasables.wrap(() -> Iterators.map(output, p -> p::releaseBlocks))
536-
);
586+
Releasables.closeExpectNoException(spare, Releasables.wrap(releasables));
537587
}
538588

539-
private static long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class) + RamUsageEstimator
540-
.shallowSizeOfInstance(List.class) * 3;
589+
private static long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class)
590+
+ RamUsageEstimator.shallowSizeOfInstance(List.class) * 4
591+
+ RamUsageEstimator.shallowSizeOfInstance(Map.class);
541592

542593
@Override
543594
public long ramBytesUsed() {
@@ -548,25 +599,34 @@ public long ramBytesUsed() {
548599
// These lists may slightly under-count, but it's not likely to be by much.
549600
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * elementTypes.size());
550601
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size());
602+
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * partitions.size());
603+
size += partitions.size() * Partition.SHALLOW_SIZE;
551604
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size());
552605
size += sortOrders.size() * SortOrder.SHALLOW_SIZE;
553-
size += inputQueue.ramBytesUsed();
606+
long ramBytesUsedSum = inputQueues.entrySet().stream()
607+
.mapToLong(e -> e.getKey().getBytes(Charset.defaultCharset()).length + e.getValue().ramBytesUsed())
608+
.sum();
609+
size += ramBytesUsedSum;
554610
return size;
555611
}
556612

557613
@Override
558614
public Status status() {
559-
return new TopNOperatorStatus(inputQueue.size(), ramBytesUsed(), pagesReceived, pagesEmitted, rowsReceived, rowsEmitted);
615+
int queueSizeSum = inputQueues.values().stream().mapToInt(Queue::size).sum();
616+
return new TopNOperatorStatus(queueSizeSum, ramBytesUsed(), pagesReceived, pagesEmitted, rowsReceived, rowsEmitted);
560617
}
561618

562619
@Override
563620
public String toString() {
621+
int queueSizeSum = inputQueues.values().stream().mapToInt(Queue::size).sum();
564622
return "TopNOperator[count="
565-
+ inputQueue
623+
+ queueSizeSum + "/" + topCount
566624
+ ", elementTypes="
567625
+ elementTypes
568626
+ ", encoders="
569627
+ encoders
628+
+ ", partitions="
629+
+ partitions
570630
+ ", sortOrders="
571631
+ sortOrders
572632
+ "]";

0 commit comments

Comments
 (0)