77
88package org .elasticsearch .compute .aggregation ;
99
10+ import org .elasticsearch .common .util .BitArray ;
1011import org .elasticsearch .compute .data .Block ;
1112import org .elasticsearch .compute .data .BooleanBlock ;
1213import org .elasticsearch .compute .data .BooleanVector ;
2223public class PresentGroupingAggregatorFunction implements GroupingAggregatorFunction {
2324
2425 private static final List <IntermediateStateDesc > INTERMEDIATE_STATE_DESC = List .of (
25- new IntermediateStateDesc ("present" , ElementType .BOOLEAN ),
26- new IntermediateStateDesc ("seen" , ElementType .BOOLEAN )
26+ new IntermediateStateDesc ("present" , ElementType .BOOLEAN )
2727 );
2828
29- private final BooleanArrayState state ;
29+ private final BitArray state ;
3030 private final List <Integer > channels ;
3131 private final DriverContext driverContext ;
3232
3333 public static PresentGroupingAggregatorFunction create (DriverContext driverContext , List <Integer > inputChannels ) {
34- return new PresentGroupingAggregatorFunction (inputChannels , new BooleanArrayState ( driverContext .bigArrays (), false ), driverContext );
34+ return new PresentGroupingAggregatorFunction (inputChannels , new BitArray ( 1 , driverContext .bigArrays ()), driverContext );
3535 }
3636
3737 public static List <IntermediateStateDesc > intermediateStateDesc () {
3838 return INTERMEDIATE_STATE_DESC ;
3939 }
4040
41- private PresentGroupingAggregatorFunction (List <Integer > channels , BooleanArrayState state , DriverContext driverContext ) {
41+ private PresentGroupingAggregatorFunction (List <Integer > channels , BitArray state , DriverContext driverContext ) {
4242 this .channels = channels ;
4343 this .state = state ;
4444 this .driverContext = driverContext ;
@@ -57,10 +57,6 @@ public int intermediateBlockCount() {
5757 public AddInput prepareProcessRawInputPage (SeenGroupIds seenGroupIds , Page page ) {
5858 Block valuesBlock = page .getBlock (blockIndex ());
5959
60- if (valuesBlock .mayHaveNulls ()) {
61- state .enableGroupIdTracking (seenGroupIds );
62- }
63-
6460 return new AddInput () {
6561 @ Override
6662 public void add (int positionOffset , IntArrayBlock groupIds ) {
@@ -88,8 +84,7 @@ private void addRawInput(int positionOffset, IntVector groups, Block values) {
8884 if (values .isNull (position )) {
8985 continue ;
9086 }
91- int groupId = groups .getInt (groupPosition );
92- state .set (groupId , state .getOrDefault (groupId ) || values .getValueCount (position ) > 0 );
87+ state .set (groups .getInt (groupPosition ), true );
9388 }
9489 }
9590
@@ -102,8 +97,7 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, Block values)
10297 int groupStart = groups .getFirstValueIndex (groupPosition );
10398 int groupEnd = groupStart + groups .getValueCount (groupPosition );
10499 for (int g = groupStart ; g < groupEnd ; g ++) {
105- int groupId = groups .getInt (g );
106- state .set (groupId , state .getOrDefault (groupId ) || values .getValueCount (position ) > 0 );
100+ state .set (groups .getInt (g ), true );
107101 }
108102 }
109103 }
@@ -117,34 +111,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block valu
117111 int groupStart = groups .getFirstValueIndex (groupPosition );
118112 int groupEnd = groupStart + groups .getValueCount (groupPosition );
119113 for (int g = groupStart ; g < groupEnd ; g ++) {
120- int groupId = groups .getInt (g );
121- state .set (groupId , state .getOrDefault (groupId ) || values .getValueCount (position ) > 0 );
114+ state .set (groups .getInt (g ), true );
122115 }
123116 }
124117 }
125118
126119 @ Override
127- public void selectedMayContainUnseenGroups (SeenGroupIds seenGroupIds ) {
128- state .enableGroupIdTracking (seenGroupIds );
129- }
120+ public void selectedMayContainUnseenGroups (SeenGroupIds seenGroupIds ) {}
130121
131122 @ Override
132123 public void addIntermediateInput (int positionOffset , IntArrayBlock groups , Page page ) {
133124 assert channels .size () == intermediateBlockCount ();
134125 assert page .getBlockCount () >= blockIndex () + intermediateStateDesc ().size ();
135- state .enableGroupIdTracking (new SeenGroupIds .Empty ());
136126 BooleanVector present = page .<BooleanBlock >getBlock (channels .get (0 )).asVector ();
137- BooleanVector seen = page .<BooleanBlock >getBlock (channels .get (1 )).asVector ();
138- assert present .getPositionCount () == seen .getPositionCount ();
139127 for (int groupPosition = 0 ; groupPosition < groups .getPositionCount (); groupPosition ++) {
140128 if (groups .isNull (groupPosition )) {
141129 continue ;
142130 }
143131 int groupStart = groups .getFirstValueIndex (groupPosition );
144132 int groupEnd = groupStart + groups .getValueCount (groupPosition );
145133 for (int g = groupStart ; g < groupEnd ; g ++) {
146- int groupId = groups .getInt (g );
147- state .set (groupId , state .getOrDefault (groupId ) || present .getBoolean (groupPosition + positionOffset ));
134+ if (present .getBoolean (groupPosition + positionOffset )) {
135+ state .set (groups .getInt (g ), true );
136+ }
148137 }
149138 }
150139 }
@@ -153,19 +142,17 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page
153142 public void addIntermediateInput (int positionOffset , IntBigArrayBlock groups , Page page ) {
154143 assert channels .size () == intermediateBlockCount ();
155144 assert page .getBlockCount () >= blockIndex () + intermediateStateDesc ().size ();
156- state .enableGroupIdTracking (new SeenGroupIds .Empty ());
157145 BooleanVector present = page .<BooleanBlock >getBlock (channels .get (0 )).asVector ();
158- BooleanVector seen = page .<BooleanBlock >getBlock (channels .get (1 )).asVector ();
159- assert present .getPositionCount () == seen .getPositionCount ();
160146 for (int groupPosition = 0 ; groupPosition < groups .getPositionCount (); groupPosition ++) {
161147 if (groups .isNull (groupPosition )) {
162148 continue ;
163149 }
164150 int groupStart = groups .getFirstValueIndex (groupPosition );
165151 int groupEnd = groupStart + groups .getValueCount (groupPosition );
166152 for (int g = groupStart ; g < groupEnd ; g ++) {
167- int groupId = groups .getInt (g );
168- state .set (groupId , state .getOrDefault (groupId ) || present .getBoolean (groupPosition + positionOffset ));
153+ if (present .getBoolean (groupPosition + positionOffset )) {
154+ state .set (groups .getInt (g ), true );
155+ }
169156 }
170157 }
171158 }
@@ -174,27 +161,35 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa
174161 public void addIntermediateInput (int positionOffset , IntVector groups , Page page ) {
175162 assert channels .size () == intermediateBlockCount ();
176163 assert page .getBlockCount () >= blockIndex () + intermediateStateDesc ().size ();
177- state .enableGroupIdTracking (new SeenGroupIds .Empty ());
178164 BooleanVector present = page .<BooleanBlock >getBlock (channels .get (0 )).asVector ();
179- BooleanVector seen = page .<BooleanBlock >getBlock (channels .get (1 )).asVector ();
180- assert present .getPositionCount () == seen .getPositionCount ();
181165 for (int groupPosition = 0 ; groupPosition < groups .getPositionCount (); groupPosition ++) {
182- int groupId = groups .getInt (groupPosition );
183- state .set (groupId , state .getOrDefault (groupId ) || present .getBoolean (groupPosition + positionOffset ));
166+ if (present .getBoolean (groupPosition + positionOffset )) {
167+ state .set (groups .getInt (groupPosition ), true );
168+ }
184169 }
185170 }
186171
187172 @ Override
188173 public void evaluateIntermediate (Block [] blocks , int offset , IntVector selected ) {
189- state .toIntermediate (blocks , offset , selected , driverContext );
174+ try (var valuesBuilder = driverContext .blockFactory ().newBooleanBlockBuilder (selected .getPositionCount ())) {
175+ for (int i = 0 ; i < selected .getPositionCount (); i ++) {
176+ int group = selected .getInt (i );
177+ if (group < state .size ()) {
178+ valuesBuilder .appendBoolean (state .get (group ));
179+ } else {
180+ valuesBuilder .appendBoolean (false );
181+ }
182+ }
183+ blocks [offset ] = valuesBuilder .build ();
184+ }
190185 }
191186
192187 @ Override
193188 public void evaluateFinal (Block [] blocks , int offset , IntVector selected , GroupingAggregatorEvaluationContext evaluationContext ) {
194189 try (BooleanVector .Builder builder = evaluationContext .blockFactory ().newBooleanVectorFixedBuilder (selected .getPositionCount ())) {
195190 for (int i = 0 ; i < selected .getPositionCount (); i ++) {
196191 int si = selected .getInt (i );
197- builder .appendBoolean (state .hasValue ( si ) && state . getOrDefault (si ));
192+ builder .appendBoolean (state .get (si ));
198193 }
199194 blocks [offset ] = builder .build ().asBlock ();
200195 }
0 commit comments