2020import org .elasticsearch .compute .aggregation .ValuesBytesRefAggregatorFunctionSupplier ;
2121import org .elasticsearch .compute .data .Block ;
2222import org .elasticsearch .compute .data .BlockFactory ;
23+ import org .elasticsearch .compute .data .BlockUtils ;
2324import org .elasticsearch .compute .data .BytesRefBlock ;
24- import org .elasticsearch .compute .data .BytesRefVector ;
2525import org .elasticsearch .compute .data .ElementType ;
2626import org .elasticsearch .compute .data .IntBlock ;
27- import org .elasticsearch .compute .data .IntVector ;
2827import org .elasticsearch .compute .data .Page ;
2928import org .elasticsearch .compute .operator .CannedSourceOperator ;
3029import org .elasticsearch .compute .operator .Driver ;
@@ -72,6 +71,8 @@ public void testCategorize_withDriver() {
7271 BigArrays bigArrays = new MockBigArrays (PageCacheRecycler .NON_RECYCLING_INSTANCE , ByteSizeValue .ofMb (256 )).withCircuitBreaking ();
7372 CircuitBreaker breaker = bigArrays .breakerService ().getBreaker (CircuitBreaker .REQUEST );
7473 DriverContext driverContext = new DriverContext (bigArrays , new BlockFactory (breaker , bigArrays ));
74+ boolean withNull = randomBoolean ();
75+ boolean withMultivalues = randomBoolean ();
7576
7677 List <BlockHash .GroupSpec > groupSpecs = List .of (
7778 new BlockHash .GroupSpec (0 , ElementType .BYTES_REF , true ),
@@ -80,28 +81,42 @@ public void testCategorize_withDriver() {
8081
8182 LocalSourceOperator .BlockSupplier input1 = () -> {
8283 try (
83- BytesRefVector .Builder messagesBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
84- IntVector .Builder idsBuilder = driverContext .blockFactory ().newIntVectorBuilder (10 )
84+ BytesRefBlock .Builder messagesBuilder = driverContext .blockFactory ().newBytesRefBlockBuilder (10 );
85+ IntBlock .Builder idsBuilder = driverContext .blockFactory ().newIntBlockBuilder (10 )
8586 ) {
87+ if (withMultivalues ) {
88+ messagesBuilder .beginPositionEntry ();
89+ }
8690 messagesBuilder .appendBytesRef (new BytesRef ("connected to 1.1.1" ));
8791 messagesBuilder .appendBytesRef (new BytesRef ("connected to 1.1.2" ));
92+ if (withMultivalues ) {
93+ messagesBuilder .endPositionEntry ();
94+ }
95+ idsBuilder .appendInt (7 );
96+ if (withMultivalues == false ) {
97+ idsBuilder .appendInt (7 );
98+ }
99+
88100 messagesBuilder .appendBytesRef (new BytesRef ("connected to 1.1.3" ));
89101 messagesBuilder .appendBytesRef (new BytesRef ("connection error" ));
90102 messagesBuilder .appendBytesRef (new BytesRef ("connection error" ));
91103 messagesBuilder .appendBytesRef (new BytesRef ("connected to 1.1.4" ));
92- idsBuilder .appendInt (7 );
93- idsBuilder .appendInt (7 );
94104 idsBuilder .appendInt (42 );
95105 idsBuilder .appendInt (7 );
96106 idsBuilder .appendInt (42 );
97107 idsBuilder .appendInt (7 );
98- return new Block [] { messagesBuilder .build ().asBlock (), idsBuilder .build ().asBlock () };
108+
109+ if (withNull ) {
110+ messagesBuilder .appendNull ();
111+ idsBuilder .appendInt (43 );
112+ }
113+ return new Block [] { messagesBuilder .build (), idsBuilder .build () };
99114 }
100115 };
101116 LocalSourceOperator .BlockSupplier input2 = () -> {
102117 try (
103- BytesRefVector .Builder messagesBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
104- IntVector .Builder idsBuilder = driverContext .blockFactory ().newIntVectorBuilder (10 )
118+ BytesRefBlock .Builder messagesBuilder = driverContext .blockFactory ().newBytesRefBlockBuilder (10 );
119+ IntBlock .Builder idsBuilder = driverContext .blockFactory ().newIntBlockBuilder (10 )
105120 ) {
106121 messagesBuilder .appendBytesRef (new BytesRef ("connected to 2.1.1" ));
107122 messagesBuilder .appendBytesRef (new BytesRef ("connected to 2.1.2" ));
@@ -111,7 +126,11 @@ public void testCategorize_withDriver() {
111126 idsBuilder .appendInt (7 );
112127 idsBuilder .appendInt (7 );
113128 idsBuilder .appendInt (42 );
114- return new Block [] { messagesBuilder .build ().asBlock (), idsBuilder .build ().asBlock () };
129+ if (withNull ) {
130+ messagesBuilder .appendNull ();
131+ idsBuilder .appendNull ();
132+ }
133+ return new Block [] { messagesBuilder .build (), idsBuilder .build () };
115134 }
116135 };
117136
@@ -177,38 +196,53 @@ public void testCategorize_withDriver() {
177196 BytesRefBlock outputValues = finalOutput .get (0 ).getBlock (2 );
178197 assertThat (outputIds .getPositionCount (), equalTo (outputMessages .getPositionCount ()));
179198 assertThat (outputValues .getPositionCount (), equalTo (outputMessages .getPositionCount ()));
180- Map <String , Map <Integer , Set <String >>> values = new HashMap <>();
199+ Map <String , Map <Integer , Set <String >>> result = new HashMap <>();
181200 for (int i = 0 ; i < outputMessages .getPositionCount (); i ++) {
182- String message = outputMessages .getBytesRef (i , new BytesRef ()).utf8ToString ();
183- int id = outputIds .getInt (i );
184- int valuesFromIndex = outputValues .getFirstValueIndex (i );
185- int valuesToIndex = valuesFromIndex + outputValues .getValueCount (i );
186- for (int valueIndex = valuesFromIndex ; valueIndex < valuesToIndex ; valueIndex ++) {
187- String value = outputValues .getBytesRef (valueIndex , new BytesRef ()).utf8ToString ();
188- values .computeIfAbsent (message , key -> new HashMap <>()).computeIfAbsent (id , key -> new HashSet <>()).add (value );
201+ BytesRef messageBytesRef = ((BytesRef ) BlockUtils .toJavaObject (outputMessages , i ));
202+ String message = messageBytesRef == null ? null : messageBytesRef .utf8ToString ();
203+ result .computeIfAbsent (message , key -> new HashMap <>());
204+
205+ Integer id = (Integer ) BlockUtils .toJavaObject (outputIds , i );
206+ result .get (message ).computeIfAbsent (id , key -> new HashSet <>());
207+
208+ Object values = BlockUtils .toJavaObject (outputValues , i );
209+ if (values == null ) {
210+ result .get (message ).get (id ).add (null );
211+ } else {
212+ if ((values instanceof List ) == false ) {
213+ values = List .of (values );
214+ }
215+ for (Object valueObject : (List <?>) values ) {
216+ BytesRef value = (BytesRef ) valueObject ;
217+ result .get (message ).get (id ).add (value .utf8ToString ());
218+ }
189219 }
190220 }
191221 Releasables .close (() -> Iterators .map (finalOutput .iterator (), (Page p ) -> p ::releaseBlocks ));
192222
193- assertThat (
194- values ,
195- equalTo (
196- Map .of (
197- ".*?connected.+?to.*?" ,
198- Map .of (
199- 7 ,
200- Set .of ("connected to 1.1.1" , "connected to 1.1.2" , "connected to 1.1.4" , "connected to 2.1.2" ),
201- 42 ,
202- Set .of ("connected to 1.1.3" ),
203- 111 ,
204- Set .of ("connected to 2.1.1" )
205- ),
206- ".*?connection.+?error.*?" ,
207- Map .of (7 , Set .of ("connection error" ), 42 , Set .of ("connection error" )),
208- ".*?disconnected.*?" ,
209- Map .of (7 , Set .of ("disconnected" ))
210- )
211- )
223+ Map <String , Map <Integer , Set <String >>> expectedResult = Map .of (
224+ ".*?connected.+?to.*?" ,
225+ Map .of (
226+ 7 ,
227+ Set .of ("connected to 1.1.1" , "connected to 1.1.2" , "connected to 1.1.4" , "connected to 2.1.2" ),
228+ 42 ,
229+ Set .of ("connected to 1.1.3" ),
230+ 111 ,
231+ Set .of ("connected to 2.1.1" )
232+ ),
233+ ".*?connection.+?error.*?" ,
234+ Map .of (7 , Set .of ("connection error" ), 42 , Set .of ("connection error" )),
235+ ".*?disconnected.*?" ,
236+ Map .of (7 , Set .of ("disconnected" ))
212237 );
238+ if (withNull ) {
239+ expectedResult = new HashMap <>(expectedResult );
240+ expectedResult .put (null , new HashMap <>());
241+ expectedResult .get (null ).put (null , new HashSet <>());
242+ expectedResult .get (null ).get (null ).add (null );
243+ expectedResult .get (null ).put (43 , new HashSet <>());
244+ expectedResult .get (null ).get (43 ).add (null );
245+ }
246+ assertThat (result , equalTo (expectedResult ));
213247 }
214248}
0 commit comments