88package org .elasticsearch .compute .aggregation .blockhash ;
99
1010import org .apache .lucene .util .BytesRef ;
11+ import org .elasticsearch .analysis .common .CommonAnalysisPlugin ;
1112import org .elasticsearch .common .breaker .CircuitBreaker ;
1213import org .elasticsearch .common .collect .Iterators ;
14+ import org .elasticsearch .common .settings .Settings ;
1315import org .elasticsearch .common .unit .ByteSizeValue ;
1416import org .elasticsearch .common .util .BigArrays ;
1517import org .elasticsearch .common .util .MockBigArrays ;
3537import org .elasticsearch .compute .operator .LocalSourceOperator ;
3638import org .elasticsearch .compute .operator .PageConsumerOperator ;
3739import org .elasticsearch .core .Releasables ;
38-
40+ import org .elasticsearch .env .Environment ;
41+ import org .elasticsearch .env .TestEnvironment ;
42+ import org .elasticsearch .index .analysis .AnalysisRegistry ;
43+ import org .elasticsearch .indices .analysis .AnalysisModule ;
44+ import org .elasticsearch .plugins .scanners .StablePluginsRegistry ;
45+ import org .elasticsearch .xpack .ml .MachineLearning ;
46+ import org .junit .Before ;
47+
48+ import java .io .IOException ;
3949import java .util .ArrayList ;
4050import java .util .HashMap ;
4151import java .util .List ;
5060
5161public class CategorizeBlockHashTests extends BlockHashTestCase {
5262
63+ private AnalysisRegistry analysisRegistry ;
64+
65+ @ Before
66+ private void initAnalysisRegistry () throws IOException {
67+ analysisRegistry = new AnalysisModule (
68+ TestEnvironment .newEnvironment (
69+ Settings .builder ().put (Environment .PATH_HOME_SETTING .getKey (), createTempDir ().toString ()).build ()
70+ ),
71+ List .of (new MachineLearning (Settings .EMPTY ), new CommonAnalysisPlugin ()),
72+ new StablePluginsRegistry ()
73+ ).getAnalysisRegistry ();
74+ }
75+
5376 public void testCategorizeRaw () {
5477 final Page page ;
5578 final int positions = 7 ;
@@ -64,7 +87,7 @@ public void testCategorizeRaw() {
6487 page = new Page (builder .build ());
6588 }
6689
67- try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true )) {
90+ try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry )) {
6891 hash .add (page , new GroupingAggregatorFunction .AddInput () {
6992 @ Override
7093 public void add (int positionOffset , IntBlock groupIds ) {
@@ -126,8 +149,8 @@ public void testCategorizeIntermediate() {
126149
127150 // Fill intermediatePages with the intermediate state from the raw hashes
128151 try (
129- BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true );
130- BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true )
152+ BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
153+ BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
131154 ) {
132155 rawHash1 .add (page1 , new GroupingAggregatorFunction .AddInput () {
133156 @ Override
@@ -241,14 +264,14 @@ public void testCategorize_withDriver() {
241264 BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
242265 LongVector .Builder countsBuilder = driverContext .blockFactory ().newLongVectorBuilder (10 )
243266 ) {
244- textsBuilder .appendBytesRef (new BytesRef ("a " ));
245- textsBuilder .appendBytesRef (new BytesRef ("b " ));
267+ textsBuilder .appendBytesRef (new BytesRef ("aaazz " ));
268+ textsBuilder .appendBytesRef (new BytesRef ("bbbzz " ));
246269 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye jan" ));
247270 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye nik" ));
248271 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye tom" ));
249272 textsBuilder .appendBytesRef (new BytesRef ("words words words hello jan" ));
250- textsBuilder .appendBytesRef (new BytesRef ("c " ));
251- textsBuilder .appendBytesRef (new BytesRef ("d " ));
273+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
274+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
252275 countsBuilder .appendLong (1 );
253276 countsBuilder .appendLong (2 );
254277 countsBuilder .appendLong (800 );
@@ -267,10 +290,10 @@ public void testCategorize_withDriver() {
267290 ) {
268291 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
269292 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
270- textsBuilder .appendBytesRef (new BytesRef ("c " ));
293+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
271294 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
272- textsBuilder .appendBytesRef (new BytesRef ("d " ));
273- textsBuilder .appendBytesRef (new BytesRef ("e " ));
295+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
296+ textsBuilder .appendBytesRef (new BytesRef ("eeezz " ));
274297 countsBuilder .appendLong (9 );
275298 countsBuilder .appendLong (90 );
276299 countsBuilder .appendLong (3 );
@@ -294,7 +317,8 @@ public void testCategorize_withDriver() {
294317 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
295318 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
296319 ),
297- 16 * 1024
320+ 16 * 1024 ,
321+ analysisRegistry
298322 ).get (driverContext )
299323 ),
300324 new PageConsumerOperator (intermediateOutput ::add ),
@@ -313,7 +337,8 @@ public void testCategorize_withDriver() {
313337 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
314338 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
315339 ),
316- 16 * 1024
340+ 16 * 1024 ,
341+ analysisRegistry
317342 ).get (driverContext )
318343 ),
319344 new PageConsumerOperator (intermediateOutput ::add ),
@@ -334,7 +359,8 @@ public void testCategorize_withDriver() {
334359 new SumLongAggregatorFunctionSupplier (List .of (1 , 2 )).groupingAggregatorFactory (AggregatorMode .FINAL ),
335360 new MaxLongAggregatorFunctionSupplier (List .of (3 , 4 )).groupingAggregatorFactory (AggregatorMode .FINAL )
336361 ),
337- 16 * 1024
362+ 16 * 1024 ,
363+ analysisRegistry
338364 ).get (driverContext )
339365 ),
340366 new PageConsumerOperator (finalOutput ::add ),
@@ -359,15 +385,15 @@ public void testCategorize_withDriver() {
359385 sums ,
360386 equalTo (
361387 Map .of (
362- ".*?a .*?" ,
388+ ".*?aaazz .*?" ,
363389 1L ,
364- ".*?b .*?" ,
390+ ".*?bbbzz .*?" ,
365391 2L ,
366- ".*?c .*?" ,
392+ ".*?ccczz .*?" ,
367393 33L ,
368- ".*?d .*?" ,
394+ ".*?dddzz .*?" ,
369395 44L ,
370- ".*?e .*?" ,
396+ ".*?eeezz .*?" ,
371397 5L ,
372398 ".*?words.+?words.+?words.+?goodbye.*?" ,
373399 8888L ,
@@ -380,15 +406,15 @@ public void testCategorize_withDriver() {
380406 maxs ,
381407 equalTo (
382408 Map .of (
383- ".*?a .*?" ,
409+ ".*?aaazz .*?" ,
384410 1L ,
385- ".*?b .*?" ,
411+ ".*?bbbzz .*?" ,
386412 2L ,
387- ".*?c .*?" ,
413+ ".*?ccczz .*?" ,
388414 30L ,
389- ".*?d .*?" ,
415+ ".*?dddzz .*?" ,
390416 40L ,
391- ".*?e .*?" ,
417+ ".*?eeezz .*?" ,
392418 5L ,
393419 ".*?words.+?words.+?words.+?goodbye.*?" ,
394420 8000L ,
0 commit comments