@@ -73,6 +73,7 @@ public class AggregatorBenchmark {
7373 static final int BLOCK_LENGTH = 8 * 1024 ;
7474 private static final int OP_COUNT = 1024 ;
7575 private static final int GROUPS = 5 ;
76+ private static final int TOP_N_LIMIT = 3 ;
7677
7778 private static final BlockFactory blockFactory = BlockFactory .getInstance (
7879 new NoopCircuitBreaker ("noop" ),
@@ -90,6 +91,7 @@ public class AggregatorBenchmark {
9091 private static final String TWO_ORDINALS = "two_" + ORDINALS ;
9192 private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS ;
9293 private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS ;
94+ private static final String TOP_N_LONGS = "top_n_" + LONGS ;
9395
9496 private static final String VECTOR_DOUBLES = "vector_doubles" ;
9597 private static final String HALF_NULL_DOUBLES = "half_null_doubles" ;
@@ -147,7 +149,8 @@ static void selfTest() {
147149 TWO_BYTES_REFS ,
148150 TWO_ORDINALS ,
149151 LONGS_AND_BYTES_REFS ,
150- TWO_LONGS_AND_BYTES_REFS }
152+ TWO_LONGS_AND_BYTES_REFS ,
153+ TOP_N_LONGS }
151154 )
152155 public String grouping ;
153156
@@ -161,8 +164,7 @@ static void selfTest() {
161164 public String filter ;
162165
163166 private static Operator operator (DriverContext driverContext , String grouping , String op , String dataType , String filter ) {
164-
165- if (grouping .equals ("none" )) {
167+ if (grouping .equals (NONE )) {
166168 return new AggregationOperator (
167169 List .of (supplier (op , dataType , filter ).aggregatorFactory (AggregatorMode .SINGLE , List .of (0 )).apply (driverContext )),
168170 driverContext
@@ -188,6 +190,12 @@ private static Operator operator(DriverContext driverContext, String grouping, S
188190 new BlockHash .GroupSpec (1 , ElementType .LONG ),
189191 new BlockHash .GroupSpec (2 , ElementType .BYTES_REF )
190192 );
193+ case TOP_N_LONGS -> List .of (new BlockHash .GroupSpec (
194+ 0 ,
195+ ElementType .LONG ,
196+ false ,
197+ new BlockHash .TopNDef (0 , true , true , TOP_N_LIMIT )
198+ ));
191199 default -> throw new IllegalArgumentException ("unsupported grouping [" + grouping + "]" );
192200 };
193201 return new HashAggregationOperator (
@@ -271,10 +279,14 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
271279 case BOOLEANS -> 2 ;
272280 default -> GROUPS ;
273281 };
282+ int availableGroups = switch (grouping ) {
283+ case TOP_N_LONGS -> TOP_N_LIMIT ;
284+ default -> groups ;
285+ };
274286 switch (op ) {
275287 case AVG -> {
276288 DoubleBlock dValues = (DoubleBlock ) values ;
277- for (int g = 0 ; g < groups ; g ++) {
289+ for (int g = 0 ; g < availableGroups ; g ++) {
278290 long group = g ;
279291 long sum = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum ();
280292 long count = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).count ();
@@ -286,7 +298,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
286298 }
287299 case COUNT -> {
288300 LongBlock lValues = (LongBlock ) values ;
289- for (int g = 0 ; g < groups ; g ++) {
301+ for (int g = 0 ; g < availableGroups ; g ++) {
290302 long group = g ;
291303 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).count () * opCount ;
292304 if (lValues .getLong (g ) != expected ) {
@@ -296,7 +308,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
296308 }
297309 case COUNT_DISTINCT -> {
298310 LongBlock lValues = (LongBlock ) values ;
299- for (int g = 0 ; g < groups ; g ++) {
311+ for (int g = 0 ; g < availableGroups ; g ++) {
300312 long group = g ;
301313 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).distinct ().count ();
302314 long count = lValues .getLong (g );
@@ -310,15 +322,15 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
310322 switch (dataType ) {
311323 case LONGS -> {
312324 LongBlock lValues = (LongBlock ) values ;
313- for (int g = 0 ; g < groups ; g ++) {
325+ for (int g = 0 ; g < availableGroups ; g ++) {
314326 if (lValues .getLong (g ) != (long ) g ) {
315327 throw new AssertionError (prefix + "expected [" + g + "] but was [" + lValues .getLong (g ) + "]" );
316328 }
317329 }
318330 }
319331 case DOUBLES -> {
320332 DoubleBlock dValues = (DoubleBlock ) values ;
321- for (int g = 0 ; g < groups ; g ++) {
333+ for (int g = 0 ; g < availableGroups ; g ++) {
322334 if (dValues .getDouble (g ) != (long ) g ) {
323335 throw new AssertionError (prefix + "expected [" + g + "] but was [" + dValues .getDouble (g ) + "]" );
324336 }
@@ -331,7 +343,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
331343 switch (dataType ) {
332344 case LONGS -> {
333345 LongBlock lValues = (LongBlock ) values ;
334- for (int g = 0 ; g < groups ; g ++) {
346+ for (int g = 0 ; g < availableGroups ; g ++) {
335347 long group = g ;
336348 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).max ().getAsLong ();
337349 if (lValues .getLong (g ) != expected ) {
@@ -341,7 +353,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
341353 }
342354 case DOUBLES -> {
343355 DoubleBlock dValues = (DoubleBlock ) values ;
344- for (int g = 0 ; g < groups ; g ++) {
356+ for (int g = 0 ; g < availableGroups ; g ++) {
345357 long group = g ;
346358 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).max ().getAsLong ();
347359 if (dValues .getDouble (g ) != expected ) {
@@ -356,7 +368,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
356368 switch (dataType ) {
357369 case LONGS -> {
358370 LongBlock lValues = (LongBlock ) values ;
359- for (int g = 0 ; g < groups ; g ++) {
371+ for (int g = 0 ; g < availableGroups ; g ++) {
360372 long group = g ;
361373 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum () * opCount ;
362374 if (lValues .getLong (g ) != expected ) {
@@ -366,7 +378,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
366378 }
367379 case DOUBLES -> {
368380 DoubleBlock dValues = (DoubleBlock ) values ;
369- for (int g = 0 ; g < groups ; g ++) {
381+ for (int g = 0 ; g < availableGroups ; g ++) {
370382 long group = g ;
371383 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum () * opCount ;
372384 if (dValues .getDouble (g ) != expected ) {
@@ -391,6 +403,14 @@ private static void checkGroupingBlock(String prefix, String grouping, Block blo
391403 }
392404 }
393405 }
406+ case TOP_N_LONGS -> {
407+ LongBlock groups = (LongBlock ) block ;
408+ for (int g = 0 ; g < TOP_N_LIMIT ; g ++) {
409+ if (groups .getLong (g ) != (long ) g ) {
410+ throw new AssertionError (prefix + "bad group expected [" + g + "] but was [" + groups .getLong (g ) + "]" );
411+ }
412+ }
413+ }
394414 case INTS -> {
395415 IntBlock groups = (IntBlock ) block ;
396416 for (int g = 0 ; g < GROUPS ; g ++) {
@@ -495,7 +515,7 @@ private static void checkUngrouped(String prefix, String op, String dataType, Pa
495515
496516 private static Page page (BlockFactory blockFactory , String grouping , String blockType ) {
497517 Block dataBlock = dataBlock (blockFactory , blockType );
498- if (grouping .equals ("none" )) {
518+ if (grouping .equals (NONE )) {
499519 return new Page (dataBlock );
500520 }
501521 List <Block > blocks = groupingBlocks (grouping , blockType );
@@ -564,7 +584,7 @@ private static Block groupingBlock(String grouping, String blockType) {
564584 default -> throw new UnsupportedOperationException ("bad grouping [" + grouping + "]" );
565585 };
566586 return switch (grouping ) {
567- case LONGS -> {
587+ case TOP_N_LONGS , LONGS -> {
568588 var builder = blockFactory .newLongBlockBuilder (BLOCK_LENGTH );
569589 for (int i = 0 ; i < BLOCK_LENGTH ; i ++) {
570590 for (int v = 0 ; v < valuesPerGroup ; v ++) {
0 commit comments