77
88package org .elasticsearch .compute .aggregation ;
99
10+ import com .carrotsearch .randomizedtesting .annotations .Name ;
1011import com .carrotsearch .randomizedtesting .annotations .ParametersFactory ;
1112
1213import org .apache .lucene .document .InetAddressPoint ;
1314import org .apache .lucene .util .BytesRef ;
1415import org .elasticsearch .common .Randomness ;
1516import org .elasticsearch .common .util .BigArrays ;
17+ import org .elasticsearch .compute .data .Block ;
1618import org .elasticsearch .compute .data .BlockTestUtils ;
19+ import org .elasticsearch .compute .data .BlockUtils ;
1720import org .elasticsearch .compute .data .ElementType ;
21+ import org .elasticsearch .compute .data .IntVector ;
22+ import org .elasticsearch .compute .data .TestBlockFactory ;
23+ import org .elasticsearch .compute .operator .DriverContext ;
24+ import org .elasticsearch .core .Releasables ;
1825import org .elasticsearch .test .ESTestCase ;
1926import org .elasticsearch .xpack .esql .core .type .DataType ;
2027
2128import java .util .ArrayList ;
2229import java .util .List ;
30+ import java .util .Locale ;
31+ import java .util .function .IntSupplier ;
2332
2433import static org .hamcrest .Matchers .equalTo ;
2534
@@ -29,21 +38,37 @@ public static List<Object[]> params() {
2938 List <Object []> params = new ArrayList <>();
3039
3140 for (boolean inOrder : new boolean [] { true , false }) {
32- params .add (new Object [] { DataType .INTEGER , 1000 , inOrder });
33- params .add (new Object [] { DataType .LONG , 1000 , inOrder });
34- params .add (new Object [] { DataType .FLOAT , 1000 , inOrder });
35- params .add (new Object [] { DataType .DOUBLE , 1000 , inOrder });
36- params .add (new Object [] { DataType .IP , 1000 , inOrder });
41+ for (IntSupplier count : new IntSupplier [] { new Fixed (100 ), new Fixed (1000 ), new Random (100 , 5000 ) }) {
42+ params .add (new Object [] { DataType .INTEGER , count , inOrder });
43+ params .add (new Object [] { DataType .LONG , count , inOrder });
44+ params .add (new Object [] { DataType .FLOAT , count , inOrder });
45+ params .add (new Object [] { DataType .DOUBLE , count , inOrder });
46+ params .add (new Object [] { DataType .IP , count , inOrder });
47+ }
3748 }
3849 return params ;
3950 }
4051
52+ private record Fixed (int i ) implements IntSupplier {
53+ @ Override
54+ public int getAsInt () {
55+ return i ;
56+ }
57+ }
58+
59+ private record Random (int min , int max ) implements IntSupplier {
60+ @ Override
61+ public int getAsInt () {
62+ return randomIntBetween (min , max );
63+ }
64+ }
65+
4166 private final DataType type ;
4267 private final ElementType elementType ;
4368 private final int valueCount ;
4469 private final boolean inOrder ;
4570
46- public ArrayStateTests (DataType type , int valueCount , boolean inOrder ) {
71+ public ArrayStateTests (@ Name ( "type" ) DataType type , @ Name ( "valueCount" ) IntSupplier valueCount , @ Name ( "inOrder" ) boolean inOrder ) {
4772 this .type = type ;
4873 this .elementType = switch (type ) {
4974 case INTEGER -> ElementType .INT ;
@@ -54,8 +79,9 @@ public ArrayStateTests(DataType type, int valueCount, boolean inOrder) {
5479 case IP -> ElementType .BYTES_REF ;
5580 default -> throw new IllegalArgumentException ();
5681 };
57- this .valueCount = valueCount ;
82+ this .valueCount = valueCount . getAsInt () ;
5883 this .inOrder = inOrder ;
84+ logger .info ("value count is {}" , this .valueCount );
5985 }
6086
6187 public void testSetNoTracking () {
@@ -146,6 +172,68 @@ public void testSetNullableThenOverwriteNullable() {
146172 }
147173 }
148174
175+ public void testToIntermediate () {
176+ AbstractArrayState state = newState ();
177+ List <Object > values = randomList (valueCount , valueCount , this ::randomValue );
178+ setAll (state , values , 0 );
179+ Block [] intermediate = new Block [2 ];
180+ DriverContext ctx = new DriverContext (BigArrays .NON_RECYCLING_INSTANCE , TestBlockFactory .getNonBreakingInstance ());
181+ state .toIntermediate (intermediate , 0 , IntVector .range (0 , valueCount , ctx .blockFactory ()), ctx );
182+ try {
183+ assertThat (intermediate [0 ].elementType (), equalTo (elementType ));
184+ assertThat (intermediate [1 ].elementType (), equalTo (ElementType .BOOLEAN ));
185+ assertThat (intermediate [0 ].getPositionCount (), equalTo (values .size ()));
186+ assertThat (intermediate [1 ].getPositionCount (), equalTo (values .size ()));
187+ for (int i = 0 ; i < values .size (); i ++) {
188+ Object v = values .get (i );
189+ assertThat (
190+ String .format (Locale .ROOT , "%05d: %s" , i , v != null ? v : "init" ),
191+ BlockUtils .toJavaObject (intermediate [0 ], i ),
192+ equalTo (v != null ? v : initialValue ())
193+ );
194+ assertThat (BlockUtils .toJavaObject (intermediate [1 ], i ), equalTo (true ));
195+ }
196+ } finally {
197+ Releasables .close (intermediate );
198+ }
199+ }
200+
201+ /**
202+ * Calls {@link GroupingAggregatorState#toIntermediate} with a range that's greater than
203+ * any collected values. This is acceptable if {@link AbstractArrayState#enableGroupIdTracking}
204+ * is called, so we do that.
205+ */
206+ public void testToIntermediatePastEnd () {
207+ int end = valueCount + between (1 , 10000 );
208+ AbstractArrayState state = newState ();
209+ state .enableGroupIdTracking (new SeenGroupIds .Empty ());
210+ List <Object > values = randomList (valueCount , valueCount , this ::randomValue );
211+ setAll (state , values , 0 );
212+ Block [] intermediate = new Block [2 ];
213+ DriverContext ctx = new DriverContext (BigArrays .NON_RECYCLING_INSTANCE , TestBlockFactory .getNonBreakingInstance ());
214+ state .toIntermediate (intermediate , 0 , IntVector .range (0 , end , ctx .blockFactory ()), ctx );
215+ try {
216+ assertThat (intermediate [0 ].elementType (), equalTo (elementType ));
217+ assertThat (intermediate [1 ].elementType (), equalTo (ElementType .BOOLEAN ));
218+ assertThat (intermediate [0 ].getPositionCount (), equalTo (end ));
219+ assertThat (intermediate [1 ].getPositionCount (), equalTo (end ));
220+ for (int i = 0 ; i < values .size (); i ++) {
221+ Object v = values .get (i );
222+ assertThat (
223+ String .format (Locale .ROOT , "%05d: %s" , i , v != null ? v : "init" ),
224+ BlockUtils .toJavaObject (intermediate [0 ], i ),
225+ equalTo (v != null ? v : initialValue ())
226+ );
227+ assertThat (BlockUtils .toJavaObject (intermediate [1 ], i ), equalTo (v != null ));
228+ }
229+ for (int i = values .size (); i < end ; i ++) {
230+ assertThat (BlockUtils .toJavaObject (intermediate [1 ], i ), equalTo (false ));
231+ }
232+ } finally {
233+ Releasables .close (intermediate );
234+ }
235+ }
236+
149237 private record ValueAndIndex (int index , Object value ) {}
150238
151239 private void setAll (AbstractArrayState state , List <Object > values , int offset ) {
@@ -181,6 +269,18 @@ private AbstractArrayState newState() {
181269 };
182270 }
183271
272+ private Object initialValue () {
273+ return switch (type ) {
274+ case INTEGER -> 1 ;
275+ case LONG -> 1L ;
276+ case FLOAT -> 1F ;
277+ case DOUBLE -> 1d ;
278+ case BOOLEAN -> false ;
279+ case IP -> new BytesRef (new byte [16 ]);
280+ default -> throw new IllegalArgumentException ();
281+ };
282+ }
283+
184284 private void set (AbstractArrayState state , int groupId , Object value ) {
185285 switch (type ) {
186286 case INTEGER -> ((IntArrayState ) state ).set (groupId , (Integer ) value );
0 commit comments