Skip to content

Commit 7ae295c

Browse files
committed
Correct categorization analyzer in ES|QL categorize
1 parent 5935f76 commit 7ae295c

File tree

14 files changed

+161
-70
lines changed

14 files changed

+161
-70
lines changed

x-pack/plugin/esql/compute/build.gradle

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ base {
1111
dependencies {
1212
compileOnly project(':server')
1313
compileOnly project('ann')
14+
compileOnly project(xpackModule('core'))
1415
compileOnly project(xpackModule('ml'))
1516
annotationProcessor project('gen')
1617
implementation 'com.carrotsearch:hppc:0.8.1'
1718

18-
testImplementation project(':test:framework')
19+
testImplementation(project(':modules:analysis-common'))
20+
testImplementation(project(':test:framework'))
1921
testImplementation(project(xpackModule('esql-core')))
2022
testImplementation(project(xpackModule('core')))
2123
testImplementation(project(xpackModule('ml')))

x-pack/plugin/esql/compute/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
requires org.elasticsearch.ml;
2020
requires org.elasticsearch.tdigest;
2121
requires org.elasticsearch.geo;
22+
requires org.elasticsearch.xcore;
2223
requires hppc;
2324

2425
exports org.elasticsearch.compute;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.compute.data.Page;
2626
import org.elasticsearch.core.Releasable;
2727
import org.elasticsearch.core.ReleasableIterator;
28+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2829

2930
import java.util.Iterator;
3031
import java.util.List;
@@ -169,14 +170,19 @@ public static BlockHash buildPackedValuesBlockHash(List<GroupSpec> groups, Block
169170
/**
170171
* Builds a BlockHash for the Categorize grouping function.
171172
*/
172-
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
173+
public static BlockHash buildCategorizeBlockHash(
174+
List<GroupSpec> groups,
175+
AggregatorMode aggregatorMode,
176+
BlockFactory blockFactory,
177+
AnalysisRegistry analysisRegistry
178+
) {
173179
if (groups.size() != 1) {
174180
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
175181
}
176182

177183
return aggregatorMode.isInputPartial()
178184
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
179-
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
185+
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry);
180186
}
181187

182188
/**

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10-
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1110
import org.apache.lucene.util.BytesRef;
1211
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1312
import org.elasticsearch.compute.data.Block;
@@ -19,33 +18,48 @@
1918
import org.elasticsearch.compute.data.Page;
2019
import org.elasticsearch.core.Releasable;
2120
import org.elasticsearch.core.Releasables;
22-
import org.elasticsearch.index.analysis.CharFilterFactory;
23-
import org.elasticsearch.index.analysis.CustomAnalyzer;
24-
import org.elasticsearch.index.analysis.TokenFilterFactory;
25-
import org.elasticsearch.index.analysis.TokenizerFactory;
21+
import org.elasticsearch.index.IndexService;
22+
import org.elasticsearch.index.analysis.AnalysisRegistry;
23+
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
2624
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2725
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2826

27+
import java.io.IOException;
28+
import java.util.List;
29+
2930
/**
3031
* BlockHash implementation for {@code Categorize} grouping function.
3132
* <p>
3233
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
3334
* </p>
3435
*/
3536
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
37+
private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
38+
List.of()
39+
);
40+
3641
private final CategorizeEvaluator evaluator;
3742

38-
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
43+
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) {
3944
super(blockFactory, channel, outputPartial);
40-
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
41-
// TODO: should be the same analyzer as used in Production
42-
new CustomAnalyzer(
43-
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
44-
new CharFilterFactory[0],
45-
new TokenFilterFactory[0]
46-
),
47-
true
48-
);
45+
46+
CategorizationAnalyzer analyzer;
47+
try {
48+
analyzer = new CategorizationAnalyzer(
49+
analysisRegistry.buildCustomAnalyzer(
50+
IndexService.IndexCreationContext.RELOAD_ANALYZERS,
51+
null,
52+
false,
53+
ANALYZER_CONFIG.getTokenizer(),
54+
ANALYZER_CONFIG.getCharFilters(),
55+
ANALYZER_CONFIG.getTokenFilters()
56+
),
57+
true
58+
);
59+
} catch (IOException e) {
60+
throw new RuntimeException(e);
61+
}
62+
4963
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
5064
}
5165

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.compute.data.Page;
2525
import org.elasticsearch.core.Releasables;
2626
import org.elasticsearch.core.TimeValue;
27+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2728
import org.elasticsearch.xcontent.XContentBuilder;
2829

2930
import java.io.IOException;
@@ -42,14 +43,15 @@ public record HashAggregationOperatorFactory(
4243
List<BlockHash.GroupSpec> groups,
4344
AggregatorMode aggregatorMode,
4445
List<GroupingAggregator.Factory> aggregators,
45-
int maxPageSize
46+
int maxPageSize,
47+
AnalysisRegistry analysisRegistry
4648
) implements OperatorFactory {
4749
@Override
4850
public Operator get(DriverContext driverContext) {
4951
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5052
return new HashAggregationOperator(
5153
aggregators,
52-
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
54+
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
5355
driverContext
5456
);
5557
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
import static org.elasticsearch.compute.data.BlockTestUtils.append;
5555
import static org.hamcrest.Matchers.equalTo;
5656
import static org.hamcrest.Matchers.hasSize;
57-
import static org.hamcrest.Matchers.in;
5857

5958
/**
6059
* Shared tests for testing grouped aggregations.
@@ -107,7 +106,8 @@ private Operator.OperatorFactory simpleWithMode(
107106
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
108107
mode,
109108
List.of(supplier.groupingAggregatorFactory(mode)),
110-
randomPageSize()
109+
randomPageSize(),
110+
null
111111
);
112112
}
113113

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
package org.elasticsearch.compute.aggregation.blockhash;
99

1010
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
1112
import org.elasticsearch.common.breaker.CircuitBreaker;
1213
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.common.settings.Settings;
1315
import org.elasticsearch.common.unit.ByteSizeValue;
1416
import org.elasticsearch.common.util.BigArrays;
1517
import org.elasticsearch.common.util.MockBigArrays;
@@ -35,7 +37,15 @@
3537
import org.elasticsearch.compute.operator.LocalSourceOperator;
3638
import org.elasticsearch.compute.operator.PageConsumerOperator;
3739
import org.elasticsearch.core.Releasables;
38-
40+
import org.elasticsearch.env.Environment;
41+
import org.elasticsearch.env.TestEnvironment;
42+
import org.elasticsearch.index.analysis.AnalysisRegistry;
43+
import org.elasticsearch.indices.analysis.AnalysisModule;
44+
import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
45+
import org.elasticsearch.xpack.ml.MachineLearning;
46+
import org.junit.Before;
47+
48+
import java.io.IOException;
3949
import java.util.ArrayList;
4050
import java.util.HashMap;
4151
import java.util.List;
@@ -50,6 +60,19 @@
5060

5161
public class CategorizeBlockHashTests extends BlockHashTestCase {
5262

63+
private AnalysisRegistry analysisRegistry;
64+
65+
@Before
66+
private void initAnalysisRegistry() throws IOException {
67+
analysisRegistry = new AnalysisModule(
68+
TestEnvironment.newEnvironment(
69+
Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
70+
),
71+
List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
72+
new StablePluginsRegistry()
73+
).getAnalysisRegistry();
74+
}
75+
5376
public void testCategorizeRaw() {
5477
final Page page;
5578
boolean withNull = randomBoolean();
@@ -72,7 +95,7 @@ public void testCategorizeRaw() {
7295
page = new Page(builder.build());
7396
}
7497

75-
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
98+
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) {
7699
hash.add(page, new GroupingAggregatorFunction.AddInput() {
77100
@Override
78101
public void add(int positionOffset, IntBlock groupIds) {
@@ -145,8 +168,8 @@ public void testCategorizeIntermediate() {
145168

146169
// Fill intermediatePages with the intermediate state from the raw hashes
147170
try (
148-
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
149-
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
171+
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
172+
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
150173
) {
151174
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
152175
@Override
@@ -267,14 +290,14 @@ public void testCategorize_withDriver() {
267290
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
268291
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
269292
) {
270-
textsBuilder.appendBytesRef(new BytesRef("a"));
271-
textsBuilder.appendBytesRef(new BytesRef("b"));
293+
textsBuilder.appendBytesRef(new BytesRef("aaazz"));
294+
textsBuilder.appendBytesRef(new BytesRef("bbbzz"));
272295
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
273296
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
274297
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
275298
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
276-
textsBuilder.appendBytesRef(new BytesRef("c"));
277-
textsBuilder.appendBytesRef(new BytesRef("d"));
299+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
300+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
278301
countsBuilder.appendLong(1);
279302
countsBuilder.appendLong(2);
280303
countsBuilder.appendLong(800);
@@ -293,10 +316,10 @@ public void testCategorize_withDriver() {
293316
) {
294317
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
295318
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
296-
textsBuilder.appendBytesRef(new BytesRef("c"));
319+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
297320
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
298-
textsBuilder.appendBytesRef(new BytesRef("d"));
299-
textsBuilder.appendBytesRef(new BytesRef("e"));
321+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
322+
textsBuilder.appendBytesRef(new BytesRef("eeezz"));
300323
countsBuilder.appendLong(9);
301324
countsBuilder.appendLong(90);
302325
countsBuilder.appendLong(3);
@@ -320,7 +343,8 @@ public void testCategorize_withDriver() {
320343
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
321344
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
322345
),
323-
16 * 1024
346+
16 * 1024,
347+
analysisRegistry
324348
).get(driverContext)
325349
),
326350
new PageConsumerOperator(intermediateOutput::add),
@@ -339,7 +363,8 @@ public void testCategorize_withDriver() {
339363
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
340364
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
341365
),
342-
16 * 1024
366+
16 * 1024,
367+
analysisRegistry
343368
).get(driverContext)
344369
),
345370
new PageConsumerOperator(intermediateOutput::add),
@@ -360,7 +385,8 @@ public void testCategorize_withDriver() {
360385
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
361386
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
362387
),
363-
16 * 1024
388+
16 * 1024,
389+
analysisRegistry
364390
).get(driverContext)
365391
),
366392
new PageConsumerOperator(finalOutput::add),
@@ -385,15 +411,15 @@ public void testCategorize_withDriver() {
385411
sums,
386412
equalTo(
387413
Map.of(
388-
".*?a.*?",
414+
".*?aaazz.*?",
389415
1L,
390-
".*?b.*?",
416+
".*?bbbzz.*?",
391417
2L,
392-
".*?c.*?",
418+
".*?ccczz.*?",
393419
33L,
394-
".*?d.*?",
420+
".*?dddzz.*?",
395421
44L,
396-
".*?e.*?",
422+
".*?eeezz.*?",
397423
5L,
398424
".*?words.+?words.+?words.+?goodbye.*?",
399425
8888L,
@@ -406,15 +432,15 @@ public void testCategorize_withDriver() {
406432
maxs,
407433
equalTo(
408434
Map.of(
409-
".*?a.*?",
435+
".*?aaazz.*?",
410436
1L,
411-
".*?b.*?",
437+
".*?bbbzz.*?",
412438
2L,
413-
".*?c.*?",
439+
".*?ccczz.*?",
414440
30L,
415-
".*?d.*?",
441+
".*?dddzz.*?",
416442
40L,
417-
".*?e.*?",
443+
".*?eeezz.*?",
418444
5L,
419445
".*?words.+?words.+?words.+?goodbye.*?",
420446
8000L,

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) {
5959
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
6060
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
6161
),
62-
randomPageSize()
62+
randomPageSize(),
63+
null
6364
);
6465
}
6566

0 commit comments

Comments
 (0)