Skip to content

Commit 691ff6a

Browse files
authored
Correct categorization analyzer in ES|QL categorize (#117695) (#117773)
* Correct categorization analyzer in ES|QL categorize * close categorizer if constructing analyzer fails * Rename capability CATEGORIZE_V4 * add comments
1 parent 36f886f commit 691ff6a

File tree

17 files changed

+199
-113
lines changed

17 files changed

+199
-113
lines changed

x-pack/plugin/esql/compute/build.gradle

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ base {
1111
dependencies {
1212
compileOnly project(':server')
1313
compileOnly project('ann')
14+
compileOnly project(xpackModule('core'))
1415
compileOnly project(xpackModule('ml'))
1516
annotationProcessor project('gen')
1617
implementation 'com.carrotsearch:hppc:0.8.1'
1718

18-
testImplementation project(':test:framework')
19+
testImplementation(project(':modules:analysis-common'))
20+
testImplementation(project(':test:framework'))
1921
testImplementation(project(xpackModule('esql-core')))
2022
testImplementation(project(xpackModule('core')))
2123
testImplementation(project(xpackModule('ml')))

x-pack/plugin/esql/compute/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
requires org.elasticsearch.ml;
2020
requires org.elasticsearch.tdigest;
2121
requires org.elasticsearch.geo;
22+
requires org.elasticsearch.xcore;
2223
requires hppc;
2324

2425
exports org.elasticsearch.compute;

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.compute.data.Page;
2626
import org.elasticsearch.core.Releasable;
2727
import org.elasticsearch.core.ReleasableIterator;
28+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2829

2930
import java.util.Iterator;
3031
import java.util.List;
@@ -169,14 +170,19 @@ public static BlockHash buildPackedValuesBlockHash(List<GroupSpec> groups, Block
169170
/**
170171
* Builds a BlockHash for the Categorize grouping function.
171172
*/
172-
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
173+
public static BlockHash buildCategorizeBlockHash(
174+
List<GroupSpec> groups,
175+
AggregatorMode aggregatorMode,
176+
BlockFactory blockFactory,
177+
AnalysisRegistry analysisRegistry
178+
) {
173179
if (groups.size() != 1) {
174180
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
175181
}
176182

177183
return aggregatorMode.isInputPartial()
178184
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
179-
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
185+
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry);
180186
}
181187

182188
/**

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10-
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1110
import org.apache.lucene.util.BytesRef;
1211
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1312
import org.elasticsearch.compute.data.Block;
@@ -19,33 +18,38 @@
1918
import org.elasticsearch.compute.data.Page;
2019
import org.elasticsearch.core.Releasable;
2120
import org.elasticsearch.core.Releasables;
22-
import org.elasticsearch.index.analysis.CharFilterFactory;
23-
import org.elasticsearch.index.analysis.CustomAnalyzer;
24-
import org.elasticsearch.index.analysis.TokenFilterFactory;
25-
import org.elasticsearch.index.analysis.TokenizerFactory;
21+
import org.elasticsearch.index.analysis.AnalysisRegistry;
22+
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
2623
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2724
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2825

26+
import java.io.IOException;
27+
import java.util.List;
28+
2929
/**
3030
* BlockHash implementation for {@code Categorize} grouping function.
3131
* <p>
3232
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
3333
* </p>
3434
*/
3535
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
36+
private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
37+
List.of()
38+
);
39+
3640
private final CategorizeEvaluator evaluator;
3741

38-
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
42+
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) {
3943
super(blockFactory, channel, outputPartial);
40-
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
41-
// TODO: should be the same analyzer as used in Production
42-
new CustomAnalyzer(
43-
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
44-
new CharFilterFactory[0],
45-
new TokenFilterFactory[0]
46-
),
47-
true
48-
);
44+
45+
CategorizationAnalyzer analyzer;
46+
try {
47+
analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
48+
} catch (IOException e) {
49+
categorizer.close();
50+
throw new RuntimeException(e);
51+
}
52+
4953
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
5054
}
5155

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.compute.data.Page;
2525
import org.elasticsearch.core.Releasables;
2626
import org.elasticsearch.core.TimeValue;
27+
import org.elasticsearch.index.analysis.AnalysisRegistry;
2728
import org.elasticsearch.xcontent.XContentBuilder;
2829

2930
import java.io.IOException;
@@ -42,14 +43,15 @@ public record HashAggregationOperatorFactory(
4243
List<BlockHash.GroupSpec> groups,
4344
AggregatorMode aggregatorMode,
4445
List<GroupingAggregator.Factory> aggregators,
45-
int maxPageSize
46+
int maxPageSize,
47+
AnalysisRegistry analysisRegistry
4648
) implements OperatorFactory {
4749
@Override
4850
public Operator get(DriverContext driverContext) {
4951
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
5052
return new HashAggregationOperator(
5153
aggregators,
52-
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
54+
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
5355
driverContext
5456
);
5557
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
import static org.elasticsearch.compute.data.BlockTestUtils.append;
5555
import static org.hamcrest.Matchers.equalTo;
5656
import static org.hamcrest.Matchers.hasSize;
57-
import static org.hamcrest.Matchers.in;
5857

5958
/**
6059
* Shared tests for testing grouped aggregations.
@@ -107,7 +106,8 @@ private Operator.OperatorFactory simpleWithMode(
107106
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
108107
mode,
109108
List.of(supplier.groupingAggregatorFactory(mode)),
110-
randomPageSize()
109+
randomPageSize(),
110+
null
111111
);
112112
}
113113

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
package org.elasticsearch.compute.aggregation.blockhash;
99

1010
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
1112
import org.elasticsearch.common.breaker.CircuitBreaker;
1213
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.common.settings.Settings;
1315
import org.elasticsearch.common.unit.ByteSizeValue;
1416
import org.elasticsearch.common.util.BigArrays;
1517
import org.elasticsearch.common.util.MockBigArrays;
@@ -35,7 +37,15 @@
3537
import org.elasticsearch.compute.operator.LocalSourceOperator;
3638
import org.elasticsearch.compute.operator.PageConsumerOperator;
3739
import org.elasticsearch.core.Releasables;
38-
40+
import org.elasticsearch.env.Environment;
41+
import org.elasticsearch.env.TestEnvironment;
42+
import org.elasticsearch.index.analysis.AnalysisRegistry;
43+
import org.elasticsearch.indices.analysis.AnalysisModule;
44+
import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
45+
import org.elasticsearch.xpack.ml.MachineLearning;
46+
import org.junit.Before;
47+
48+
import java.io.IOException;
3949
import java.util.ArrayList;
4050
import java.util.HashMap;
4151
import java.util.List;
@@ -50,6 +60,19 @@
5060

5161
public class CategorizeBlockHashTests extends BlockHashTestCase {
5262

63+
private AnalysisRegistry analysisRegistry;
64+
65+
@Before
66+
private void initAnalysisRegistry() throws IOException {
67+
analysisRegistry = new AnalysisModule(
68+
TestEnvironment.newEnvironment(
69+
Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
70+
),
71+
List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
72+
new StablePluginsRegistry()
73+
).getAnalysisRegistry();
74+
}
75+
5376
public void testCategorizeRaw() {
5477
final Page page;
5578
boolean withNull = randomBoolean();
@@ -72,7 +95,7 @@ public void testCategorizeRaw() {
7295
page = new Page(builder.build());
7396
}
7497

75-
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
98+
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) {
7699
hash.add(page, new GroupingAggregatorFunction.AddInput() {
77100
@Override
78101
public void add(int positionOffset, IntBlock groupIds) {
@@ -145,8 +168,8 @@ public void testCategorizeIntermediate() {
145168

146169
// Fill intermediatePages with the intermediate state from the raw hashes
147170
try (
148-
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
149-
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
171+
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
172+
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
150173
) {
151174
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
152175
@Override
@@ -267,14 +290,16 @@ public void testCategorize_withDriver() {
267290
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
268291
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
269292
) {
270-
textsBuilder.appendBytesRef(new BytesRef("a"));
271-
textsBuilder.appendBytesRef(new BytesRef("b"));
293+
// Note that just using "a" or "aaa" doesn't work, because the ml_standard
294+
// tokenizer drops numbers, including hexadecimal ones.
295+
textsBuilder.appendBytesRef(new BytesRef("aaazz"));
296+
textsBuilder.appendBytesRef(new BytesRef("bbbzz"));
272297
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
273298
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
274299
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
275300
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
276-
textsBuilder.appendBytesRef(new BytesRef("c"));
277-
textsBuilder.appendBytesRef(new BytesRef("d"));
301+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
302+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
278303
countsBuilder.appendLong(1);
279304
countsBuilder.appendLong(2);
280305
countsBuilder.appendLong(800);
@@ -293,10 +318,10 @@ public void testCategorize_withDriver() {
293318
) {
294319
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
295320
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
296-
textsBuilder.appendBytesRef(new BytesRef("c"));
321+
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
297322
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
298-
textsBuilder.appendBytesRef(new BytesRef("d"));
299-
textsBuilder.appendBytesRef(new BytesRef("e"));
323+
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
324+
textsBuilder.appendBytesRef(new BytesRef("eeezz"));
300325
countsBuilder.appendLong(9);
301326
countsBuilder.appendLong(90);
302327
countsBuilder.appendLong(3);
@@ -320,7 +345,8 @@ public void testCategorize_withDriver() {
320345
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
321346
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
322347
),
323-
16 * 1024
348+
16 * 1024,
349+
analysisRegistry
324350
).get(driverContext)
325351
),
326352
new PageConsumerOperator(intermediateOutput::add),
@@ -339,7 +365,8 @@ public void testCategorize_withDriver() {
339365
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
340366
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
341367
),
342-
16 * 1024
368+
16 * 1024,
369+
analysisRegistry
343370
).get(driverContext)
344371
),
345372
new PageConsumerOperator(intermediateOutput::add),
@@ -360,7 +387,8 @@ public void testCategorize_withDriver() {
360387
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
361388
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
362389
),
363-
16 * 1024
390+
16 * 1024,
391+
analysisRegistry
364392
).get(driverContext)
365393
),
366394
new PageConsumerOperator(finalOutput::add),
@@ -385,15 +413,15 @@ public void testCategorize_withDriver() {
385413
sums,
386414
equalTo(
387415
Map.of(
388-
".*?a.*?",
416+
".*?aaazz.*?",
389417
1L,
390-
".*?b.*?",
418+
".*?bbbzz.*?",
391419
2L,
392-
".*?c.*?",
420+
".*?ccczz.*?",
393421
33L,
394-
".*?d.*?",
422+
".*?dddzz.*?",
395423
44L,
396-
".*?e.*?",
424+
".*?eeezz.*?",
397425
5L,
398426
".*?words.+?words.+?words.+?goodbye.*?",
399427
8888L,
@@ -406,15 +434,15 @@ public void testCategorize_withDriver() {
406434
maxs,
407435
equalTo(
408436
Map.of(
409-
".*?a.*?",
437+
".*?aaazz.*?",
410438
1L,
411-
".*?b.*?",
439+
".*?bbbzz.*?",
412440
2L,
413-
".*?c.*?",
441+
".*?ccczz.*?",
414442
30L,
415-
".*?d.*?",
443+
".*?dddzz.*?",
416444
40L,
417-
".*?e.*?",
445+
".*?eeezz.*?",
418446
5L,
419447
".*?words.+?words.+?words.+?goodbye.*?",
420448
8000L,

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) {
5959
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
6060
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
6161
),
62-
randomPageSize()
62+
randomPageSize(),
63+
null
6364
);
6465
}
6566

0 commit comments

Comments
 (0)