2626import org .elasticsearch .compute .data .ElementType ;
2727import org .elasticsearch .compute .data .IntBlock ;
2828import org .elasticsearch .compute .data .IntVector ;
29+ import org .elasticsearch .compute .data .LongBlock ;
2930import org .elasticsearch .compute .data .LongVector ;
3031import org .elasticsearch .compute .data .Page ;
3132import org .elasticsearch .compute .operator .CannedSourceOperator ;
4546import org .elasticsearch .xpack .ml .job .categorization .CategorizationAnalyzer ;
4647
4748import java .util .ArrayList ;
49+ import java .util .HashMap ;
4850import java .util .List ;
51+ import java .util .Map ;
4952import java .util .Set ;
5053import java .util .stream .Collectors ;
5154import java .util .stream .IntStream ;
5255
5356import static org .elasticsearch .compute .operator .OperatorTestCase .runDriver ;
54- import static org .hamcrest .Matchers .containsInAnyOrder ;
5557import static org .hamcrest .Matchers .equalTo ;
5658import static org .hamcrest .Matchers .hasSize ;
5759
@@ -244,24 +246,41 @@ public void testCategorize_withDriver() {
244246 DriverContext driverContext = new DriverContext (bigArrays , new BlockFactory (breaker , bigArrays ));
245247
246248 LocalSourceOperator .BlockSupplier input1 = () -> {
247- try (BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 )) {
249+ try (
250+ BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
251+ LongVector .Builder countsBuilder = driverContext .blockFactory ().newLongVectorBuilder (10 )
252+ ) {
248253 textsBuilder .appendBytesRef (new BytesRef ("a" ));
249254 textsBuilder .appendBytesRef (new BytesRef ("b" ));
250255 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye jan" ));
251256 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye nik" ));
252257 textsBuilder .appendBytesRef (new BytesRef ("words words words hello jan" ));
253258 textsBuilder .appendBytesRef (new BytesRef ("c" ));
254- return new Block [] { textsBuilder .build ().asBlock () };
259+ countsBuilder .appendLong (11 );
260+ countsBuilder .appendLong (22 );
261+ countsBuilder .appendLong (800 );
262+ countsBuilder .appendLong (80 );
263+ countsBuilder .appendLong (900 );
264+ countsBuilder .appendLong (30 );
265+ return new Block [] { textsBuilder .build ().asBlock (), countsBuilder .build ().asBlock () };
255266 }
256267 };
257268 LocalSourceOperator .BlockSupplier input2 = () -> {
258- try (BytesRefVector .Builder builder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 )) {
259- builder .appendBytesRef (new BytesRef ("words words words hello nik" ));
260- builder .appendBytesRef (new BytesRef ("c" ));
261- builder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
262- builder .appendBytesRef (new BytesRef ("d" ));
263- builder .appendBytesRef (new BytesRef ("e" ));
264- return new Block [] { builder .build ().asBlock () };
269+ try (
270+ BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
271+ LongVector .Builder countsBuilder = driverContext .blockFactory ().newLongVectorBuilder (10 )
272+ ) {
273+ textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
274+ textsBuilder .appendBytesRef (new BytesRef ("c" ));
275+ textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
276+ textsBuilder .appendBytesRef (new BytesRef ("d" ));
277+ textsBuilder .appendBytesRef (new BytesRef ("e" ));
278+ countsBuilder .appendLong (99 );
279+ countsBuilder .appendLong (3 );
280+ countsBuilder .appendLong (8 );
281+ countsBuilder .appendLong (44 );
282+ countsBuilder .appendLong (55 );
283+ return new Block [] { textsBuilder .build ().asBlock (), countsBuilder .build ().asBlock () };
265284 }
266285 };
267286 List <Page > intermediateOutput = new ArrayList <>();
@@ -273,7 +292,7 @@ public void testCategorize_withDriver() {
273292 List .of (
274293 new HashAggregationOperator .HashAggregationOperatorFactory (
275294 List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_RAW )),
276- List .of (),
295+ List .of (new SumLongAggregatorFunctionSupplier ( List . of ( 1 )). groupingAggregatorFactory ( AggregatorMode . INITIAL ) ),
277296 16 * 1024
278297 ).get (driverContext )
279298 ),
@@ -288,7 +307,7 @@ public void testCategorize_withDriver() {
288307 List .of (
289308 new HashAggregationOperator .HashAggregationOperatorFactory (
290309 List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_RAW )),
291- List .of (),
310+ List .of (new SumLongAggregatorFunctionSupplier ( List . of ( 1 )). groupingAggregatorFactory ( AggregatorMode . INITIAL ) ),
292311 16 * 1024
293312 ).get (driverContext )
294313 ),
@@ -303,7 +322,7 @@ public void testCategorize_withDriver() {
303322 List .of (
304323 new HashAggregationOperator .HashAggregationOperatorFactory (
305324 List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_INTERMEDIATE )),
306- List .of (),
325+ List .of (new SumLongAggregatorFunctionSupplier ( List . of ( 1 )). groupingAggregatorFactory ( AggregatorMode . INITIAL ) ),
307326 16 * 1024
308327 ).get (driverContext )
309328 ),
@@ -313,23 +332,32 @@ public void testCategorize_withDriver() {
313332 runDriver (driver );
314333
315334 assertThat (finalOutput , hasSize (1 ));
316- assertThat (finalOutput .get (0 ).getBlockCount (), equalTo (1 ));
317- BytesRefBlock block = finalOutput .get (0 ).getBlock (0 );
318- BytesRefVector vector = block .asVector ();
319- List <String > values = new ArrayList <>();
320- for (int p = 0 ; p < vector .getPositionCount (); p ++) {
321- values . add ( vector .getBytesRef (p , new BytesRef ()).utf8ToString ());
335+ assertThat (finalOutput .get (0 ).getBlockCount (), equalTo (3 ));
336+ BytesRefVector textsVector = (( BytesRefBlock ) finalOutput .get (0 ).getBlock (0 )). asVector ( );
337+ LongVector countsVector = (( LongBlock ) finalOutput . get ( 0 ). getBlock ( 1 )) .asVector ();
338+ Map <String , Long > counts = new HashMap <>();
339+ for (int i = 0 ; i < countsVector .getPositionCount (); i ++) {
340+ counts . put ( textsVector .getBytesRef (i , new BytesRef ()).utf8ToString (), countsVector . getLong ( i ));
322341 }
323342 assertThat (
324- values ,
325- containsInAnyOrder (
326- ".*?a.*?" ,
327- ".*?b.*?" ,
328- ".*?c.*?" ,
329- ".*?d.*?" ,
330- ".*?e.*?" ,
331- ".*?words.+?words.+?words.+?goodbye.*?" ,
332- ".*?words.+?words.+?words.+?hello.*?"
343+ counts ,
344+ equalTo (
345+ Map .of (
346+ ".*?a.*?" ,
347+ 11 ,
348+ ".*?b.*?" ,
349+ 22 ,
350+ ".*?c.*?" ,
351+ 33 ,
352+ ".*?d.*?" ,
353+ 44 ,
354+ ".*?e.*?" ,
355+ 55 ,
356+ ".*?words.+?words.+?words.+?goodbye.*?" ,
357+ 888 ,
358+ ".*?words.+?words.+?words.+?hello.*?" ,
359+ 999
360+ )
333361 )
334362 );
335363 Releasables .close (() -> Iterators .map (finalOutput .iterator (), (Page p ) -> p ::releaseBlocks ));
0 commit comments