@@ -73,6 +73,7 @@ public class AggregatorBenchmark {
7373 static final int BLOCK_LENGTH = 8 * 1024 ;
7474 private static final int OP_COUNT = 1024 ;
7575 private static final int GROUPS = 5 ;
76+ private static final int TOP_N_LIMIT = 3 ;
7677
7778 private static final BlockFactory blockFactory = BlockFactory .getInstance (
7879 new NoopCircuitBreaker ("noop" ),
@@ -90,6 +91,7 @@ public class AggregatorBenchmark {
9091 private static final String TWO_ORDINALS = "two_" + ORDINALS ;
9192 private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS ;
9293 private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS ;
94+ private static final String TOP_N_LONGS = "top_n_" + LONGS ;
9395
9496 private static final String VECTOR_DOUBLES = "vector_doubles" ;
9597 private static final String HALF_NULL_DOUBLES = "half_null_doubles" ;
@@ -147,7 +149,8 @@ static void selfTest() {
147149 TWO_BYTES_REFS ,
148150 TWO_ORDINALS ,
149151 LONGS_AND_BYTES_REFS ,
150- TWO_LONGS_AND_BYTES_REFS }
152+ TWO_LONGS_AND_BYTES_REFS ,
153+ TOP_N_LONGS }
151154 )
152155 public String grouping ;
153156
@@ -161,8 +164,7 @@ static void selfTest() {
161164 public String filter ;
162165
163166 private static Operator operator (DriverContext driverContext , String grouping , String op , String dataType , String filter ) {
164-
165- if (grouping .equals ("none" )) {
167+ if (grouping .equals (NONE )) {
166168 return new AggregationOperator (
167169 List .of (supplier (op , dataType , filter ).aggregatorFactory (AggregatorMode .SINGLE , List .of (0 )).apply (driverContext )),
168170 driverContext
@@ -188,6 +190,9 @@ private static Operator operator(DriverContext driverContext, String grouping, S
188190 new BlockHash .GroupSpec (1 , ElementType .LONG ),
189191 new BlockHash .GroupSpec (2 , ElementType .BYTES_REF )
190192 );
193+ case TOP_N_LONGS -> List .of (
194+ new BlockHash .GroupSpec (0 , ElementType .LONG , false , new BlockHash .TopNDef (0 , true , true , TOP_N_LIMIT ))
195+ );
191196 default -> throw new IllegalArgumentException ("unsupported grouping [" + grouping + "]" );
192197 };
193198 return new HashAggregationOperator (
@@ -271,10 +276,14 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
271276 case BOOLEANS -> 2 ;
272277 default -> GROUPS ;
273278 };
279+ int availableGroups = switch (grouping ) {
280+ case TOP_N_LONGS -> TOP_N_LIMIT ;
281+ default -> groups ;
282+ };
274283 switch (op ) {
275284 case AVG -> {
276285 DoubleBlock dValues = (DoubleBlock ) values ;
277- for (int g = 0 ; g < groups ; g ++) {
286+ for (int g = 0 ; g < availableGroups ; g ++) {
278287 long group = g ;
279288 long sum = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum ();
280289 long count = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).count ();
@@ -286,7 +295,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
286295 }
287296 case COUNT -> {
288297 LongBlock lValues = (LongBlock ) values ;
289- for (int g = 0 ; g < groups ; g ++) {
298+ for (int g = 0 ; g < availableGroups ; g ++) {
290299 long group = g ;
291300 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).count () * opCount ;
292301 if (lValues .getLong (g ) != expected ) {
@@ -296,7 +305,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
296305 }
297306 case COUNT_DISTINCT -> {
298307 LongBlock lValues = (LongBlock ) values ;
299- for (int g = 0 ; g < groups ; g ++) {
308+ for (int g = 0 ; g < availableGroups ; g ++) {
300309 long group = g ;
301310 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).distinct ().count ();
302311 long count = lValues .getLong (g );
@@ -310,15 +319,15 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
310319 switch (dataType ) {
311320 case LONGS -> {
312321 LongBlock lValues = (LongBlock ) values ;
313- for (int g = 0 ; g < groups ; g ++) {
322+ for (int g = 0 ; g < availableGroups ; g ++) {
314323 if (lValues .getLong (g ) != (long ) g ) {
315324 throw new AssertionError (prefix + "expected [" + g + "] but was [" + lValues .getLong (g ) + "]" );
316325 }
317326 }
318327 }
319328 case DOUBLES -> {
320329 DoubleBlock dValues = (DoubleBlock ) values ;
321- for (int g = 0 ; g < groups ; g ++) {
330+ for (int g = 0 ; g < availableGroups ; g ++) {
322331 if (dValues .getDouble (g ) != (long ) g ) {
323332 throw new AssertionError (prefix + "expected [" + g + "] but was [" + dValues .getDouble (g ) + "]" );
324333 }
@@ -331,7 +340,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
331340 switch (dataType ) {
332341 case LONGS -> {
333342 LongBlock lValues = (LongBlock ) values ;
334- for (int g = 0 ; g < groups ; g ++) {
343+ for (int g = 0 ; g < availableGroups ; g ++) {
335344 long group = g ;
336345 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).max ().getAsLong ();
337346 if (lValues .getLong (g ) != expected ) {
@@ -341,7 +350,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
341350 }
342351 case DOUBLES -> {
343352 DoubleBlock dValues = (DoubleBlock ) values ;
344- for (int g = 0 ; g < groups ; g ++) {
353+ for (int g = 0 ; g < availableGroups ; g ++) {
345354 long group = g ;
346355 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).max ().getAsLong ();
347356 if (dValues .getDouble (g ) != expected ) {
@@ -356,7 +365,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
356365 switch (dataType ) {
357366 case LONGS -> {
358367 LongBlock lValues = (LongBlock ) values ;
359- for (int g = 0 ; g < groups ; g ++) {
368+ for (int g = 0 ; g < availableGroups ; g ++) {
360369 long group = g ;
361370 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum () * opCount ;
362371 if (lValues .getLong (g ) != expected ) {
@@ -366,7 +375,7 @@ private static void checkGrouped(String prefix, String grouping, String op, Stri
366375 }
367376 case DOUBLES -> {
368377 DoubleBlock dValues = (DoubleBlock ) values ;
369- for (int g = 0 ; g < groups ; g ++) {
378+ for (int g = 0 ; g < availableGroups ; g ++) {
370379 long group = g ;
371380 long expected = LongStream .range (0 , BLOCK_LENGTH ).filter (l -> l % groups == group ).sum () * opCount ;
372381 if (dValues .getDouble (g ) != expected ) {
@@ -391,6 +400,14 @@ private static void checkGroupingBlock(String prefix, String grouping, Block blo
391400 }
392401 }
393402 }
403+ case TOP_N_LONGS -> {
404+ LongBlock groups = (LongBlock ) block ;
405+ for (int g = 0 ; g < TOP_N_LIMIT ; g ++) {
406+ if (groups .getLong (g ) != (long ) g ) {
407+ throw new AssertionError (prefix + "bad group expected [" + g + "] but was [" + groups .getLong (g ) + "]" );
408+ }
409+ }
410+ }
394411 case INTS -> {
395412 IntBlock groups = (IntBlock ) block ;
396413 for (int g = 0 ; g < GROUPS ; g ++) {
@@ -495,7 +512,7 @@ private static void checkUngrouped(String prefix, String op, String dataType, Pa
495512
496513 private static Page page (BlockFactory blockFactory , String grouping , String blockType ) {
497514 Block dataBlock = dataBlock (blockFactory , blockType );
498- if (grouping .equals ("none" )) {
515+ if (grouping .equals (NONE )) {
499516 return new Page (dataBlock );
500517 }
501518 List <Block > blocks = groupingBlocks (grouping , blockType );
@@ -564,7 +581,7 @@ private static Block groupingBlock(String grouping, String blockType) {
564581 default -> throw new UnsupportedOperationException ("bad grouping [" + grouping + "]" );
565582 };
566583 return switch (grouping ) {
567- case LONGS -> {
584+ case TOP_N_LONGS , LONGS -> {
568585 var builder = blockFactory .newLongBlockBuilder (BLOCK_LENGTH );
569586 for (int i = 0 ; i < BLOCK_LENGTH ; i ++) {
570587 for (int v = 0 ; v < valuesPerGroup ; v ++) {
0 commit comments