1515import org .elasticsearch .common .collect .Iterators ;
1616import org .elasticsearch .compute .data .Block ;
1717import org .elasticsearch .compute .data .BlockFactory ;
18+ import org .elasticsearch .compute .data .BytesRefBlock ;
1819import org .elasticsearch .compute .data .ElementType ;
1920import org .elasticsearch .compute .data .Page ;
2021import org .elasticsearch .compute .operator .BreakingBytesRefBuilder ;
2324import org .elasticsearch .core .Releasable ;
2425import org .elasticsearch .core .Releasables ;
2526
27+ import java .nio .charset .Charset ;
2628import java .util .ArrayList ;
2729import java .util .Arrays ;
2830import java .util .Collections ;
2931import java .util .Iterator ;
3032import java .util .List ;
33+ import java .util .Map ;
34+ import java .util .TreeMap ;
3135
3236/**
3337 * An operator that sorts "rows" of values by encoding the values to sort on, as bytes (using BytesRef). Each data type is encoded
@@ -194,6 +198,16 @@ private void writeValues(int position, BreakingBytesRefBuilder values) {
194198 }
195199 }
196200
201+ public record Partition (int channel ) {
202+
203+ private static final long SHALLOW_SIZE = RamUsageEstimator .shallowSizeOfInstance (Partition .class );
204+
205+ @ Override
206+ public String toString () {
207+ return "Partition[channel=" + this .channel + "]" ;
208+ }
209+ }
210+
197211 public record SortOrder (int channel , boolean asc , boolean nullsFirst ) {
198212
199213 private static final long SHALLOW_SIZE = RamUsageEstimator .shallowSizeOfInstance (SortOrder .class );
@@ -224,6 +238,7 @@ public record TopNOperatorFactory(
224238 int topCount ,
225239 List <ElementType > elementTypes ,
226240 List <TopNEncoder > encoders ,
241+ List <Partition > partitions ,
227242 List <SortOrder > sortOrders ,
228243 int maxPageSize
229244 ) implements OperatorFactory {
@@ -243,6 +258,7 @@ public TopNOperator get(DriverContext driverContext) {
243258 topCount ,
244259 elementTypes ,
245260 encoders ,
261+ partitions ,
246262 sortOrders ,
247263 maxPageSize
248264 );
@@ -256,6 +272,8 @@ public String describe() {
256272 + elementTypes
257273 + ", encoders="
258274 + encoders
275+ + ", partitions="
276+ + partitions
259277 + ", sortOrders="
260278 + sortOrders
261279 + "]" ;
@@ -264,12 +282,14 @@ public String describe() {
264282
265283 private final BlockFactory blockFactory ;
266284 private final CircuitBreaker breaker ;
267- private final Queue inputQueue ;
285+ private final Map < String , Queue > inputQueues ;
268286
287+ private final int topCount ;
269288 private final int maxPageSize ;
270289
271290 private final List <ElementType > elementTypes ;
272291 private final List <TopNEncoder > encoders ;
292+ private final List <Partition > partitions ;
273293 private final List <SortOrder > sortOrders ;
274294
275295 private Row spare ;
@@ -304,16 +324,19 @@ public TopNOperator(
304324 int topCount ,
305325 List <ElementType > elementTypes ,
306326 List <TopNEncoder > encoders ,
327+ List <Partition > partitions ,
307328 List <SortOrder > sortOrders ,
308329 int maxPageSize
309330 ) {
310331 this .blockFactory = blockFactory ;
311332 this .breaker = breaker ;
333+ this .topCount = topCount ;
312334 this .maxPageSize = maxPageSize ;
313335 this .elementTypes = elementTypes ;
314336 this .encoders = encoders ;
337+ this .partitions = partitions ;
315338 this .sortOrders = sortOrders ;
316- this .inputQueue = new Queue ( topCount );
339+ this .inputQueues = new TreeMap <>( );
317340 }
318341
319342 static int compareRows (Row r1 , Row r2 ) {
@@ -385,6 +408,8 @@ public void addInput(Page page) {
385408 spareKeysPreAllocSize = Math .max (spare .keys .length (), spareKeysPreAllocSize / 2 );
386409 spareValuesPreAllocSize = Math .max (spare .values .length (), spareValuesPreAllocSize / 2 );
387410
411+ String partitionKey = getPartitionKey (page , i );
412+ Queue inputQueue = inputQueues .computeIfAbsent (partitionKey , key -> new Queue (topCount ));
388413 spare = inputQueue .insertWithOverflow (spare );
389414 }
390415 } finally {
@@ -394,6 +419,28 @@ public void addInput(Page page) {
394419 }
395420 }
396421
422+ /**
423+ * Calculates the partition key of the i-th row of the given page.
424+ *
425+ * @param page page for which the partition key should be calculated
426+ * @param i row index
427+ * @return partition key of the i-th row of the given page
428+ */
429+ private String getPartitionKey (Page page , int i ) {
430+ if (partitions .isEmpty ()) {
431+ return "" ;
432+ }
433+ assert page .getPositionCount () > 0 ;
434+ StringBuilder builder = new StringBuilder ();
435+ for (Partition partition : partitions ) {
436+ try (var block = page .getBlock (partition .channel ).filter (i )) {
437+ BytesRef partitionFieldValue = ((BytesRefBlock ) block ).getBytesRef (i , new BytesRef ());
438+ builder .append (partitionFieldValue .utf8ToString ());
439+ }
440+ }
441+ return builder .toString ();
442+ }
443+
397444 @ Override
398445 public void finish () {
399446 if (output == null ) {
@@ -407,14 +454,17 @@ private Iterator<Page> toPages() {
407454 spare .close ();
408455 spare = null ;
409456 }
410- if (inputQueue .size () == 0 ) {
411- return Collections .emptyIterator ();
412- }
413- List <Row > list = new ArrayList <>(inputQueue .size ());
414- List <Page > result = new ArrayList <>();
415- ResultBuilder [] builders = null ;
416457 boolean success = false ;
458+ List <Row > list = null ;
459+ ResultBuilder [] builders = null ;
460+ List <Page > result = new ArrayList <>();
461+ // TODO: optimize case where all the queues are empty
417462 try {
463+ for (var entry : inputQueues .entrySet ()) {
464+ Queue inputQueue = entry .getValue ();
465+
466+ list = new ArrayList <>(inputQueue .size ());
467+ builders = null ;
418468 while (inputQueue .size () > 0 ) {
419469 list .add (inputQueue .pop ());
420470 }
@@ -483,6 +533,7 @@ private Iterator<Page> toPages() {
483533 }
484534 }
485535 assert builders == null ;
536+ }
486537 success = true ;
487538 return result .iterator ();
488539 } finally {
@@ -524,20 +575,20 @@ public Page getOutput() {
524575
525576 @ Override
526577 public void close () {
578+ List <Releasable > releasables = new ArrayList <>();
579+ releasables .addAll (inputQueues .values ().stream ().map (Releasables ::wrap ).toList ());
580+ releasables .add (output == null ? null : Releasables .wrap (() -> Iterators .map (output , p -> p ::releaseBlocks )));
527581 /*
528582 * If we close before calling finish then spare and inputQueue will be live rows
529583 * that need closing. If we close after calling finish then the output iterator
530584 * will contain pages of results that have yet to be returned.
531585 */
532- Releasables .closeExpectNoException (
533- spare ,
534- inputQueue == null ? null : Releasables .wrap (inputQueue ),
535- output == null ? null : Releasables .wrap (() -> Iterators .map (output , p -> p ::releaseBlocks ))
536- );
586+ Releasables .closeExpectNoException (spare , Releasables .wrap (releasables ));
537587 }
538588
539- private static long SHALLOW_SIZE = RamUsageEstimator .shallowSizeOfInstance (TopNOperator .class ) + RamUsageEstimator
540- .shallowSizeOfInstance (List .class ) * 3 ;
589+ private static long SHALLOW_SIZE = RamUsageEstimator .shallowSizeOfInstance (TopNOperator .class )
590+ + RamUsageEstimator .shallowSizeOfInstance (List .class ) * 4
591+ + RamUsageEstimator .shallowSizeOfInstance (Map .class );
541592
542593 @ Override
543594 public long ramBytesUsed () {
@@ -548,25 +599,34 @@ public long ramBytesUsed() {
548599 // These lists may slightly under-count, but it's not likely to be by much.
549600 size += RamUsageEstimator .alignObjectSize (arrHeader + ref * elementTypes .size ());
550601 size += RamUsageEstimator .alignObjectSize (arrHeader + ref * encoders .size ());
602+ size += RamUsageEstimator .alignObjectSize (arrHeader + ref * partitions .size ());
603+ size += partitions .size () * Partition .SHALLOW_SIZE ;
551604 size += RamUsageEstimator .alignObjectSize (arrHeader + ref * sortOrders .size ());
552605 size += sortOrders .size () * SortOrder .SHALLOW_SIZE ;
553- size += inputQueue .ramBytesUsed ();
606+ long ramBytesUsedSum = inputQueues .entrySet ().stream ()
607+ .mapToLong (e -> e .getKey ().getBytes (Charset .defaultCharset ()).length + e .getValue ().ramBytesUsed ())
608+ .sum ();
609+ size += ramBytesUsedSum ;
554610 return size ;
555611 }
556612
557613 @ Override
558614 public Status status () {
559- return new TopNOperatorStatus (inputQueue .size (), ramBytesUsed (), pagesReceived , pagesEmitted , rowsReceived , rowsEmitted );
615+ int queueSizeSum = inputQueues .values ().stream ().mapToInt (Queue ::size ).sum ();
616+ return new TopNOperatorStatus (queueSizeSum , ramBytesUsed (), pagesReceived , pagesEmitted , rowsReceived , rowsEmitted );
560617 }
561618
562619 @ Override
563620 public String toString () {
621+ int queueSizeSum = inputQueues .values ().stream ().mapToInt (Queue ::size ).sum ();
564622 return "TopNOperator[count="
565- + inputQueue
623+ + queueSizeSum + "/" + topCount
566624 + ", elementTypes="
567625 + elementTypes
568626 + ", encoders="
569627 + encoders
628+ + ", partitions="
629+ + partitions
570630 + ", sortOrders="
571631 + sortOrders
572632 + "]" ;
0 commit comments