Skip to content

Commit 2f62a9c

Browse files
committed
Correct categorization analyzer in ES|QL categorize
1 parent d4bcd97 commit 2f62a9c

File tree

14 files changed

+161
-71
lines changed

14 files changed

+161
-71
lines changed

x-pack/plugin/esql/compute/build.gradle

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ base {
1111
dependencies {
1212
compileOnly project(':server')
1313
compileOnly project('ann')
14+
compileOnly project(xpackModule('core'))
1415
compileOnly project(xpackModule('ml'))
1516
annotationProcessor project('gen')
1617
implementation 'com.carrotsearch:hppc:0.8.1'
1718

18-
testImplementation project(':test:framework')
19+
testImplementation(project(':modules:analysis-common'))
20+
testImplementation(project(':test:framework'))
1921
testImplementation(project(xpackModule('esql-core')))
2022
testImplementation(project(xpackModule('core')))
2123
testImplementation(project(xpackModule('ml')))

x-pack/plugin/esql/compute/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
requires org.elasticsearch.ml;
2020
requires org.elasticsearch.tdigest;
2121
requires org.elasticsearch.geo;
22+
requires org.elasticsearch.xcore;
2223
requires hppc;
2324

2425
exports org.elasticsearch.compute;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.compute.data.Page;
2626
import org.elasticsearch.core.Releasable;
2727
import org.elasticsearch.core.ReleasableIterator;
28+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2829

2930
import java.util.Iterator;
3031
import java.util.List;
@@ -169,14 +170,19 @@ public static BlockHash buildPackedValuesBlockHash(List<GroupSpec> groups, Block
169170
/**
170171
* Builds a BlockHash for the Categorize grouping function.
171172
*/
172-
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
173+
public static BlockHash buildCategorizeBlockHash(
174+
List<GroupSpec> groups,
175+
AggregatorMode aggregatorMode,
176+
BlockFactory blockFactory,
177+
AnalysisRegistry analysisRegistry
178+
) {
173179
if (groups.size() != 1) {
174180
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
175181
}
176182

177183
return aggregatorMode.isInputPartial()
178184
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
179-
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
185+
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry);
180186
}
181187

182188
/**

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10-
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1110
import org.apache.lucene.util.BytesRef;
1211
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1312
import org.elasticsearch.compute.data.Block;
@@ -19,33 +18,48 @@
1918
import org.elasticsearch.compute.data.Page;
2019
import org.elasticsearch.core.Releasable;
2120
import org.elasticsearch.core.Releasables;
22-
import org.elasticsearch.index.analysis.CharFilterFactory;
23-
import org.elasticsearch.index.analysis.CustomAnalyzer;
24-
import org.elasticsearch.index.analysis.TokenFilterFactory;
25-
import org.elasticsearch.index.analysis.TokenizerFactory;
21+
import org.elasticsearch.index.IndexService;
22+
import org.elasticsearch.index.analysis.AnalysisRegistry;
23+
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
2624
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2725
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2826

27+
import java.io.IOException;
28+
import java.util.List;
29+
2930
/**
3031
* BlockHash implementation for {@code Categorize} grouping function.
3132
* <p>
3233
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
3334
* </p>
3435
*/
3536
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
37+
private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
38+
List.of()
39+
);
40+
3641
private final CategorizeEvaluator evaluator;
3742

38-
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
43+
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) {
3944
super(blockFactory, channel, outputPartial);
40-
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
41-
// TODO: should be the same analyzer as used in Production
42-
new CustomAnalyzer(
43-
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
44-
new CharFilterFactory[0],
45-
new TokenFilterFactory[0]
46-
),
47-
true
48-
);
45+
46+
CategorizationAnalyzer analyzer;
47+
try {
48+
analyzer = new CategorizationAnalyzer(
49+
analysisRegistry.buildCustomAnalyzer(
50+
IndexService.IndexCreationContext.RELOAD_ANALYZERS,
51+
null,
52+
false,
53+
ANALYZER_CONFIG.getTokenizer(),
54+
ANALYZER_CONFIG.getCharFilters(),
55+
ANALYZER_CONFIG.getTokenFilters()
56+
),
57+
true
58+
);
59+
} catch (IOException e) {
60+
throw new RuntimeException(e);
61+
}
62+
4963
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
5064
}
5165

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.compute.data.Page;
2525
import org.elasticsearch.core.Releasables;
2626
import org.elasticsearch.core.TimeValue;
27+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2728
import org.elasticsearch.xcontent.XContentBuilder;
2829

2930
import java.io.IOException;
@@ -42,14 +43,15 @@ public record HashAggregationOperatorFactory(
4243
List<BlockHash.GroupSpec> groups,
4344
AggregatorMode aggregatorMode,
4445
List<GroupingAggregator.Factory> aggregators,
45-
int maxPageSize
46+
int maxPageSize,
47+
AnalysisRegistry analysisRegistry
4648
) implements OperatorFactory {
4749
@Override
4850
public Operator get(DriverContext driverContext) {
4951
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5052
return new HashAggregationOperator(
5153
aggregators,
52-
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
54+
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
5355
driverContext
5456
);
5557
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
import static org.elasticsearch.compute.data.BlockTestUtils.append;
5555
import static org.hamcrest.Matchers.equalTo;
5656
import static org.hamcrest.Matchers.hasSize;
57-
import static org.hamcrest.Matchers.in;
5857

5958
/**
6059
* Shared tests for testing grouped aggregations.
@@ -107,7 +106,8 @@ private Operator.OperatorFactory simpleWithMode(
107106
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
108107
mode,
109108
List.of(supplier.groupingAggregatorFactory(mode)),
110-
randomPageSize()
109+
randomPageSize(),
110+
null
111111
);
112112
}
113113

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
package org.elasticsearch.compute.aggregation.blockhash;
99

1010
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
1112
import org.elasticsearch.common.breaker.CircuitBreaker;
1213
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.common.settings.Settings;
1315
import org.elasticsearch.common.unit.ByteSizeValue;
1416
import org.elasticsearch.common.util.BigArrays;
1517
import org.elasticsearch.common.util.MockBigArrays;
@@ -35,7 +37,15 @@
3537
import org.elasticsearch.compute.operator.LocalSourceOperator;
3638
import org.elasticsearch.compute.operator.PageConsumerOperator;
3739
import org.elasticsearch.core.Releasables;
38-
40+
import org.elasticsearch.env.Environment;
41+
import org.elasticsearch.env.TestEnvironment;
42+
import org.elasticsearch.index.analysis.AnalysisRegistry;
43+
import org.elasticsearch.indices.analysis.AnalysisModule;
44+
import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
45+
import org.elasticsearch.xpack.ml.MachineLearning;
46+
import org.junit.Before;
47+
48+
import java.io.IOException;
3949
import java.util.ArrayList;
4050
import java.util.HashMap;
4151
import java.util.List;
@@ -50,6 +60,19 @@
5060

5161
public class CategorizeBlockHashTests extends BlockHashTestCase {
5262

63+
private AnalysisRegistry analysisRegistry;
64+
65+
@Before
66+
private void initAnalysisRegistry() throws IOException {
67+
analysisRegistry = new AnalysisModule(
68+
TestEnvironment.newEnvironment(
69+
Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
70+
),
71+
List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
72+
new StablePluginsRegistry()
73+
).getAnalysisRegistry();
74+
}
75+
5376
public void testCategorizeRaw() {
5477
final Page page;
5578
final int positions = 7;
@@ -64,7 +87,7 @@ public void testCategorizeRaw() {
6487
page = new Page(builder.build());
6588
}
6689

67-
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
90+
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) {
6891
hash.add(page, new GroupingAggregatorFunction.AddInput() {
6992
@Override
7093
public void add(int positionOffset, IntBlock groupIds) {
@@ -126,8 +149,8 @@ public void testCategorizeIntermediate() {
126149

127150
// Fill intermediatePages with the intermediate state from the raw hashes
128151
try (
129-
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
130-
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
152+
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
153+
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
131154
) {
132155
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
133156
@Override
@@ -241,14 +264,14 @@ public void testCategorize_withDriver() {
241264
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
242265
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
243266
) {
244-
textsBuilder.appendBytesRef(new BytesRef("a"));
245-
textsBuilder.appendBytesRef(new BytesRef("b"));
267+
textsBuilder.appendBytesRef(new BytesRef("aaazz"));
268+
textsBuilder.appendBytesRef(new BytesRef("bbbzz"));
246269
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
247270
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
248271
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
249272
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
250-
textsBuilder.appendBytesRef(new BytesRef("c"));
251-
textsBuilder.appendBytesRef(new BytesRef("d"));
273+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
274+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
252275
countsBuilder.appendLong(1);
253276
countsBuilder.appendLong(2);
254277
countsBuilder.appendLong(800);
@@ -267,10 +290,10 @@ public void testCategorize_withDriver() {
267290
) {
268291
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
269292
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
270-
textsBuilder.appendBytesRef(new BytesRef("c"));
293+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
271294
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
272-
textsBuilder.appendBytesRef(new BytesRef("d"));
273-
textsBuilder.appendBytesRef(new BytesRef("e"));
295+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
296+
textsBuilder.appendBytesRef(new BytesRef("eeezz"));
274297
countsBuilder.appendLong(9);
275298
countsBuilder.appendLong(90);
276299
countsBuilder.appendLong(3);
@@ -294,7 +317,8 @@ public void testCategorize_withDriver() {
294317
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
295318
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
296319
),
297-
16 * 1024
320+
16 * 1024,
321+
analysisRegistry
298322
).get(driverContext)
299323
),
300324
new PageConsumerOperator(intermediateOutput::add),
@@ -313,7 +337,8 @@ public void testCategorize_withDriver() {
313337
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
314338
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
315339
),
316-
16 * 1024
340+
16 * 1024,
341+
analysisRegistry
317342
).get(driverContext)
318343
),
319344
new PageConsumerOperator(intermediateOutput::add),
@@ -334,7 +359,8 @@ public void testCategorize_withDriver() {
334359
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
335360
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
336361
),
337-
16 * 1024
362+
16 * 1024,
363+
analysisRegistry
338364
).get(driverContext)
339365
),
340366
new PageConsumerOperator(finalOutput::add),
@@ -359,15 +385,15 @@ public void testCategorize_withDriver() {
359385
sums,
360386
equalTo(
361387
Map.of(
362-
".*?a.*?",
388+
".*?aaazz.*?",
363389
1L,
364-
".*?b.*?",
390+
".*?bbbzz.*?",
365391
2L,
366-
".*?c.*?",
392+
".*?ccczz.*?",
367393
33L,
368-
".*?d.*?",
394+
".*?dddzz.*?",
369395
44L,
370-
".*?e.*?",
396+
".*?eeezz.*?",
371397
5L,
372398
".*?words.+?words.+?words.+?goodbye.*?",
373399
8888L,
@@ -380,15 +406,15 @@ public void testCategorize_withDriver() {
380406
maxs,
381407
equalTo(
382408
Map.of(
383-
".*?a.*?",
409+
".*?aaazz.*?",
384410
1L,
385-
".*?b.*?",
411+
".*?bbbzz.*?",
386412
2L,
387-
".*?c.*?",
413+
".*?ccczz.*?",
388414
30L,
389-
".*?d.*?",
415+
".*?dddzz.*?",
390416
40L,
391-
".*?e.*?",
417+
".*?eeezz.*?",
392418
5L,
393419
".*?words.+?words.+?words.+?goodbye.*?",
394420
8000L,

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) {
5959
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
6060
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
6161
),
62-
randomPageSize()
62+
randomPageSize(),
63+
null
6364
);
6465
}
6566

0 commit comments

Comments
 (0)