99
1010import org .apache .lucene .analysis .core .WhitespaceTokenizer ;
1111import org .apache .lucene .util .BytesRef ;
12+ import org .elasticsearch .common .breaker .CircuitBreaker ;
13+ import org .elasticsearch .common .collect .Iterators ;
14+ import org .elasticsearch .common .unit .ByteSizeValue ;
15+ import org .elasticsearch .common .util .BigArrays ;
1216import org .elasticsearch .common .util .BytesRefHash ;
17+ import org .elasticsearch .common .util .MockBigArrays ;
18+ import org .elasticsearch .common .util .PageCacheRecycler ;
19+ import org .elasticsearch .compute .aggregation .AggregatorMode ;
1320import org .elasticsearch .compute .aggregation .GroupingAggregatorFunction ;
21+ import org .elasticsearch .compute .aggregation .SumLongAggregatorFunctionSupplier ;
22+ import org .elasticsearch .compute .data .Block ;
23+ import org .elasticsearch .compute .data .BlockFactory ;
1424import org .elasticsearch .compute .data .BytesRefBlock ;
25+ import org .elasticsearch .compute .data .BytesRefVector ;
26+ import org .elasticsearch .compute .data .ElementType ;
1527import org .elasticsearch .compute .data .IntBlock ;
1628import org .elasticsearch .compute .data .IntVector ;
29+ import org .elasticsearch .compute .data .LongVector ;
1730import org .elasticsearch .compute .data .Page ;
31+ import org .elasticsearch .compute .operator .CannedSourceOperator ;
32+ import org .elasticsearch .compute .operator .Driver ;
33+ import org .elasticsearch .compute .operator .DriverContext ;
34+ import org .elasticsearch .compute .operator .HashAggregationOperator ;
35+ import org .elasticsearch .compute .operator .LocalSourceOperator ;
36+ import org .elasticsearch .compute .operator .PageConsumerOperator ;
37+ import org .elasticsearch .core .Releasables ;
1838import org .elasticsearch .index .analysis .CharFilterFactory ;
1939import org .elasticsearch .index .analysis .CustomAnalyzer ;
2040import org .elasticsearch .index .analysis .TokenFilterFactory ;
2444import org .elasticsearch .xpack .ml .aggs .categorization .TokenListCategorizer .CloseableTokenListCategorizer ;
2545import org .elasticsearch .xpack .ml .job .categorization .CategorizationAnalyzer ;
2646
47+ import java .util .ArrayList ;
48+ import java .util .List ;
2749import java .util .Set ;
2850import java .util .stream .Collectors ;
2951import java .util .stream .IntStream ;
3052
53+ import static org .elasticsearch .compute .operator .OperatorTestCase .runDriver ;
54+ import static org .hamcrest .Matchers .containsInAnyOrder ;
55+ import static org .hamcrest .Matchers .equalTo ;
56+ import static org .hamcrest .Matchers .hasSize ;
57+
3158public class CategorizeBlockHashTests extends BlockHashTestCase {
3259
3360 /**
@@ -47,7 +74,7 @@ public void testCategorizeRaw() {
4774 page = new Page (builder .build ());
4875 }
4976 // final int emitBatchSize = between(positions, 10 * 1024);
50- try (BlockHash hash = new CategorizeRawBlockHash (blockFactory , 0 , true , createAnalyzer (), createCategorizer ())) {
77+ try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true , createAnalyzer (), createCategorizer ())) {
5178 hash .add (page , new GroupingAggregatorFunction .AddInput () {
5279 @ Override
5380 public void add (int positionOffset , IntBlock groupIds ) {
@@ -107,9 +134,9 @@ public void testCategorizeIntermediate() {
107134 }
108135 // final int emitBatchSize = between(positions, 10 * 1024);
109136 try (
110- BlockHash rawHash1 = new CategorizeRawBlockHash (blockFactory , 0 , true , createAnalyzer (), createCategorizer ());
111- BlockHash rawHash2 = new CategorizeRawBlockHash (blockFactory , 0 , true , createAnalyzer (), createCategorizer ());
112- BlockHash intermediateHash = new CategorizedIntermediateBlockHash (blockFactory , 0 , true , createCategorizer ())
137+ BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true , createAnalyzer (), createCategorizer ());
138+ BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true , createAnalyzer (), createCategorizer ());
139+ BlockHash intermediateHash = new CategorizedIntermediateBlockHash (0 , blockFactory , true , createCategorizer ())
113140 ) {
114141 rawHash1 .add (page1 , new GroupingAggregatorFunction .AddInput () {
115142 @ Override
@@ -211,6 +238,103 @@ public void close() {
211238 }
212239 }
213240
241+ public void testCategorize_withDriver () {
242+ BigArrays bigArrays = new MockBigArrays (PageCacheRecycler .NON_RECYCLING_INSTANCE , ByteSizeValue .ofMb (256 )).withCircuitBreaking ();
243+ CircuitBreaker breaker = bigArrays .breakerService ().getBreaker (CircuitBreaker .REQUEST );
244+ DriverContext driverContext = new DriverContext (bigArrays , new BlockFactory (breaker , bigArrays ));
245+
246+ LocalSourceOperator .BlockSupplier input1 = () -> {
247+ try (BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 )) {
248+ textsBuilder .appendBytesRef (new BytesRef ("a" ));
249+ textsBuilder .appendBytesRef (new BytesRef ("b" ));
250+ textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye jan" ));
251+ textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye nik" ));
252+ textsBuilder .appendBytesRef (new BytesRef ("words words words hello jan" ));
253+ textsBuilder .appendBytesRef (new BytesRef ("c" ));
254+ return new Block [] { textsBuilder .build ().asBlock () };
255+ }
256+ };
257+ LocalSourceOperator .BlockSupplier input2 = () -> {
258+ try (BytesRefVector .Builder builder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 )) {
259+ builder .appendBytesRef (new BytesRef ("words words words hello nik" ));
260+ builder .appendBytesRef (new BytesRef ("c" ));
261+ builder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
262+ builder .appendBytesRef (new BytesRef ("d" ));
263+ builder .appendBytesRef (new BytesRef ("e" ));
264+ return new Block [] { builder .build ().asBlock () };
265+ }
266+ };
267+ List <Page > intermediateOutput = new ArrayList <>();
268+ List <Page > finalOutput = new ArrayList <>();
269+
270+ Driver driver = new Driver (
271+ driverContext ,
272+ new LocalSourceOperator (input1 ),
273+ List .of (
274+ new HashAggregationOperator .HashAggregationOperatorFactory (
275+ List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_RAW )),
276+ List .of (),
277+ 16 * 1024
278+ ).get (driverContext )
279+ ),
280+ new PageConsumerOperator (intermediateOutput ::add ),
281+ () -> {}
282+ );
283+ runDriver (driver );
284+
285+ driver = new Driver (
286+ driverContext ,
287+ new LocalSourceOperator (input2 ),
288+ List .of (
289+ new HashAggregationOperator .HashAggregationOperatorFactory (
290+ List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_RAW )),
291+ List .of (),
292+ 16 * 1024
293+ ).get (driverContext )
294+ ),
295+ new PageConsumerOperator (intermediateOutput ::add ),
296+ () -> {}
297+ );
298+ runDriver (driver );
299+
300+ driver = new Driver (
301+ driverContext ,
302+ new CannedSourceOperator (intermediateOutput .iterator ()),
303+ List .of (
304+ new HashAggregationOperator .HashAggregationOperatorFactory (
305+ List .of (new BlockHash .GroupSpec (0 , ElementType .CATEGORY_INTERMEDIATE )),
306+ List .of (),
307+ 16 * 1024
308+ ).get (driverContext )
309+ ),
310+ new PageConsumerOperator (finalOutput ::add ),
311+ () -> {}
312+ );
313+ runDriver (driver );
314+
315+ assertThat (finalOutput , hasSize (1 ));
316+ assertThat (finalOutput .get (0 ).getBlockCount (), equalTo (1 ));
317+ BytesRefBlock block = finalOutput .get (0 ).getBlock (0 );
318+ BytesRefVector vector = block .asVector ();
319+ List <String > values = new ArrayList <>();
320+ for (int p = 0 ; p < vector .getPositionCount (); p ++) {
321+ values .add (vector .getBytesRef (p , new BytesRef ()).utf8ToString ());
322+ }
323+ assertThat (
324+ values ,
325+ containsInAnyOrder (
326+ ".*?a.*?" ,
327+ ".*?b.*?" ,
328+ ".*?c.*?" ,
329+ ".*?d.*?" ,
330+ ".*?e.*?" ,
331+ ".*?words.+?words.+?words.+?goodbye.*?" ,
332+ ".*?words.+?words.+?words.+?hello.*?"
333+ )
334+ );
335+ Releasables .close (() -> Iterators .map (finalOutput .iterator (), (Page p ) -> p ::releaseBlocks ));
336+ }
337+
214338 private static CategorizationAnalyzer createAnalyzer () {
215339 return new CategorizationAnalyzer (
216340 // TODO: should be the same analyzer as used in Production
0 commit comments