Skip to content

Commit cff81d9

Browse files
committed
Introduce grouping in TopN operator
1 parent 20c02f4 commit cff81d9

File tree

15 files changed

+328
-49
lines changed

15 files changed

+328
-49
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/TopNBenchmark.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ private static Operator operator(String data, int topCount) {
125125
topCount,
126126
elementTypes,
127127
encoders,
128+
List.of(),
128129
IntStream.range(0, count).mapToObj(c -> new TopNOperator.SortOrder(c, false, false)).toList(),
129130
16 * 1024
130131
);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.collect.Iterators;
1616
import org.elasticsearch.compute.data.Block;
1717
import org.elasticsearch.compute.data.BlockFactory;
18+
import org.elasticsearch.compute.data.BytesRefBlock;
1819
import org.elasticsearch.compute.data.ElementType;
1920
import org.elasticsearch.compute.data.Page;
2021
import org.elasticsearch.compute.operator.BreakingBytesRefBuilder;
@@ -23,11 +24,14 @@
2324
import org.elasticsearch.core.Releasable;
2425
import org.elasticsearch.core.Releasables;
2526

27+
import java.nio.charset.Charset;
2628
import java.util.ArrayList;
2729
import java.util.Arrays;
2830
import java.util.Collections;
2931
import java.util.Iterator;
3032
import java.util.List;
33+
import java.util.Map;
34+
import java.util.TreeMap;
3135

3236
/**
3337
* An operator that sorts "rows" of values by encoding the values to sort on, as bytes (using BytesRef). Each data type is encoded
@@ -194,6 +198,16 @@ private void writeValues(int position, BreakingBytesRefBuilder values) {
194198
}
195199
}
196200

201+
public record Partition(int channel) {
202+
203+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Partition.class);
204+
205+
@Override
206+
public String toString() {
207+
return "Partition[channel=" + this.channel + "]";
208+
}
209+
}
210+
197211
public record SortOrder(int channel, boolean asc, boolean nullsFirst) {
198212

199213
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(SortOrder.class);
@@ -224,6 +238,7 @@ public record TopNOperatorFactory(
224238
int topCount,
225239
List<ElementType> elementTypes,
226240
List<TopNEncoder> encoders,
241+
List<Partition> partitions,
227242
List<SortOrder> sortOrders,
228243
int maxPageSize
229244
) implements OperatorFactory {
@@ -243,6 +258,7 @@ public TopNOperator get(DriverContext driverContext) {
243258
topCount,
244259
elementTypes,
245260
encoders,
261+
partitions,
246262
sortOrders,
247263
maxPageSize
248264
);
@@ -256,6 +272,8 @@ public String describe() {
256272
+ elementTypes
257273
+ ", encoders="
258274
+ encoders
275+
+ ", partitions="
276+
+ partitions
259277
+ ", sortOrders="
260278
+ sortOrders
261279
+ "]";
@@ -264,12 +282,14 @@ public String describe() {
264282

265283
private final BlockFactory blockFactory;
266284
private final CircuitBreaker breaker;
267-
private final Queue inputQueue;
285+
private final Map<String, Queue> inputQueues;
268286

287+
private final int topCount;
269288
private final int maxPageSize;
270289

271290
private final List<ElementType> elementTypes;
272291
private final List<TopNEncoder> encoders;
292+
private final List<Partition> partitions;
273293
private final List<SortOrder> sortOrders;
274294

275295
private Row spare;
@@ -304,16 +324,19 @@ public TopNOperator(
304324
int topCount,
305325
List<ElementType> elementTypes,
306326
List<TopNEncoder> encoders,
327+
List<Partition> partitions,
307328
List<SortOrder> sortOrders,
308329
int maxPageSize
309330
) {
310331
this.blockFactory = blockFactory;
311332
this.breaker = breaker;
333+
this.topCount = topCount;
312334
this.maxPageSize = maxPageSize;
313335
this.elementTypes = elementTypes;
314336
this.encoders = encoders;
337+
this.partitions = partitions;
315338
this.sortOrders = sortOrders;
316-
this.inputQueue = new Queue(topCount);
339+
this.inputQueues = new TreeMap<>();
317340
}
318341

319342
static int compareRows(Row r1, Row r2) {
@@ -385,6 +408,8 @@ public void addInput(Page page) {
385408
spareKeysPreAllocSize = Math.max(spare.keys.length(), spareKeysPreAllocSize / 2);
386409
spareValuesPreAllocSize = Math.max(spare.values.length(), spareValuesPreAllocSize / 2);
387410

411+
String partitionKey = getPartitionKey(page, i);
412+
Queue inputQueue = inputQueues.computeIfAbsent(partitionKey, key -> new Queue(topCount));
388413
spare = inputQueue.insertWithOverflow(spare);
389414
}
390415
} finally {
@@ -394,6 +419,28 @@ public void addInput(Page page) {
394419
}
395420
}
396421

422+
/**
423+
* Calculates the partition key of the i-th row of the given page.
424+
*
425+
* @param page page for which the partition key should be calculated
426+
* @param i row index
427+
* @return partition key of the i-th row of the given page
428+
*/
429+
private String getPartitionKey(Page page, int i) {
430+
if (partitions.isEmpty()) {
431+
return "";
432+
}
433+
assert page.getPositionCount() > 0;
434+
StringBuilder builder = new StringBuilder();
435+
for (Partition partition : partitions) {
436+
try (var block = page.getBlock(partition.channel).filter(i)) {
437+
BytesRef partitionFieldValue = ((BytesRefBlock) block).getBytesRef(i, new BytesRef());
438+
builder.append(partitionFieldValue.utf8ToString());
439+
}
440+
}
441+
return builder.toString();
442+
}
443+
397444
@Override
398445
public void finish() {
399446
if (output == null) {
@@ -407,14 +454,17 @@ private Iterator<Page> toPages() {
407454
spare.close();
408455
spare = null;
409456
}
410-
if (inputQueue.size() == 0) {
411-
return Collections.emptyIterator();
412-
}
413-
List<Row> list = new ArrayList<>(inputQueue.size());
414-
List<Page> result = new ArrayList<>();
415-
ResultBuilder[] builders = null;
416457
boolean success = false;
458+
List<Row> list = null;
459+
ResultBuilder[] builders = null;
460+
List<Page> result = new ArrayList<>();
461+
// TODO: optimize case where all the queues are empty
417462
try {
463+
for (var entry : inputQueues.entrySet()) {
464+
Queue inputQueue = entry.getValue();
465+
466+
list = new ArrayList<>(inputQueue.size());
467+
builders = null;
418468
while (inputQueue.size() > 0) {
419469
list.add(inputQueue.pop());
420470
}
@@ -483,6 +533,7 @@ private Iterator<Page> toPages() {
483533
}
484534
}
485535
assert builders == null;
536+
}
486537
success = true;
487538
return result.iterator();
488539
} finally {
@@ -524,20 +575,20 @@ public Page getOutput() {
524575

525576
@Override
526577
public void close() {
578+
List<Releasable> releasables = new ArrayList<>();
579+
releasables.addAll(inputQueues.values().stream().map(Releasables::wrap).toList());
580+
releasables.add(output == null ? null : Releasables.wrap(() -> Iterators.map(output, p -> p::releaseBlocks)));
527581
/*
528582
* If we close before calling finish then spare and inputQueue will be live rows
529583
* that need closing. If we close after calling finish then the output iterator
530584
* will contain pages of results that have yet to be returned.
531585
*/
532-
Releasables.closeExpectNoException(
533-
spare,
534-
inputQueue == null ? null : Releasables.wrap(inputQueue),
535-
output == null ? null : Releasables.wrap(() -> Iterators.map(output, p -> p::releaseBlocks))
536-
);
586+
Releasables.closeExpectNoException(spare, Releasables.wrap(releasables));
537587
}
538588

539-
private static long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class) + RamUsageEstimator
540-
.shallowSizeOfInstance(List.class) * 3;
589+
private static long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class)
590+
+ RamUsageEstimator.shallowSizeOfInstance(List.class) * 4
591+
+ RamUsageEstimator.shallowSizeOfInstance(Map.class);
541592

542593
@Override
543594
public long ramBytesUsed() {
@@ -548,25 +599,34 @@ public long ramBytesUsed() {
548599
// These lists may slightly under-count, but it's not likely to be by much.
549600
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * elementTypes.size());
550601
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size());
602+
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * partitions.size());
603+
size += partitions.size() * Partition.SHALLOW_SIZE;
551604
size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size());
552605
size += sortOrders.size() * SortOrder.SHALLOW_SIZE;
553-
size += inputQueue.ramBytesUsed();
606+
long ramBytesUsedSum = inputQueues.entrySet().stream()
607+
.mapToLong(e -> e.getKey().getBytes(Charset.defaultCharset()).length + e.getValue().ramBytesUsed())
608+
.sum();
609+
size += ramBytesUsedSum;
554610
return size;
555611
}
556612

557613
@Override
558614
public Status status() {
559-
return new TopNOperatorStatus(inputQueue.size(), ramBytesUsed(), pagesReceived, pagesEmitted, rowsReceived, rowsEmitted);
615+
int queueSizeSum = inputQueues.values().stream().mapToInt(Queue::size).sum();
616+
return new TopNOperatorStatus(queueSizeSum, ramBytesUsed(), pagesReceived, pagesEmitted, rowsReceived, rowsEmitted);
560617
}
561618

562619
@Override
563620
public String toString() {
621+
int queueSizeSum = inputQueues.values().stream().mapToInt(Queue::size).sum();
564622
return "TopNOperator[count="
565-
+ inputQueue
623+
+ queueSizeSum + "/" + topCount
566624
+ ", elementTypes="
567625
+ elementTypes
568626
+ ", encoders="
569627
+ encoders
628+
+ ", partitions="
629+
+ partitions
570630
+ ", sortOrders="
571631
+ sortOrders
572632
+ "]";

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ protected TopNOperator.TopNOperatorFactory simple(SimpleOptions options) {
135135
4,
136136
List.of(LONG),
137137
List.of(DEFAULT_UNSORTABLE),
138+
List.of(),
138139
List.of(new TopNOperator.SortOrder(0, true, false)),
139140
pageSize
140141
);
@@ -143,15 +144,15 @@ protected TopNOperator.TopNOperatorFactory simple(SimpleOptions options) {
143144
@Override
144145
protected Matcher<String> expectedDescriptionOfSimple() {
145146
return equalTo(
146-
"TopNOperator[count=4, elementTypes=[LONG], encoders=[DefaultUnsortable], "
147+
"TopNOperator[count=4, elementTypes=[LONG], encoders=[DefaultUnsortable], partitions=[], "
147148
+ "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]]"
148149
);
149150
}
150151

151152
@Override
152153
protected Matcher<String> expectedToStringOfSimple() {
153154
return equalTo(
154-
"TopNOperator[count=0/4, elementTypes=[LONG], encoders=[DefaultUnsortable], "
155+
"TopNOperator[count=0/4, elementTypes=[LONG], encoders=[DefaultUnsortable], partitions=[], "
155156
+ "sortOrders=[SortOrder[channel=0, asc=true, nullsFirst=false]]]"
156157
);
157158
}
@@ -216,6 +217,7 @@ public long accumulateObject(Object o, long shallowSize, Map<Field, Object> fiel
216217
topCount,
217218
List.of(LONG),
218219
List.of(DEFAULT_UNSORTABLE),
220+
List.of(),
219221
List.of(new TopNOperator.SortOrder(0, true, false)),
220222
pageSize
221223
).get(context)
@@ -554,6 +556,7 @@ public void testCollectAllValues() {
554556
topCount,
555557
elementTypes,
556558
encoders,
559+
List.of(),
557560
List.of(new TopNOperator.SortOrder(0, false, false)),
558561
randomPageSize()
559562
)
@@ -643,6 +646,7 @@ public void testCollectAllValues_RandomMultiValues() {
643646
topCount,
644647
elementTypes,
645648
encoders,
649+
List.of(),
646650
List.of(new TopNOperator.SortOrder(0, false, false)),
647651
randomPageSize()
648652
)
@@ -677,6 +681,7 @@ private List<Tuple<Long, Long>> topNTwoColumns(
677681
limit,
678682
elementTypes,
679683
encoder,
684+
List.of(),
680685
sortOrders,
681686
randomPageSize()
682687
)
@@ -704,6 +709,7 @@ public void testTopNManyDescriptionAndToString() {
704709
10,
705710
List.of(BYTES_REF, BYTES_REF),
706711
List.of(UTF8, new FixedLengthTopNEncoder(fixedLength)),
712+
List.of(),
707713
List.of(new TopNOperator.SortOrder(1, false, false), new TopNOperator.SortOrder(3, false, true)),
708714
randomPageSize()
709715
);
@@ -712,7 +718,7 @@ public void testTopNManyDescriptionAndToString() {
712718
.collect(Collectors.joining(", "));
713719
String tail = ", elementTypes=[BYTES_REF, BYTES_REF], encoders=[UTF8TopNEncoder, FixedLengthTopNEncoder["
714720
+ fixedLength
715-
+ "]], sortOrders=["
721+
+ "]], partitions=[], sortOrders=["
716722
+ sorts
717723
+ "]]";
718724
assertThat(factory.describe(), equalTo("TopNOperator[count=10" + tail));
@@ -946,6 +952,7 @@ private void assertSortingOnMV(
946952
topCount,
947953
List.of(blockType),
948954
List.of(encoder),
955+
List.of(),
949956
List.of(sortOrders),
950957
randomPageSize()
951958
)
@@ -1076,6 +1083,7 @@ public void testRandomMultiValuesTopN() {
10761083
topCount,
10771084
elementTypes,
10781085
encoders,
1086+
List.of(),
10791087
uniqueOrders.stream().toList(),
10801088
rows
10811089
),
@@ -1119,6 +1127,7 @@ public void testIPSortingSingleValue() throws UnknownHostException {
11191127
ips.size(),
11201128
List.of(BYTES_REF),
11211129
List.of(TopNEncoder.IP),
1130+
List.of(),
11221131
List.of(new TopNOperator.SortOrder(0, asc, randomBoolean())),
11231132
randomPageSize()
11241133
)
@@ -1245,6 +1254,7 @@ private void assertIPSortingOnMultiValues(
12451254
ips.size(),
12461255
List.of(BYTES_REF),
12471256
List.of(TopNEncoder.IP),
1257+
List.of(),
12481258
List.of(new TopNOperator.SortOrder(0, asc, nullsFirst)),
12491259
randomPageSize()
12501260
)
@@ -1332,6 +1342,7 @@ public void testZeroByte() {
13321342
2,
13331343
List.of(BYTES_REF, INT),
13341344
List.of(TopNEncoder.UTF8, DEFAULT_UNSORTABLE),
1345+
List.of(),
13351346
List.of(
13361347
new TopNOperator.SortOrder(0, true, randomBoolean()),
13371348
new TopNOperator.SortOrder(1, randomBoolean(), randomBoolean())
@@ -1371,6 +1382,7 @@ public void testErrorBeforeFullyDraining() {
13711382
topCount,
13721383
List.of(LONG),
13731384
List.of(DEFAULT_UNSORTABLE),
1385+
List.of(),
13741386
List.of(new TopNOperator.SortOrder(0, true, randomBoolean())),
13751387
maxPageSize
13761388
)
@@ -1406,6 +1418,7 @@ public void testCloseWithoutCompleting() {
14061418
2,
14071419
List.of(INT),
14081420
List.of(DEFAULT_UNSORTABLE),
1421+
List.of(),
14091422
List.of(new TopNOperator.SortOrder(0, randomBoolean(), randomBoolean())),
14101423
randomPageSize()
14111424
)
@@ -1429,6 +1442,7 @@ public void testRowResizes() {
14291442
10,
14301443
types,
14311444
encoders,
1445+
List.of(),
14321446
List.of(new TopNOperator.SortOrder(0, randomBoolean(), randomBoolean())),
14331447
randomPageSize()
14341448
)

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ public static List<NamedWriteableRegistry.Entry> expressions() {
126126
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
127127
entries.addAll(ExpressionCoreWritables.expressions());
128128
entries.add(UnsupportedAttribute.EXPRESSION_ENTRY);
129+
entries.add(Partition.ENTRY);
129130
entries.add(Order.ENTRY);
130131
return entries;
131132
}

0 commit comments

Comments
 (0)