3636import org .elasticsearch .compute .test .TestBlockBuilder ;
3737import org .elasticsearch .compute .test .TestBlockFactory ;
3838import org .elasticsearch .compute .test .TestDriverFactory ;
39- import org .elasticsearch .core .Releasables ;
4039import org .elasticsearch .core .Tuple ;
4140import org .elasticsearch .indices .CrankyCircuitBreakerService ;
4241import org .elasticsearch .test .ESTestCase ;
9089import static org .hamcrest .Matchers .equalTo ;
9190import static org .hamcrest .Matchers .greaterThan ;
9291import static org .hamcrest .Matchers .hasSize ;
93- import static org .hamcrest .Matchers .in ;
9492import static org .hamcrest .Matchers .is ;
9593import static org .hamcrest .Matchers .lessThan ;
9694import static org .hamcrest .Matchers .lessThanOrEqualTo ;
@@ -316,22 +314,192 @@ private List<Long> topNLong(List<Long> inputValues, int limit, boolean ascending
316314 }
317315
318316 public void testBasicTopNWithPartitionField () {
319- List <Tuple <Long , BytesRef >> values = denormalize (List .of (
320- tuple (new BytesRef ("a" ), Arrays .asList (2L , 1L , 4L , null , 5L , 10L , null , 20L , 4L , 100L )),
321- tuple (new BytesRef ("b" ), Arrays .asList (6L , 5L , 8L , null , 9L , 14L , null , 24L , 8L , 104L )),
322- tuple (new BytesRef ("c" ), Arrays .asList (-2L , -1L , -4L , null , -5L , -10L , null , -20L , -4L , -100L )),
323- tuple (new BytesRef ("" ), Arrays .asList (1L , 2L , 3L ))
324- ));
325- assertThat (topNLongWithPartitionField (driverContext (), values , 1 , true , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (1L , "a" ), tuple (5L , "b" ), tuple (-100L , "c" )))));
326- assertThat (topNLongWithPartitionField (driverContext (), values , 1 , false , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (3L , "" ), tuple (100L , "a" ), tuple (104L , "b" ), tuple (-1L , "c" )))));
327- assertThat (topNLongWithPartitionField (driverContext (), values , 2 , true , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (2L , "" ), tuple (1L , "a" ), tuple (2L , "a" ), tuple (5L , "b" ), tuple (6L , "b" ), tuple (-100L , "c" ), tuple (-20L , "c" )))));
328- assertThat (topNLongWithPartitionField (driverContext (), values , 2 , false , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (3L , "" ), tuple (2L , "" ), tuple (100L , "a" ), tuple (20L , "a" ), tuple (104L , "b" ), tuple (24L , "b" ), tuple (-1L , "c" ), tuple (-2L , "c" )))));
329- assertThat (topNLongWithPartitionField (driverContext (), values , 3 , true , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (2L , "" ), tuple (3L , "" ), tuple (1L , "a" ), tuple (2L , "a" ), tuple (4L , "a" ), tuple (5L , "b" ), tuple (6L , "b" ), tuple (8L , "b" ), tuple (-100L , "c" ), tuple (-20L , "c" ), tuple (-10L , "c" )))));
330- assertThat (topNLongWithPartitionField (driverContext (), values , 3 , false , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (3L , "" ), tuple (2L , "" ), tuple (1L , "" ), tuple (100L , "a" ), tuple (20L , "a" ), tuple (10L , "a" ), tuple (104L , "b" ), tuple (24L , "b" ), tuple (14L , "b" ), tuple (-1L , "c" ), tuple (-2L , "c" ), tuple (-4L , "c" )))));
331- assertThat (topNLongWithPartitionField (driverContext (), values , 4 , true , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (2L , "" ), tuple (3L , "" ), tuple (1L , "a" ), tuple (2L , "a" ), tuple (4L , "a" ), tuple (4L , "a" ), tuple (5L , "b" ), tuple (6L , "b" ), tuple (8L , "b" ), tuple (8L , "b" ), tuple (-100L , "c" ), tuple (-20L , "c" ), tuple (-10L , "c" ), tuple (-5L , "c" )))));
332- assertThat (topNLongWithPartitionField (driverContext (), values , 4 , false , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (3L , "" ), tuple (2L , "" ), tuple (1L , "" ), tuple (100L , "a" ), tuple (20L , "a" ), tuple (10L , "a" ), tuple (5L , "a" ), tuple (104L , "b" ), tuple (24L , "b" ), tuple (14L , "b" ), tuple (9L , "b" ), tuple (-1L , "c" ), tuple (-2L , "c" ), tuple (-4L , "c" ), tuple (-4L , "c" )))));
333- assertThat (topNLongWithPartitionField (driverContext (), values , 100 , true , false ), equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (2L , "" ), tuple (3L , "" ), tuple (1L , "a" ), tuple (2L , "a" ), tuple (4L , "a" ), tuple (4L , "a" ), tuple (5L , "a" ), tuple (10L , "a" ), tuple (20L , "a" ), tuple (100L , "a" ), tuple (null , "a" ), tuple (null , "a" ), tuple (5L , "b" ), tuple (6L , "b" ), tuple (8L , "b" ), tuple (8L , "b" ), tuple (9L , "b" ), tuple (14L , "b" ), tuple (24L , "b" ), tuple (104L , "b" ), tuple (null , "b" ), tuple (null , "b" ), tuple (-100L , "c" ), tuple (-20L , "c" ), tuple (-10L , "c" ), tuple (-5L , "c" ), tuple (-4L , "c" ), tuple (-4L , "c" ), tuple (-2L , "c" ), tuple (-1L , "c" ), tuple (null , "c" ), tuple (null , "c" )))));
334- assertThat (topNLongWithPartitionField (driverContext (), values , 100 , false , false ), equalTo (Arrays .asList (100L , 20L , 10L , 5L , 4L , 4L , 2L , 1L , null , null )));
317+ List <Tuple <Long , BytesRef >> values = denormalize (
318+ List .of (
319+ tuple (new BytesRef ("a" ), Arrays .asList (2L , 1L , 4L , null , 5L , 10L , null , 20L , 4L , 100L )),
320+ tuple (new BytesRef ("b" ), Arrays .asList (6L , 5L , 8L , null , 9L , 14L , null , 24L , 8L , 104L )),
321+ tuple (new BytesRef ("c" ), Arrays .asList (-2L , -1L , -4L , null , -5L , -10L , null , -20L , -4L , -100L )),
322+ tuple (new BytesRef ("" ), Arrays .asList (1L , 2L , 3L ))
323+ )
324+ );
325+ assertThat (
326+ topNLongWithPartitionField (driverContext (), values , 1 , true , false ),
327+ equalTo (stringsToBytesRefs (Arrays .asList (tuple (1L , "" ), tuple (1L , "a" ), tuple (5L , "b" ), tuple (-100L , "c" ))))
328+ );
329+ assertThat (
330+ topNLongWithPartitionField (driverContext (), values , 1 , false , false ),
331+ equalTo (stringsToBytesRefs (Arrays .asList (tuple (3L , "" ), tuple (100L , "a" ), tuple (104L , "b" ), tuple (-1L , "c" ))))
332+ );
333+ assertThat (
334+ topNLongWithPartitionField (driverContext (), values , 2 , true , false ),
335+ equalTo (
336+ stringsToBytesRefs (
337+ Arrays .asList (
338+ tuple (1L , "" ),
339+ tuple (2L , "" ),
340+ tuple (1L , "a" ),
341+ tuple (2L , "a" ),
342+ tuple (5L , "b" ),
343+ tuple (6L , "b" ),
344+ tuple (-100L , "c" ),
345+ tuple (-20L , "c" )
346+ )
347+ )
348+ )
349+ );
350+ assertThat (
351+ topNLongWithPartitionField (driverContext (), values , 2 , false , false ),
352+ equalTo (
353+ stringsToBytesRefs (
354+ Arrays .asList (
355+ tuple (3L , "" ),
356+ tuple (2L , "" ),
357+ tuple (100L , "a" ),
358+ tuple (20L , "a" ),
359+ tuple (104L , "b" ),
360+ tuple (24L , "b" ),
361+ tuple (-1L , "c" ),
362+ tuple (-2L , "c" )
363+ )
364+ )
365+ )
366+ );
367+ assertThat (
368+ topNLongWithPartitionField (driverContext (), values , 3 , true , false ),
369+ equalTo (
370+ stringsToBytesRefs (
371+ Arrays .asList (
372+ tuple (1L , "" ),
373+ tuple (2L , "" ),
374+ tuple (3L , "" ),
375+ tuple (1L , "a" ),
376+ tuple (2L , "a" ),
377+ tuple (4L , "a" ),
378+ tuple (5L , "b" ),
379+ tuple (6L , "b" ),
380+ tuple (8L , "b" ),
381+ tuple (-100L , "c" ),
382+ tuple (-20L , "c" ),
383+ tuple (-10L , "c" )
384+ )
385+ )
386+ )
387+ );
388+ assertThat (
389+ topNLongWithPartitionField (driverContext (), values , 3 , false , false ),
390+ equalTo (
391+ stringsToBytesRefs (
392+ Arrays .asList (
393+ tuple (3L , "" ),
394+ tuple (2L , "" ),
395+ tuple (1L , "" ),
396+ tuple (100L , "a" ),
397+ tuple (20L , "a" ),
398+ tuple (10L , "a" ),
399+ tuple (104L , "b" ),
400+ tuple (24L , "b" ),
401+ tuple (14L , "b" ),
402+ tuple (-1L , "c" ),
403+ tuple (-2L , "c" ),
404+ tuple (-4L , "c" )
405+ )
406+ )
407+ )
408+ );
409+ assertThat (
410+ topNLongWithPartitionField (driverContext (), values , 4 , true , false ),
411+ equalTo (
412+ stringsToBytesRefs (
413+ Arrays .asList (
414+ tuple (1L , "" ),
415+ tuple (2L , "" ),
416+ tuple (3L , "" ),
417+ tuple (1L , "a" ),
418+ tuple (2L , "a" ),
419+ tuple (4L , "a" ),
420+ tuple (4L , "a" ),
421+ tuple (5L , "b" ),
422+ tuple (6L , "b" ),
423+ tuple (8L , "b" ),
424+ tuple (8L , "b" ),
425+ tuple (-100L , "c" ),
426+ tuple (-20L , "c" ),
427+ tuple (-10L , "c" ),
428+ tuple (-5L , "c" )
429+ )
430+ )
431+ )
432+ );
433+ assertThat (
434+ topNLongWithPartitionField (driverContext (), values , 4 , false , false ),
435+ equalTo (
436+ stringsToBytesRefs (
437+ Arrays .asList (
438+ tuple (3L , "" ),
439+ tuple (2L , "" ),
440+ tuple (1L , "" ),
441+ tuple (100L , "a" ),
442+ tuple (20L , "a" ),
443+ tuple (10L , "a" ),
444+ tuple (5L , "a" ),
445+ tuple (104L , "b" ),
446+ tuple (24L , "b" ),
447+ tuple (14L , "b" ),
448+ tuple (9L , "b" ),
449+ tuple (-1L , "c" ),
450+ tuple (-2L , "c" ),
451+ tuple (-4L , "c" ),
452+ tuple (-4L , "c" )
453+ )
454+ )
455+ )
456+ );
457+ assertThat (
458+ topNLongWithPartitionField (driverContext (), values , 100 , true , false ),
459+ equalTo (
460+ stringsToBytesRefs (
461+ Arrays .asList (
462+ tuple (1L , "" ),
463+ tuple (2L , "" ),
464+ tuple (3L , "" ),
465+ tuple (1L , "a" ),
466+ tuple (2L , "a" ),
467+ tuple (4L , "a" ),
468+ tuple (4L , "a" ),
469+ tuple (5L , "a" ),
470+ tuple (10L , "a" ),
471+ tuple (20L , "a" ),
472+ tuple (100L , "a" ),
473+ tuple (null , "a" ),
474+ tuple (null , "a" ),
475+ tuple (5L , "b" ),
476+ tuple (6L , "b" ),
477+ tuple (8L , "b" ),
478+ tuple (8L , "b" ),
479+ tuple (9L , "b" ),
480+ tuple (14L , "b" ),
481+ tuple (24L , "b" ),
482+ tuple (104L , "b" ),
483+ tuple (null , "b" ),
484+ tuple (null , "b" ),
485+ tuple (-100L , "c" ),
486+ tuple (-20L , "c" ),
487+ tuple (-10L , "c" ),
488+ tuple (-5L , "c" ),
489+ tuple (-4L , "c" ),
490+ tuple (-4L , "c" ),
491+ tuple (-2L , "c" ),
492+ tuple (-1L , "c" ),
493+ tuple (null , "c" ),
494+ tuple (null , "c" )
495+ )
496+ )
497+ )
498+ );
499+ assertThat (
500+ topNLongWithPartitionField (driverContext (), values , 100 , false , false ),
501+ equalTo (Arrays .asList (100L , 20L , 10L , 5L , 4L , 4L , 2L , 1L , null , null ))
502+ );
335503 assertThat (topNLongWithPartitionField (driverContext (), values , 1 , true , true ), equalTo (Arrays .asList (new Long [] { null })));
336504 assertThat (topNLongWithPartitionField (driverContext (), values , 1 , false , true ), equalTo (Arrays .asList (new Long [] { null })));
337505 assertThat (topNLongWithPartitionField (driverContext (), values , 2 , true , true ), equalTo (Arrays .asList (null , null )));
@@ -340,8 +508,14 @@ public void testBasicTopNWithPartitionField() {
340508 assertThat (topNLongWithPartitionField (driverContext (), values , 3 , false , true ), equalTo (Arrays .asList (null , null , 100L )));
341509 assertThat (topNLongWithPartitionField (driverContext (), values , 4 , true , true ), equalTo (Arrays .asList (null , null , 1L , 2L )));
342510 assertThat (topNLongWithPartitionField (driverContext (), values , 4 , false , true ), equalTo (Arrays .asList (null , null , 100L , 20L )));
343- assertThat (topNLongWithPartitionField (driverContext (), values , 100 , true , true ), equalTo (Arrays .asList (null , null , 1L , 2L , 4L , 4L , 5L , 10L , 20L , 100L )));
344- assertThat (topNLongWithPartitionField (driverContext (), values , 100 , false , true ), equalTo (Arrays .asList (null , null , 100L , 20L , 10L , 5L , 4L , 4L , 2L , 1L )));
511+ assertThat (
512+ topNLongWithPartitionField (driverContext (), values , 100 , true , true ),
513+ equalTo (Arrays .asList (null , null , 1L , 2L , 4L , 4L , 5L , 10L , 20L , 100L ))
514+ );
515+ assertThat (
516+ topNLongWithPartitionField (driverContext (), values , 100 , false , true ),
517+ equalTo (Arrays .asList (null , null , 100L , 20L , 10L , 5L , 4L , 4L , 2L , 1L ))
518+ );
345519 }
346520
347521 private static List <Tuple <Long , BytesRef >> denormalize (List <Tuple <BytesRef , List <Long >>> values ) {
@@ -375,24 +549,42 @@ private void testRandomTopNWithPartitionField(boolean ascendingOrder, DriverCont
375549 final boolean nullsFirst = randomBoolean ();
376550 final int noPartitions = randomIntBetween (1 , 20 );
377551 final int limit = randomIntBetween (1 , 20 );
378- List <Tuple <Long , BytesRef >> inputValues = randomList (0 , 5000 , () ->
379- Tuple .tuple (
380- randomLongBetween (-10_000 , 10_000 ),
381- new BytesRef ("partition_" + randomIntBetween (1 , noPartitions ))
382- )
552+ List <Tuple <Long , BytesRef >> inputValues = randomList (
553+ 0 ,
554+ 5000 ,
555+ () -> Tuple .tuple (randomLongBetween (-10_000 , 10_000 ), new BytesRef ("partition_" + randomIntBetween (1 , noPartitions )))
383556 );
384557 List <Tuple <Long , BytesRef >> expectedOutputValues = inputValues .stream ()
385- .collect (groupingBy (Tuple ::v2 , TreeMap ::new , collectingAndThen (toList (),
386- values -> values .stream ().sorted (valueComparator ).limit (limit ).collect (toList ()))))
387- .entrySet ().stream ().flatMap (e -> e .getValue ().stream ()).toList ();
558+ .collect (
559+ groupingBy (
560+ Tuple ::v2 ,
561+ TreeMap ::new ,
562+ collectingAndThen (toList (), values -> values .stream ().sorted (valueComparator ).limit (limit ).collect (toList ()))
563+ )
564+ )
565+ .entrySet ()
566+ .stream ()
567+ .flatMap (e -> e .getValue ().stream ())
568+ .toList ();
388569
389- List <Tuple <Long , BytesRef >> actualOutputValues =
390- topNLongWithPartitionField (driverContext , inputValues , limit , ascendingOrder , nullsFirst );
570+ List <Tuple <Long , BytesRef >> actualOutputValues = topNLongWithPartitionField (
571+ driverContext ,
572+ inputValues ,
573+ limit ,
574+ ascendingOrder ,
575+ nullsFirst
576+ );
391577
392578 assertThat (actualOutputValues , equalTo (expectedOutputValues ));
393579 }
394580
395- private List <Tuple <Long , BytesRef >> topNLongWithPartitionField (DriverContext driverContext , List <Tuple <Long , BytesRef >> inputValues , int limit , boolean ascendingOrder , boolean nullsFirst ) {
581+ private List <Tuple <Long , BytesRef >> topNLongWithPartitionField (
582+ DriverContext driverContext ,
583+ List <Tuple <Long , BytesRef >> inputValues ,
584+ int limit ,
585+ boolean ascendingOrder ,
586+ boolean nullsFirst
587+ ) {
396588 List <Page > outputPages = new ArrayList <>();
397589 List <Tuple <Long , BytesRef >> actualOutputValues = new ArrayList <>();
398590 try (
@@ -424,7 +616,7 @@ private List<Tuple<Long, BytesRef>> topNLongWithPartitionField(DriverContext dri
424616 runDriver (driver );
425617 }
426618 assertDriverContext (driverContext );
427- // outputPages.forEach(Page::releaseBlocks);
619+ // outputPages.forEach(Page::releaseBlocks);
428620 return actualOutputValues ;
429621 }
430622
0 commit comments