88package org .elasticsearch .compute .aggregation .blockhash ;
99
1010import org .apache .lucene .util .BytesRef ;
11+ import org .elasticsearch .analysis .common .CommonAnalysisPlugin ;
1112import org .elasticsearch .common .breaker .CircuitBreaker ;
1213import org .elasticsearch .common .collect .Iterators ;
14+ import org .elasticsearch .common .settings .Settings ;
1315import org .elasticsearch .common .unit .ByteSizeValue ;
1416import org .elasticsearch .common .util .BigArrays ;
1517import org .elasticsearch .common .util .MockBigArrays ;
3537import org .elasticsearch .compute .operator .LocalSourceOperator ;
3638import org .elasticsearch .compute .operator .PageConsumerOperator ;
3739import org .elasticsearch .core .Releasables ;
38-
40+ import org .elasticsearch .env .Environment ;
41+ import org .elasticsearch .env .TestEnvironment ;
42+ import org .elasticsearch .index .analysis .AnalysisRegistry ;
43+ import org .elasticsearch .indices .analysis .AnalysisModule ;
44+ import org .elasticsearch .plugins .scanners .StablePluginsRegistry ;
45+ import org .elasticsearch .xpack .ml .MachineLearning ;
46+ import org .junit .Before ;
47+
48+ import java .io .IOException ;
3949import java .util .ArrayList ;
4050import java .util .HashMap ;
4151import java .util .List ;
5060
5161public class CategorizeBlockHashTests extends BlockHashTestCase {
5262
63+ private AnalysisRegistry analysisRegistry ;
64+
65+ @ Before
66+ private void initAnalysisRegistry () throws IOException {
67+ analysisRegistry = new AnalysisModule (
68+ TestEnvironment .newEnvironment (
69+ Settings .builder ().put (Environment .PATH_HOME_SETTING .getKey (), createTempDir ().toString ()).build ()
70+ ),
71+ List .of (new MachineLearning (Settings .EMPTY ), new CommonAnalysisPlugin ()),
72+ new StablePluginsRegistry ()
73+ ).getAnalysisRegistry ();
74+ }
75+
5376 public void testCategorizeRaw () {
5477 final Page page ;
5578 boolean withNull = randomBoolean ();
@@ -72,7 +95,7 @@ public void testCategorizeRaw() {
7295 page = new Page (builder .build ());
7396 }
7497
75- try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true )) {
98+ try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry )) {
7699 hash .add (page , new GroupingAggregatorFunction .AddInput () {
77100 @ Override
78101 public void add (int positionOffset , IntBlock groupIds ) {
@@ -145,8 +168,8 @@ public void testCategorizeIntermediate() {
145168
146169 // Fill intermediatePages with the intermediate state from the raw hashes
147170 try (
148- BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true );
149- BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true )
171+ BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
172+ BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
150173 ) {
151174 rawHash1 .add (page1 , new GroupingAggregatorFunction .AddInput () {
152175 @ Override
@@ -267,14 +290,14 @@ public void testCategorize_withDriver() {
267290 BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
268291 LongVector .Builder countsBuilder = driverContext .blockFactory ().newLongVectorBuilder (10 )
269292 ) {
270- textsBuilder .appendBytesRef (new BytesRef ("a " ));
271- textsBuilder .appendBytesRef (new BytesRef ("b " ));
293+ textsBuilder .appendBytesRef (new BytesRef ("aaazz " ));
294+ textsBuilder .appendBytesRef (new BytesRef ("bbbzz " ));
272295 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye jan" ));
273296 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye nik" ));
274297 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye tom" ));
275298 textsBuilder .appendBytesRef (new BytesRef ("words words words hello jan" ));
276- textsBuilder .appendBytesRef (new BytesRef ("c " ));
277- textsBuilder .appendBytesRef (new BytesRef ("d " ));
299+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
300+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
278301 countsBuilder .appendLong (1 );
279302 countsBuilder .appendLong (2 );
280303 countsBuilder .appendLong (800 );
@@ -293,10 +316,10 @@ public void testCategorize_withDriver() {
293316 ) {
294317 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
295318 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
296- textsBuilder .appendBytesRef (new BytesRef ("c " ));
319+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
297320 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
298- textsBuilder .appendBytesRef (new BytesRef ("d " ));
299- textsBuilder .appendBytesRef (new BytesRef ("e " ));
321+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
322+ textsBuilder .appendBytesRef (new BytesRef ("eeezz " ));
300323 countsBuilder .appendLong (9 );
301324 countsBuilder .appendLong (90 );
302325 countsBuilder .appendLong (3 );
@@ -320,7 +343,8 @@ public void testCategorize_withDriver() {
320343 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
321344 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
322345 ),
323- 16 * 1024
346+ 16 * 1024 ,
347+ analysisRegistry
324348 ).get (driverContext )
325349 ),
326350 new PageConsumerOperator (intermediateOutput ::add ),
@@ -339,7 +363,8 @@ public void testCategorize_withDriver() {
339363 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
340364 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
341365 ),
342- 16 * 1024
366+ 16 * 1024 ,
367+ analysisRegistry
343368 ).get (driverContext )
344369 ),
345370 new PageConsumerOperator (intermediateOutput ::add ),
@@ -360,7 +385,8 @@ public void testCategorize_withDriver() {
360385 new SumLongAggregatorFunctionSupplier (List .of (1 , 2 )).groupingAggregatorFactory (AggregatorMode .FINAL ),
361386 new MaxLongAggregatorFunctionSupplier (List .of (3 , 4 )).groupingAggregatorFactory (AggregatorMode .FINAL )
362387 ),
363- 16 * 1024
388+ 16 * 1024 ,
389+ analysisRegistry
364390 ).get (driverContext )
365391 ),
366392 new PageConsumerOperator (finalOutput ::add ),
@@ -385,15 +411,15 @@ public void testCategorize_withDriver() {
385411 sums ,
386412 equalTo (
387413 Map .of (
388- ".*?a .*?" ,
414+ ".*?aaazz .*?" ,
389415 1L ,
390- ".*?b .*?" ,
416+ ".*?bbbzz .*?" ,
391417 2L ,
392- ".*?c .*?" ,
418+ ".*?ccczz .*?" ,
393419 33L ,
394- ".*?d .*?" ,
420+ ".*?dddzz .*?" ,
395421 44L ,
396- ".*?e .*?" ,
422+ ".*?eeezz .*?" ,
397423 5L ,
398424 ".*?words.+?words.+?words.+?goodbye.*?" ,
399425 8888L ,
@@ -406,15 +432,15 @@ public void testCategorize_withDriver() {
406432 maxs ,
407433 equalTo (
408434 Map .of (
409- ".*?a .*?" ,
435+ ".*?aaazz .*?" ,
410436 1L ,
411- ".*?b .*?" ,
437+ ".*?bbbzz .*?" ,
412438 2L ,
413- ".*?c .*?" ,
439+ ".*?ccczz .*?" ,
414440 30L ,
415- ".*?d .*?" ,
441+ ".*?dddzz .*?" ,
416442 40L ,
417- ".*?e .*?" ,
443+ ".*?eeezz .*?" ,
418444 5L ,
419445 ".*?words.+?words.+?words.+?goodbye.*?" ,
420446 8000L ,
0 commit comments