1717import org .elasticsearch .compute .aggregation .CountAggregatorFunction ;
1818import org .elasticsearch .compute .aggregation .CountDistinctDoubleAggregatorFunctionSupplier ;
1919import org .elasticsearch .compute .aggregation .CountDistinctLongAggregatorFunctionSupplier ;
20+ import org .elasticsearch .compute .aggregation .FilteredAggregatorFunctionSupplier ;
2021import org .elasticsearch .compute .aggregation .MaxDoubleAggregatorFunctionSupplier ;
2122import org .elasticsearch .compute .aggregation .MaxLongAggregatorFunctionSupplier ;
2223import org .elasticsearch .compute .aggregation .MinDoubleAggregatorFunctionSupplier ;
2728import org .elasticsearch .compute .data .Block ;
2829import org .elasticsearch .compute .data .BlockFactory ;
2930import org .elasticsearch .compute .data .BooleanBlock ;
31+ import org .elasticsearch .compute .data .BooleanVector ;
3032import org .elasticsearch .compute .data .BytesRefBlock ;
3133import org .elasticsearch .compute .data .DoubleBlock ;
3234import org .elasticsearch .compute .data .ElementType ;
3537import org .elasticsearch .compute .data .Page ;
3638import org .elasticsearch .compute .operator .AggregationOperator ;
3739import org .elasticsearch .compute .operator .DriverContext ;
40+ import org .elasticsearch .compute .operator .EvalOperator ;
3841import org .elasticsearch .compute .operator .HashAggregationOperator ;
3942import org .elasticsearch .compute .operator .Operator ;
4043import org .openjdk .jmh .annotations .Benchmark ;
@@ -94,13 +97,20 @@ public class AggregatorBenchmark {
9497
9598 private static final String NONE = "none" ;
9699
100+ private static final String CONSTANT_TRUE = "constant_true" ;
101+ private static final String ALL_TRUE = "all_true" ;
102+ private static final String HALF_TRUE = "half_true" ;
103+ private static final String CONSTANT_FALSE = "constant_false" ;
104+
97105 static {
98106 // Smoke test all the expected values and force loading subclasses more like prod
99107 try {
100108 for (String grouping : AggregatorBenchmark .class .getField ("grouping" ).getAnnotationsByType (Param .class )[0 ].value ()) {
101109 for (String op : AggregatorBenchmark .class .getField ("op" ).getAnnotationsByType (Param .class )[0 ].value ()) {
102110 for (String blockType : AggregatorBenchmark .class .getField ("blockType" ).getAnnotationsByType (Param .class )[0 ].value ()) {
103- run (grouping , op , blockType , 50 );
111+ for (String filter : AggregatorBenchmark .class .getField ("filter" ).getAnnotationsByType (Param .class )[0 ].value ()) {
112+ run (grouping , op , blockType , filter , 10 );
113+ }
104114 }
105115 }
106116 }
@@ -118,10 +128,14 @@ public class AggregatorBenchmark {
118128 @ Param ({ VECTOR_LONGS , HALF_NULL_LONGS , VECTOR_DOUBLES , HALF_NULL_DOUBLES })
119129 public String blockType ;
120130
121- private static Operator operator (DriverContext driverContext , String grouping , String op , String dataType ) {
131+ @ Param ({ NONE , CONSTANT_TRUE , ALL_TRUE , HALF_TRUE , CONSTANT_FALSE })
132+ public String filter ;
133+
134+ private static Operator operator (DriverContext driverContext , String grouping , String op , String dataType , String filter ) {
135+
122136 if (grouping .equals ("none" )) {
123137 return new AggregationOperator (
124- List .of (supplier (op , dataType , 0 ).aggregatorFactory (AggregatorMode .SINGLE ).apply (driverContext )),
138+ List .of (supplier (op , dataType , filter , 0 ).aggregatorFactory (AggregatorMode .SINGLE ).apply (driverContext )),
125139 driverContext
126140 );
127141 }
@@ -144,14 +158,14 @@ private static Operator operator(DriverContext driverContext, String grouping, S
144158 default -> throw new IllegalArgumentException ("unsupported grouping [" + grouping + "]" );
145159 };
146160 return new HashAggregationOperator (
147- List .of (supplier (op , dataType , groups .size ()).groupingAggregatorFactory (AggregatorMode .SINGLE )),
161+ List .of (supplier (op , dataType , filter , groups .size ()).groupingAggregatorFactory (AggregatorMode .SINGLE )),
148162 () -> BlockHash .build (groups , driverContext .blockFactory (), 16 * 1024 , false ),
149163 driverContext
150164 );
151165 }
152166
153- private static AggregatorFunctionSupplier supplier (String op , String dataType , int dataChannel ) {
154- return switch (op ) {
167+ private static AggregatorFunctionSupplier supplier (String op , String dataType , String filter , int dataChannel ) {
168+ return filtered ( switch (op ) {
155169 case COUNT -> CountAggregatorFunction .supplier (List .of (dataChannel ));
156170 case COUNT_DISTINCT -> switch (dataType ) {
157171 case LONGS -> new CountDistinctLongAggregatorFunctionSupplier (List .of (dataChannel ), 3000 );
@@ -174,10 +188,22 @@ private static AggregatorFunctionSupplier supplier(String op, String dataType, i
174188 default -> throw new IllegalArgumentException ("unsupported data type [" + dataType + "]" );
175189 };
176190 default -> throw new IllegalArgumentException ("unsupported op [" + op + "]" );
177- };
191+ }, filter ) ;
178192 }
179193
180- private static void checkExpected (String grouping , String op , String blockType , String dataType , Page page , int opCount ) {
194+ private static void checkExpected (
195+ String grouping ,
196+ String op ,
197+ String blockType ,
198+ String filter ,
199+ String dataType ,
200+ Page page ,
201+ int opCount
202+ ) {
203+ if (filter .equals (CONSTANT_FALSE ) || filter .equals (HALF_TRUE )) {
204+ // We don't verify these because it's hard to get the right answer.
205+ return ;
206+ }
181207 String prefix = String .format ("[%s][%s][%s] " , grouping , op , blockType );
182208 if (grouping .equals ("none" )) {
183209 checkUngrouped (prefix , op , dataType , page , opCount );
@@ -559,27 +585,73 @@ private static BytesRef bytesGroup(int group) {
559585 });
560586 }
561587
588+ private static AggregatorFunctionSupplier filtered (AggregatorFunctionSupplier agg , String filter ) {
589+ if (filter .equals ("none" )) {
590+ return agg ;
591+ }
592+ BooleanBlock mask = mask (filter ).asBlock ();
593+ return new FilteredAggregatorFunctionSupplier (agg , context -> new EvalOperator .ExpressionEvaluator () {
594+ @ Override
595+ public Block eval (Page page ) {
596+ mask .incRef ();
597+ return mask ;
598+ }
599+
600+ @ Override
601+ public void close () {
602+ mask .close ();
603+ }
604+ });
605+ }
606+
607+ private static BooleanVector mask (String filter ) {
608+ // Usually BLOCK_LENGTH is the count of positions, but sometimes the blocks are longer
609+ int positionCount = BLOCK_LENGTH * 10 ;
610+ return switch (filter ) {
611+ case CONSTANT_TRUE -> blockFactory .newConstantBooleanVector (true , positionCount );
612+ case ALL_TRUE -> {
613+ try (BooleanVector .Builder builder = blockFactory .newBooleanVectorFixedBuilder (positionCount )) {
614+ for (int i = 0 ; i < positionCount ; i ++) {
615+ builder .appendBoolean (true );
616+ }
617+ yield builder .build ();
618+ }
619+ }
620+ case HALF_TRUE -> {
621+ try (BooleanVector .Builder builder = blockFactory .newBooleanVectorFixedBuilder (positionCount )) {
622+ for (int i = 0 ; i < positionCount ; i ++) {
623+ builder .appendBoolean (i % 2 == 0 );
624+ }
625+ yield builder .build ();
626+ }
627+ }
628+ case CONSTANT_FALSE -> blockFactory .newConstantBooleanVector (false , positionCount );
629+ default -> throw new IllegalArgumentException ("unsupported filter [" + filter + "]" );
630+ };
631+ }
632+
562633 @ Benchmark
563634 @ OperationsPerInvocation (OP_COUNT * BLOCK_LENGTH )
564635 public void run () {
565- run (grouping , op , blockType , OP_COUNT );
636+ run (grouping , op , blockType , filter , OP_COUNT );
566637 }
567638
568- private static void run (String grouping , String op , String blockType , int opCount ) {
639+ private static void run (String grouping , String op , String blockType , String filter , int opCount ) {
640+ // System.err.printf("[%s][%s][%s][%s][%s]\n", grouping, op, blockType, filter, opCount);
569641 String dataType = switch (blockType ) {
570642 case VECTOR_LONGS , HALF_NULL_LONGS -> LONGS ;
571643 case VECTOR_DOUBLES , HALF_NULL_DOUBLES -> DOUBLES ;
572644 default -> throw new IllegalArgumentException ();
573645 };
574646
575647 DriverContext driverContext = driverContext ();
576- try (Operator operator = operator (driverContext , grouping , op , dataType )) {
648+ try (Operator operator = operator (driverContext , grouping , op , dataType , filter )) {
577649 Page page = page (driverContext .blockFactory (), grouping , blockType );
578650 for (int i = 0 ; i < opCount ; i ++) {
579651 operator .addInput (page .shallowCopy ());
580652 }
581653 operator .finish ();
582- checkExpected (grouping , op , blockType , dataType , operator .getOutput (), opCount );
654+ checkExpected (grouping , op , blockType , filter , dataType , operator .getOutput (), opCount );
583655 }
584656 }
585657
0 commit comments