From 4bac42ae7f5143c807625bc2c41e89213fc467ff Mon Sep 17 00:00:00 2001 From: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com> Date: Fri, 29 Nov 2024 11:00:54 +0100 Subject: [PATCH] Correct categorization analyzer in ES|QL categorize (#117695) * Correct categorization analyzer in ES|QL categorize * close categorizer if constructing analyzer fails * Rename capability CATEGORIZE_V4 * add comments --- x-pack/plugin/esql/compute/build.gradle | 4 +- .../compute/src/main/java/module-info.java | 1 + .../aggregation/blockhash/BlockHash.java | 10 +- .../blockhash/CategorizeRawBlockHash.java | 34 ++--- .../operator/HashAggregationOperator.java | 6 +- .../GroupingAggregatorFunctionTestCase.java | 4 +- .../blockhash/CategorizeBlockHashTests.java | 76 +++++++---- .../HashAggregationOperatorTests.java | 3 +- .../src/main/resources/categorize.csv-spec | 123 ++++++++++-------- .../xpack/esql/action/EsqlCapabilities.java | 2 +- .../AbstractPhysicalOperationProviders.java | 9 +- .../planner/EsPhysicalOperationProviders.java | 4 +- .../xpack/esql/plugin/ComputeService.java | 2 +- .../xpack/esql/analysis/VerifierTests.java | 6 +- .../optimizer/LogicalPlanOptimizerTests.java | 4 +- .../planner/LocalExecutionPlannerTests.java | 4 +- .../TestPhysicalOperationProviders.java | 20 ++- 17 files changed, 199 insertions(+), 113 deletions(-) diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 609c778df5929..8e866cec3f421 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -11,11 +11,13 @@ base { dependencies { compileOnly project(':server') compileOnly project('ann') + compileOnly project(xpackModule('core')) compileOnly project(xpackModule('ml')) annotationProcessor project('gen') implementation 'com.carrotsearch:hppc:0.8.1' - testImplementation project(':test:framework') + testImplementation(project(':modules:analysis-common')) + testImplementation(project(':test:framework')) testImplementation(project(xpackModule('esql-core'))) testImplementation(project(xpackModule('core'))) testImplementation(project(xpackModule('ml'))) diff --git a/x-pack/plugin/esql/compute/src/main/java/module-info.java b/x-pack/plugin/esql/compute/src/main/java/module-info.java index 573d9e048a4d4..1b3253694b298 100644 --- a/x-pack/plugin/esql/compute/src/main/java/module-info.java +++ b/x-pack/plugin/esql/compute/src/main/java/module-info.java @@ -19,6 +19,7 @@ requires org.elasticsearch.ml; requires org.elasticsearch.tdigest; requires org.elasticsearch.geo; + requires org.elasticsearch.xcore; requires hppc; exports org.elasticsearch.compute; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index ef0f3ceb112c4..ea76c3bd0a0aa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -25,6 +25,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.index.analysis.AnalysisRegistry; import java.util.Iterator; import java.util.List; @@ -169,14 +170,19 @@ public static BlockHash buildPackedValuesBlockHash(List groups, Block /** * Builds a BlockHash for the Categorize grouping function. */ - public static BlockHash buildCategorizeBlockHash(List groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) { + public static BlockHash buildCategorizeBlockHash( + List groups, + AggregatorMode aggregatorMode, + BlockFactory blockFactory, + AnalysisRegistry analysisRegistry + ) { if (groups.size() != 1) { throw new IllegalArgumentException("only a single CATEGORIZE group can used"); } return aggregatorMode.isInputPartial() ? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial()) - : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial()); + : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry); } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java index 0d0a2fef2f82b..47dd7f650dffa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java @@ -7,7 +7,6 @@ package org.elasticsearch.compute.aggregation.blockhash; -import org.apache.lucene.analysis.core.WhitespaceTokenizer; import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; import org.elasticsearch.compute.data.Block; @@ -19,13 +18,14 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import org.elasticsearch.index.analysis.CharFilterFactory; -import org.elasticsearch.index.analysis.CustomAnalyzer; -import org.elasticsearch.index.analysis.TokenFilterFactory; -import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; +import java.io.IOException; +import java.util.List; + /** * BlockHash implementation for {@code Categorize} grouping function. *

@@ -33,19 +33,23 @@ *

*/ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash { + private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer( + List.of() + ); + private final CategorizeEvaluator evaluator; - CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) { + CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) { super(blockFactory, channel, outputPartial); - CategorizationAnalyzer analyzer = new CategorizationAnalyzer( - // TODO: should be the same analyzer as used in Production - new CustomAnalyzer( - TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new), - new CharFilterFactory[0], - new TokenFilterFactory[0] - ), - true - ); + + CategorizationAnalyzer analyzer; + try { + analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG); + } catch (IOException e) { + categorizer.close(); + throw new RuntimeException(e); + } + this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index a69e8ca767014..6f8386ec08de1 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -24,6 +24,7 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; @@ -42,14 +43,15 @@ public record HashAggregationOperatorFactory( List groups, AggregatorMode aggregatorMode, List aggregators, - int maxPageSize + int maxPageSize, + AnalysisRegistry analysisRegistry ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) { return new HashAggregationOperator( aggregators, - () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()), + () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry), driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index 1e97bdf5a2e79..58925a5ca36fc 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -54,7 +54,6 @@ import static org.elasticsearch.compute.data.BlockTestUtils.append; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.in; /** * Shared tests for testing grouped aggregations. @@ -107,7 +106,8 @@ private Operator.OperatorFactory simpleWithMode( List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), mode, List.of(supplier.groupingAggregatorFactory(mode)), - randomPageSize() + randomPageSize(), + null ); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index dd7a87dc4a574..8a3c723557151 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -8,8 +8,10 @@ package org.elasticsearch.compute.aggregation.blockhash; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.analysis.common.CommonAnalysisPlugin; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.MockBigArrays; @@ -35,7 +37,15 @@ import org.elasticsearch.compute.operator.LocalSourceOperator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.core.Releasables; - +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.indices.analysis.AnalysisModule; +import org.elasticsearch.plugins.scanners.StablePluginsRegistry; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.junit.Before; + +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -50,6 +60,19 @@ public class CategorizeBlockHashTests extends BlockHashTestCase { + private AnalysisRegistry analysisRegistry; + + @Before + private void initAnalysisRegistry() throws IOException { + analysisRegistry = new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()), + new StablePluginsRegistry() + ).getAnalysisRegistry(); + } + public void testCategorizeRaw() { final Page page; boolean withNull = randomBoolean(); @@ -72,7 +95,7 @@ public void testCategorizeRaw() { page = new Page(builder.build()); } - try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) { + try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) { hash.add(page, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { @@ -145,8 +168,8 @@ public void testCategorizeIntermediate() { // Fill intermediatePages with the intermediate state from the raw hashes try ( - BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true); - BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true) + BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry); + BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry); ) { rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() { @Override @@ -267,14 +290,16 @@ public void testCategorize_withDriver() { BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10); LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10) ) { - textsBuilder.appendBytesRef(new BytesRef("a")); - textsBuilder.appendBytesRef(new BytesRef("b")); + // Note that just using "a" or "aaa" doesn't work, because the ml_standard + // tokenizer drops numbers, including hexadecimal ones. + textsBuilder.appendBytesRef(new BytesRef("aaazz")); + textsBuilder.appendBytesRef(new BytesRef("bbbzz")); textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan")); textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik")); textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom")); textsBuilder.appendBytesRef(new BytesRef("words words words hello jan")); - textsBuilder.appendBytesRef(new BytesRef("c")); - textsBuilder.appendBytesRef(new BytesRef("d")); + textsBuilder.appendBytesRef(new BytesRef("ccczz")); + textsBuilder.appendBytesRef(new BytesRef("dddzz")); countsBuilder.appendLong(1); countsBuilder.appendLong(2); countsBuilder.appendLong(800); @@ -293,10 +318,10 @@ public void testCategorize_withDriver() { ) { textsBuilder.appendBytesRef(new BytesRef("words words words hello nik")); textsBuilder.appendBytesRef(new BytesRef("words words words hello nik")); - textsBuilder.appendBytesRef(new BytesRef("c")); + textsBuilder.appendBytesRef(new BytesRef("ccczz")); textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris")); - textsBuilder.appendBytesRef(new BytesRef("d")); - textsBuilder.appendBytesRef(new BytesRef("e")); + textsBuilder.appendBytesRef(new BytesRef("dddzz")); + textsBuilder.appendBytesRef(new BytesRef("eeezz")); countsBuilder.appendLong(9); countsBuilder.appendLong(90); countsBuilder.appendLong(3); @@ -320,7 +345,8 @@ public void testCategorize_withDriver() { new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL), new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL) ), - 16 * 1024 + 16 * 1024, + analysisRegistry ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add), @@ -339,7 +365,8 @@ public void testCategorize_withDriver() { new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL), new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL) ), - 16 * 1024 + 16 * 1024, + analysisRegistry ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add), @@ -360,7 +387,8 @@ public void testCategorize_withDriver() { new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL), new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL) ), - 16 * 1024 + 16 * 1024, + analysisRegistry ).get(driverContext) ), new PageConsumerOperator(finalOutput::add), @@ -385,15 +413,15 @@ public void testCategorize_withDriver() { sums, equalTo( Map.of( - ".*?a.*?", + ".*?aaazz.*?", 1L, - ".*?b.*?", + ".*?bbbzz.*?", 2L, - ".*?c.*?", + ".*?ccczz.*?", 33L, - ".*?d.*?", + ".*?dddzz.*?", 44L, - ".*?e.*?", + ".*?eeezz.*?", 5L, ".*?words.+?words.+?words.+?goodbye.*?", 8888L, @@ -406,15 +434,15 @@ public void testCategorize_withDriver() { maxs, equalTo( Map.of( - ".*?a.*?", + ".*?aaazz.*?", 1L, - ".*?b.*?", + ".*?bbbzz.*?", 2L, - ".*?c.*?", + ".*?ccczz.*?", 30L, - ".*?d.*?", + ".*?dddzz.*?", 40L, - ".*?e.*?", + ".*?eeezz.*?", 5L, ".*?words.+?words.+?words.+?goodbye.*?", 8000L, diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index b2f4ad594936e..953c7d1c313f1 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -59,7 +59,8 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode) ), - randomPageSize() + randomPageSize(), + null ); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index 547c430ed7518..e45b10d1aa122 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -1,5 +1,5 @@ standard aggs -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS count=COUNT(), @@ -17,7 +17,7 @@ count:long | sum:long | avg:double | count_distinct:long | category:keyw ; values aggs -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS values=MV_SORT(VALUES(message)), @@ -33,7 +33,7 @@ values:keyword | top ; mv -required_capability: categorize_v3 +required_capability: categorize_v4 FROM mv_sample_data | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message) @@ -48,7 +48,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; row mv -required_capability: categorize_v3 +required_capability: categorize_v4 ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"] | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message) @@ -60,8 +60,20 @@ COUNT():long | VALUES(str):keyword | category:keyword 1 | [a, b, c] | .*?disconnected.*? ; +skips stopwords +required_capability: categorize_v4 + +ROW message = ["Mon Tue connected to a", "Jul Aug connected to b September ", "UTC connected GMT to c UTC"] + | STATS COUNT() BY category=CATEGORIZE(message) + | SORT category +; + +COUNT():long | category:keyword + 3 | .*?connected.+?to.*? +; + with multiple indices -required_capability: categorize_v3 +required_capability: categorize_v4 required_capability: union_types FROM sample_data* @@ -76,7 +88,7 @@ COUNT():long | category:keyword ; mv with many values -required_capability: categorize_v3 +required_capability: categorize_v4 FROM employees | STATS COUNT() BY category=CATEGORIZE(job_positions) @@ -93,7 +105,7 @@ COUNT():long | category:keyword ; mv with many values and SUM -required_capability: categorize_v3 +required_capability: categorize_v4 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -108,7 +120,7 @@ SUM(languages):long | category:keyword ; mv with many values and nulls and SUM -required_capability: categorize_v3 +required_capability: categorize_v4 FROM employees | STATS SUM(languages) BY category=CATEGORIZE(job_positions) @@ -122,7 +134,7 @@ SUM(languages):long | category:keyword ; mv via eval -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL message = MV_APPEND(message, "Banana") @@ -138,7 +150,7 @@ COUNT():long | category:keyword ; mv via eval const -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -152,7 +164,7 @@ COUNT():long | category:keyword ; mv via eval const without aliases -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL message = ["Banana", "Bread"] @@ -166,7 +178,7 @@ COUNT():long | CATEGORIZE(message):keyword ; mv const in parameter -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -179,7 +191,7 @@ COUNT():long | c:keyword ; agg alias shadowing -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"]) @@ -194,7 +206,7 @@ c:keyword ; chained aggregations using categorize -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -203,13 +215,13 @@ FROM sample_data ; COUNT():long | category:keyword - 1 | .*?\.\*\?Connected\.\+\?to\.\*\?.*? - 1 | .*?\.\*\?Connection\.\+\?error\.\*\?.*? - 1 | .*?\.\*\?Disconnected\.\*\?.*? + 1 | .*?Connected.+?to.*? + 1 | .*?Connection.+?error.*? + 1 | .*?Disconnected.*? ; stats without aggs -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS BY category=CATEGORIZE(message) @@ -223,7 +235,7 @@ category:keyword ; text field -required_capability: categorize_v3 +required_capability: categorize_v4 FROM hosts | STATS COUNT() BY category=CATEGORIZE(host_group) @@ -231,14 +243,17 @@ FROM hosts ; COUNT():long | category:keyword - 2 | .*?DB.+?servers.*? 2 | .*?Gateway.+?instances.*? 5 | .*?Kubernetes.+?cluster.*? + 2 | .*?servers.*? 1 | null + +// Note: DB is removed from "DB servers", because the ml_standard +// tokenizer drops numbers, including hexadecimal ones. ; on TO_UPPER -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message)) @@ -252,7 +267,7 @@ COUNT():long | category:keyword ; on CONCAT -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana")) @@ -266,7 +281,7 @@ COUNT():long | category:keyword ; on CONCAT with unicode -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊")) @@ -274,13 +289,13 @@ FROM sample_data ; COUNT():long | category:keyword - 3 | .*?Connected.+?to.+?👍🏽😊.*? - 3 | .*?Connection.+?error.+?👍🏽😊.*? - 1 | .*?Disconnected.+?👍🏽😊.*? + 3 | .*?Connected.+?to.*? + 3 | .*?Connection.+?error.*? + 1 | .*?Disconnected.*? ; on REVERSE(CONCAT()) -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊"))) @@ -288,13 +303,13 @@ FROM sample_data ; COUNT():long | category:keyword - 1 | .*?😊👍🏽.+?detcennocsiD.*? - 3 | .*?😊👍🏽.+?ot.+?detcennoC.*? - 3 | .*?😊👍🏽.+?rorre.+?noitcennoC.*? + 1 | .*?detcennocsiD.*? + 3 | .*?ot.+?detcennoC.*? + 3 | .*?rorre.+?noitcennoC.*? ; and then TO_LOWER -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -309,7 +324,7 @@ COUNT():long | category:keyword ; on const empty string -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE("") @@ -321,7 +336,7 @@ COUNT():long | category:keyword ; on const empty string from eval -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL x = "" @@ -334,7 +349,7 @@ COUNT():long | category:keyword ; on null -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL x = null @@ -347,7 +362,7 @@ COUNT():long | SUM(event_duration):long | category:keyword ; on null string -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL x = null::string @@ -360,7 +375,7 @@ COUNT():long | category:keyword ; filtering out all data -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | WHERE @timestamp < "2023-10-23T00:00:00Z" @@ -372,7 +387,7 @@ COUNT():long | category:keyword ; filtering out all data with constant -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS COUNT() BY category=CATEGORIZE(message) @@ -383,7 +398,7 @@ COUNT():long | category:keyword ; drop output columns -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS count=COUNT() BY category=CATEGORIZE(message) @@ -398,7 +413,7 @@ x:integer ; category value processing -required_capability: categorize_v3 +required_capability: categorize_v4 ROW message = ["connected to a", "connected to b", "disconnected"] | STATS COUNT() BY category=CATEGORIZE(message) @@ -412,21 +427,21 @@ COUNT():long | category:keyword ; row aliases -required_capability: categorize_v3 +required_capability: categorize_v4 -ROW message = "connected to a" +ROW message = "connected to xyz" | EVAL x = message | STATS COUNT() BY category=CATEGORIZE(x) | EVAL y = category | SORT y ; -COUNT():long | category:keyword | y:keyword - 1 | .*?connected.+?to.+?a.*? | .*?connected.+?to.+?a.*? +COUNT():long | category:keyword | y:keyword + 1 | .*?connected.+?to.+?xyz.*? | .*?connected.+?to.+?xyz.*? ; from aliases -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL x = message @@ -442,9 +457,9 @@ COUNT():long | category:keyword | y:keyword ; row aliases with keep -required_capability: categorize_v3 +required_capability: categorize_v4 -ROW message = "connected to a" +ROW message = "connected to xyz" | EVAL x = message | KEEP x | STATS COUNT() BY category=CATEGORIZE(x) @@ -454,11 +469,11 @@ ROW message = "connected to a" ; COUNT():long | y:keyword - 1 | .*?connected.+?to.+?a.*? + 1 | .*?connected.+?to.+?xyz.*? ; from aliases with keep -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | EVAL x = message @@ -476,9 +491,9 @@ COUNT():long | y:keyword ; row rename -required_capability: categorize_v3 +required_capability: categorize_v4 -ROW message = "connected to a" +ROW message = "connected to xyz" | RENAME message as x | STATS COUNT() BY category=CATEGORIZE(x) | RENAME category as y @@ -486,11 +501,11 @@ ROW message = "connected to a" ; COUNT():long | y:keyword - 1 | .*?connected.+?to.+?a.*? + 1 | .*?connected.+?to.+?xyz.*? ; from rename -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | RENAME message as x @@ -506,7 +521,7 @@ COUNT():long | y:keyword ; row drop -required_capability: categorize_v3 +required_capability: categorize_v4 ROW message = "connected to a" | STATS c = COUNT() BY category=CATEGORIZE(message) @@ -519,7 +534,7 @@ c:long ; from drop -required_capability: categorize_v3 +required_capability: categorize_v4 FROM sample_data | STATS c = COUNT() BY category=CATEGORIZE(message) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 5c99f8f502d73..0d272d1da46c4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -396,7 +396,7 @@ public enum Cap { /** * Supported the text categorization function "CATEGORIZE". */ - CATEGORIZE_V3(Build.current().isSnapshot()), + CATEGORIZE_V4(Build.current().isSnapshot()), /** * QSTR function diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index a7418654f6b0e..69e2d1c45aa3c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -18,6 +18,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Alias; @@ -46,6 +47,11 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders { private final AggregateMapper aggregateMapper = new AggregateMapper(); + private final AnalysisRegistry analysisRegistry; + + AbstractPhysicalOperationProviders(AnalysisRegistry analysisRegistry) { + this.analysisRegistry = analysisRegistry; + } @Override public final PhysicalOperation groupingPhysicalOperation( @@ -173,7 +179,8 @@ else if (aggregatorMode.isOutputPartial()) { groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), aggregatorMode, aggregatorFactories, - context.pageSize(aggregateExec.estimatedRowSize()) + context.pageSize(aggregateExec.estimatedRowSize()), + analysisRegistry ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 15f5b6579098d..7bf7d0e2d08eb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -34,6 +34,7 @@ import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.FieldNamesFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; @@ -98,7 +99,8 @@ public interface ShardContext extends org.elasticsearch.compute.lucene.ShardCont private final List shardContexts; - public EsPhysicalOperationProviders(List shardContexts) { + public EsPhysicalOperationProviders(List shardContexts, AnalysisRegistry analysisRegistry) { + super(analysisRegistry); this.shardContexts = shardContexts; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 73266551f169c..b06dd3cdb64d3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -452,7 +452,7 @@ void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, context.exchangeSink(), enrichLookupService, lookupFromIndexService, - new EsPhysicalOperationProviders(contexts) + new EsPhysicalOperationProviders(contexts, searchService.getIndicesService().getAnalysis()) ); LOGGER.debug("Received physical plan:\n{}", plan); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index dd14e8dd82123..d4fca2a0a2540 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1846,7 +1846,7 @@ public void testIntervalAsString() { } public void testCategorizeSingleGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)"); query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); @@ -1875,7 +1875,7 @@ public void testCategorizeSingleGrouping() { } public void testCategorizeNestedGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)"); @@ -1890,7 +1890,7 @@ public void testCategorizeNestedGrouping() { } public void testCategorizeWithinAggregations() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 76641802160c4..ec02995978d97 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -1212,7 +1212,7 @@ public void testCombineProjectionWithAggregationFirstAndAliasedGroupingUsedInAgg * \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..] */ public void testCombineProjectionWithCategorizeGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); var plan = plan(""" from test @@ -3949,7 +3949,7 @@ public void testNestedExpressionsInGroups() { * \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..] */ public void testNestedExpressionsInGroupsWithCategorize() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled()); + assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V4.isEnabled()); var plan = optimizedPlan(""" from test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index ff9e45a9f9233..5d8da21c6faad 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -156,7 +156,7 @@ private Configuration config() { randomZone(), randomLocale(random()), "test_user", - "test_cluser", + "test_cluster", pragmas, EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE.getDefault(null), EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE.getDefault(null), @@ -187,7 +187,7 @@ private EsPhysicalOperationProviders esPhysicalOperationProviders() throws IOExc ); } releasables.add(searcher); - return new EsPhysicalOperationProviders(shardContexts); + return new EsPhysicalOperationProviders(shardContexts, null); } private IndexReader reader() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index c811643c8daea..e91fc6e49312d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -8,7 +8,9 @@ package org.elasticsearch.xpack.esql.planner; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.analysis.common.CommonAnalysisPlugin; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.aggregation.GroupingAggregator; @@ -28,7 +30,11 @@ import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.SourceOperator.SourceOperatorFactory; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.indices.analysis.AnalysisModule; +import org.elasticsearch.plugins.scanners.StablePluginsRegistry; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.expression.Attribute; @@ -39,7 +45,9 @@ import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation; +import org.elasticsearch.xpack.ml.MachineLearning; +import java.io.IOException; import java.util.List; import java.util.Random; import java.util.function.Function; @@ -48,6 +56,7 @@ import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween; import static java.util.stream.Collectors.joining; +import static org.apache.lucene.tests.util.LuceneTestCase.createTempDir; import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.DOC_VALUES; import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.NONE; @@ -56,7 +65,16 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationPro private final Page testData; private final List columnNames; - public TestPhysicalOperationProviders(Page testData, List columnNames) { + public TestPhysicalOperationProviders(Page testData, List columnNames) throws IOException { + super( + new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()), + new StablePluginsRegistry() + ).getAnalysisRegistry() + ); this.testData = testData; this.columnNames = columnNames; }