88package org .elasticsearch .compute .aggregation .blockhash ;
99
1010import org .apache .lucene .util .BytesRef ;
11+ import org .elasticsearch .analysis .common .CommonAnalysisPlugin ;
1112import org .elasticsearch .common .breaker .CircuitBreaker ;
1213import org .elasticsearch .common .collect .Iterators ;
14+ import org .elasticsearch .common .settings .Settings ;
1315import org .elasticsearch .common .unit .ByteSizeValue ;
1416import org .elasticsearch .common .util .BigArrays ;
1517import org .elasticsearch .common .util .MockBigArrays ;
3537import org .elasticsearch .compute .operator .LocalSourceOperator ;
3638import org .elasticsearch .compute .operator .PageConsumerOperator ;
3739import org .elasticsearch .core .Releasables ;
38-
40+ import org .elasticsearch .env .Environment ;
41+ import org .elasticsearch .env .TestEnvironment ;
42+ import org .elasticsearch .index .analysis .AnalysisRegistry ;
43+ import org .elasticsearch .indices .analysis .AnalysisModule ;
44+ import org .elasticsearch .plugins .scanners .StablePluginsRegistry ;
45+ import org .elasticsearch .xpack .ml .MachineLearning ;
46+ import org .junit .Before ;
47+
48+ import java .io .IOException ;
3949import java .util .ArrayList ;
4050import java .util .HashMap ;
4151import java .util .List ;
5060
5161public class CategorizeBlockHashTests extends BlockHashTestCase {
5262
63+ private AnalysisRegistry analysisRegistry ;
64+
65+ @ Before
66+ private void initAnalysisRegistry () throws IOException {
67+ analysisRegistry = new AnalysisModule (
68+ TestEnvironment .newEnvironment (
69+ Settings .builder ().put (Environment .PATH_HOME_SETTING .getKey (), createTempDir ().toString ()).build ()
70+ ),
71+ List .of (new MachineLearning (Settings .EMPTY ), new CommonAnalysisPlugin ()),
72+ new StablePluginsRegistry ()
73+ ).getAnalysisRegistry ();
74+ }
75+
5376 public void testCategorizeRaw () {
5477 final Page page ;
5578 boolean withNull = randomBoolean ();
@@ -72,7 +95,7 @@ public void testCategorizeRaw() {
7295 page = new Page (builder .build ());
7396 }
7497
75- try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true )) {
98+ try (BlockHash hash = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry )) {
7699 hash .add (page , new GroupingAggregatorFunction .AddInput () {
77100 @ Override
78101 public void add (int positionOffset , IntBlock groupIds ) {
@@ -145,8 +168,8 @@ public void testCategorizeIntermediate() {
145168
146169 // Fill intermediatePages with the intermediate state from the raw hashes
147170 try (
148- BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true );
149- BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true )
171+ BlockHash rawHash1 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
172+ BlockHash rawHash2 = new CategorizeRawBlockHash (0 , blockFactory , true , analysisRegistry );
150173 ) {
151174 rawHash1 .add (page1 , new GroupingAggregatorFunction .AddInput () {
152175 @ Override
@@ -267,14 +290,16 @@ public void testCategorize_withDriver() {
267290 BytesRefVector .Builder textsBuilder = driverContext .blockFactory ().newBytesRefVectorBuilder (10 );
268291 LongVector .Builder countsBuilder = driverContext .blockFactory ().newLongVectorBuilder (10 )
269292 ) {
270- textsBuilder .appendBytesRef (new BytesRef ("a" ));
271- textsBuilder .appendBytesRef (new BytesRef ("b" ));
293+ // Note that just using "a" or "aaa" doesn't work, because the ml_standard
294+ // tokenizer drops numbers, including hexadecimal ones.
295+ textsBuilder .appendBytesRef (new BytesRef ("aaazz" ));
296+ textsBuilder .appendBytesRef (new BytesRef ("bbbzz" ));
272297 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye jan" ));
273298 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye nik" ));
274299 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye tom" ));
275300 textsBuilder .appendBytesRef (new BytesRef ("words words words hello jan" ));
276- textsBuilder .appendBytesRef (new BytesRef ("c " ));
277- textsBuilder .appendBytesRef (new BytesRef ("d " ));
301+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
302+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
278303 countsBuilder .appendLong (1 );
279304 countsBuilder .appendLong (2 );
280305 countsBuilder .appendLong (800 );
@@ -293,10 +318,10 @@ public void testCategorize_withDriver() {
293318 ) {
294319 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
295320 textsBuilder .appendBytesRef (new BytesRef ("words words words hello nik" ));
296- textsBuilder .appendBytesRef (new BytesRef ("c " ));
321+ textsBuilder .appendBytesRef (new BytesRef ("ccczz " ));
297322 textsBuilder .appendBytesRef (new BytesRef ("words words words goodbye chris" ));
298- textsBuilder .appendBytesRef (new BytesRef ("d " ));
299- textsBuilder .appendBytesRef (new BytesRef ("e " ));
323+ textsBuilder .appendBytesRef (new BytesRef ("dddzz " ));
324+ textsBuilder .appendBytesRef (new BytesRef ("eeezz " ));
300325 countsBuilder .appendLong (9 );
301326 countsBuilder .appendLong (90 );
302327 countsBuilder .appendLong (3 );
@@ -320,7 +345,8 @@ public void testCategorize_withDriver() {
320345 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
321346 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
322347 ),
323- 16 * 1024
348+ 16 * 1024 ,
349+ analysisRegistry
324350 ).get (driverContext )
325351 ),
326352 new PageConsumerOperator (intermediateOutput ::add ),
@@ -339,7 +365,8 @@ public void testCategorize_withDriver() {
339365 new SumLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL ),
340366 new MaxLongAggregatorFunctionSupplier (List .of (1 )).groupingAggregatorFactory (AggregatorMode .INITIAL )
341367 ),
342- 16 * 1024
368+ 16 * 1024 ,
369+ analysisRegistry
343370 ).get (driverContext )
344371 ),
345372 new PageConsumerOperator (intermediateOutput ::add ),
@@ -360,7 +387,8 @@ public void testCategorize_withDriver() {
360387 new SumLongAggregatorFunctionSupplier (List .of (1 , 2 )).groupingAggregatorFactory (AggregatorMode .FINAL ),
361388 new MaxLongAggregatorFunctionSupplier (List .of (3 , 4 )).groupingAggregatorFactory (AggregatorMode .FINAL )
362389 ),
363- 16 * 1024
390+ 16 * 1024 ,
391+ analysisRegistry
364392 ).get (driverContext )
365393 ),
366394 new PageConsumerOperator (finalOutput ::add ),
@@ -385,15 +413,15 @@ public void testCategorize_withDriver() {
385413 sums ,
386414 equalTo (
387415 Map .of (
388- ".*?a .*?" ,
416+ ".*?aaazz .*?" ,
389417 1L ,
390- ".*?b .*?" ,
418+ ".*?bbbzz .*?" ,
391419 2L ,
392- ".*?c .*?" ,
420+ ".*?ccczz .*?" ,
393421 33L ,
394- ".*?d .*?" ,
422+ ".*?dddzz .*?" ,
395423 44L ,
396- ".*?e .*?" ,
424+ ".*?eeezz .*?" ,
397425 5L ,
398426 ".*?words.+?words.+?words.+?goodbye.*?" ,
399427 8888L ,
@@ -406,15 +434,15 @@ public void testCategorize_withDriver() {
406434 maxs ,
407435 equalTo (
408436 Map .of (
409- ".*?a .*?" ,
437+ ".*?aaazz .*?" ,
410438 1L ,
411- ".*?b .*?" ,
439+ ".*?bbbzz .*?" ,
412440 2L ,
413- ".*?c .*?" ,
441+ ".*?ccczz .*?" ,
414442 30L ,
415- ".*?d .*?" ,
443+ ".*?dddzz .*?" ,
416444 40L ,
417- ".*?e .*?" ,
445+ ".*?eeezz .*?" ,
418446 5L ,
419447 ".*?words.+?words.+?words.+?goodbye.*?" ,
420448 8000L ,
0 commit comments