From 6f2a058bd521e73131b9edb8ccf80a13b4fadabd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?=
Date: Thu, 28 Nov 2024 11:40:12 +0100
Subject: [PATCH] Backport 9022ccc
---
docs/changelog/114317.yaml | 5 +
.../kibana/definition/categorize.json | 4 +-
.../esql/functions/types/categorize.asciidoc | 4 +-
muted-tests.yml | 15 -
.../AbstractCategorizeBlockHash.java | 105 ++++
.../aggregation/blockhash/BlockHash.java | 28 +-
.../blockhash/CategorizeRawBlockHash.java | 137 +++++
.../CategorizedIntermediateBlockHash.java | 77 +++
.../operator/HashAggregationOperator.java | 9 +
.../GroupingAggregatorFunctionTestCase.java | 1 +
.../blockhash/BlockHashTestCase.java | 34 ++
.../aggregation/blockhash/BlockHashTests.java | 22 +-
.../blockhash/CategorizeBlockHashTests.java | 406 ++++++++++++++
.../HashAggregationOperatorTests.java | 1 +
.../xpack/esql/CsvTestsDataLoader.java | 2 +
.../src/main/resources/categorize.csv-spec | 526 +++++++++++++++++-
.../resources/mapping-mv_sample_data.json | 16 +
.../src/main/resources/mv_sample_data.csv | 8 +
.../grouping/CategorizeEvaluator.java | 145 -----
.../xpack/esql/action/EsqlCapabilities.java | 5 +-
.../function/grouping/Categorize.java | 76 +--
.../rules/logical/CombineProjections.java | 38 +-
.../optimizer/rules/logical/FoldNull.java | 2 +
...laceAggregateNestedExpressionWithEval.java | 31 +-
.../physical/local/InsertFieldExtraction.java | 17 +-
.../AbstractPhysicalOperationProviders.java | 42 +-
.../xpack/esql/analysis/VerifierTests.java | 6 +-
.../function/AbstractAggregationTestCase.java | 3 +-
.../function/AbstractFunctionTestCase.java | 19 +-
.../AbstractScalarFunctionTestCase.java | 1 +
.../expression/function/TestCaseSupplier.java | 83 ++-
.../function/grouping/CategorizeTests.java | 16 +-
.../optimizer/LogicalPlanOptimizerTests.java | 61 ++
.../rules/logical/FoldNullTests.java | 13 +
.../categorization/TokenListCategorizer.java | 24 +
35 files changed, 1660 insertions(+), 322 deletions(-)
create mode 100644 docs/changelog/114317.yaml
create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java
create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-mv_sample_data.json
create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_sample_data.csv
delete mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java
diff --git a/docs/changelog/114317.yaml b/docs/changelog/114317.yaml
new file mode 100644
index 0000000000000..9c73fe513e197
--- /dev/null
+++ b/docs/changelog/114317.yaml
@@ -0,0 +1,5 @@
+pr: 114317
+summary: "ESQL: CATEGORIZE as a `BlockHash`"
+area: ES|QL
+type: enhancement
+issues: []
diff --git a/docs/reference/esql/functions/kibana/definition/categorize.json b/docs/reference/esql/functions/kibana/definition/categorize.json
index 386b178d3753f..ca3971a6e05a3 100644
--- a/docs/reference/esql/functions/kibana/definition/categorize.json
+++ b/docs/reference/esql/functions/kibana/definition/categorize.json
@@ -14,7 +14,7 @@
}
],
"variadic" : false,
- "returnType" : "integer"
+ "returnType" : "keyword"
},
{
"params" : [
@@ -26,7 +26,7 @@
}
],
"variadic" : false,
- "returnType" : "integer"
+ "returnType" : "keyword"
}
],
"preview" : false,
diff --git a/docs/reference/esql/functions/types/categorize.asciidoc b/docs/reference/esql/functions/types/categorize.asciidoc
index 4917ed313e6d7..5b64971cbc482 100644
--- a/docs/reference/esql/functions/types/categorize.asciidoc
+++ b/docs/reference/esql/functions/types/categorize.asciidoc
@@ -5,6 +5,6 @@
[%header.monospaced.styled,format=dsv,separator=|]
|===
field | result
-keyword | integer
-text | integer
+keyword | keyword
+text | keyword
|===
diff --git a/muted-tests.yml b/muted-tests.yml
index 336dd2696df91..ddfa55d93f4ee 100644
--- a/muted-tests.yml
+++ b/muted-tests.yml
@@ -193,12 +193,6 @@ tests:
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
method: test {p0=indices.split/40_routing_partition_size/more than 1}
issue: https://github.com/elastic/elasticsearch/issues/113841
-- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
- method: test {categorize.Categorize SYNC}
- issue: https://github.com/elastic/elasticsearch/issues/113722
-- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
- method: test {categorize.Categorize ASYNC}
- issue: https://github.com/elastic/elasticsearch/issues/116373
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
method: testBlockedThreadPoolsRejectUserRequests
issue: https://github.com/elastic/elasticsearch/issues/113939
@@ -254,12 +248,6 @@ tests:
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
method: test {p0=search/380_sort_segments_on_timestamp/Test that index segments are NOT sorted on timestamp field when @timestamp field is dynamically added}
issue: https://github.com/elastic/elasticsearch/issues/116221
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
- method: test {categorize.Categorize SYNC}
- issue: https://github.com/elastic/elasticsearch/issues/113054
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
- method: test {categorize.Categorize ASYNC}
- issue: https://github.com/elastic/elasticsearch/issues/113054
- class: org.elasticsearch.ingest.common.IngestCommonClientYamlTestSuiteIT
method: test {yaml=ingest/310_reroute_processor/Test remove then add reroute processor with and without lazy rollover}
issue: https://github.com/elastic/elasticsearch/issues/116158
@@ -272,9 +260,6 @@ tests:
- class: org.elasticsearch.xpack.deprecation.DeprecationHttpIT
method: testDeprecatedSettingsReturnWarnings
issue: https://github.com/elastic/elasticsearch/issues/108628
-- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
- method: test {categorize.Categorize}
- issue: https://github.com/elastic/elasticsearch/issues/116434
- class: org.elasticsearch.xpack.apmdata.APMYamlTestSuiteIT
method: test {yaml=/10_apm/Test template reinstallation}
issue: https://github.com/elastic/elasticsearch/issues/116445
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
new file mode 100644
index 0000000000000..22d3a10facb06
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRefBuilder;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.ReleasableIterator;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
+import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
+import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
+
+import java.io.IOException;
+
+/**
+ * Base BlockHash implementation for {@code Categorize} grouping function.
+ */
+public abstract class AbstractCategorizeBlockHash extends BlockHash {
+ // TODO: this should probably also take an emitBatchSize
+ private final int channel;
+ private final boolean outputPartial;
+ protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
+
+ AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
+ super(blockFactory);
+ this.channel = channel;
+ this.outputPartial = outputPartial;
+ this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
+ new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
+ CategorizationPartOfSpeechDictionary.getInstance(),
+ 0.70f
+ );
+ }
+
+ protected int channel() {
+ return channel;
+ }
+
+ @Override
+ public Block[] getKeys() {
+ return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
+ }
+
+ @Override
+ public IntVector nonEmpty() {
+ return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
+ }
+
+ @Override
+ public BitArray seenGroupIds(BigArrays bigArrays) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
+ */
+ private Block buildIntermediateBlock() {
+ if (categorizer.getCategoryCount() == 0) {
+ return blockFactory.newConstantNullBlock(0);
+ }
+ try (BytesStreamOutput out = new BytesStreamOutput()) {
+ // TODO be more careful here.
+ out.writeVInt(categorizer.getCategoryCount());
+ for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+ category.writeTo(out);
+ }
+ // We're returning a block with N positions just because the Page must have all blocks with the same position count!
+ return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private Block buildFinalBlock() {
+ try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
+ BytesRefBuilder scratch = new BytesRefBuilder();
+ for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+ scratch.copyChars(category.getRegex());
+ result.appendBytesRef(scratch.get());
+ scratch.clear();
+ }
+ return result.build().asBlock();
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
index 919cb92f79260..ef0f3ceb112c4 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
@@ -14,6 +14,7 @@
import org.elasticsearch.common.util.Int3Hash;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
@@ -58,9 +59,7 @@
* leave a big gap, even if we never see {@code null}.
*
*/
-public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
- permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
- NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
+public abstract class BlockHash implements Releasable, SeenGroupIds {
protected final BlockFactory blockFactory;
@@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
@Override
public abstract BitArray seenGroupIds(BigArrays bigArrays);
- public record GroupSpec(int channel, ElementType elementType) {}
+ /**
+ * @param isCategorize Whether this group is a CATEGORIZE() or not.
+ * May be changed in the future when more stateful grouping functions are added.
+ */
+ public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
+ public GroupSpec(int channel, ElementType elementType) {
+ this(channel, elementType, false);
+ }
+ }
/**
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
@@ -159,6 +166,19 @@ public static BlockHash buildPackedValuesBlockHash(List groups, Block
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
}
+ /**
+ * Builds a BlockHash for the Categorize grouping function.
+ */
+ public static BlockHash buildCategorizeBlockHash(List groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
+ if (groups.size() != 1) {
+ throw new IllegalArgumentException("only a single CATEGORIZE group can used");
+ }
+
+ return aggregatorMode.isInputPartial()
+ ? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
+ : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
+ }
+
/**
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
*/
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
new file mode 100644
index 0000000000000..bf633e0454384
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.analysis.core.WhitespaceTokenizer;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.analysis.CharFilterFactory;
+import org.elasticsearch.index.analysis.CustomAnalyzer;
+import org.elasticsearch.index.analysis.TokenFilterFactory;
+import org.elasticsearch.index.analysis.TokenizerFactory;
+import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
+import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
+
+/**
+ * BlockHash implementation for {@code Categorize} grouping function.
+ *
+ * This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
+ *
+ */
+public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
+ private final CategorizeEvaluator evaluator;
+
+ CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
+ super(blockFactory, channel, outputPartial);
+ CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
+ // TODO: should be the same analyzer as used in Production
+ new CustomAnalyzer(
+ TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
+ new CharFilterFactory[0],
+ new TokenFilterFactory[0]
+ ),
+ true
+ );
+ this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
+ }
+
+ @Override
+ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+ try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
+ addInput.add(0, result);
+ }
+ }
+
+ @Override
+ public void close() {
+ evaluator.close();
+ }
+
+ /**
+ * Similar implementation to an Evaluator.
+ */
+ public static final class CategorizeEvaluator implements Releasable {
+ private final CategorizationAnalyzer analyzer;
+
+ private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
+
+ private final BlockFactory blockFactory;
+
+ public CategorizeEvaluator(
+ CategorizationAnalyzer analyzer,
+ TokenListCategorizer.CloseableTokenListCategorizer categorizer,
+ BlockFactory blockFactory
+ ) {
+ this.analyzer = analyzer;
+ this.categorizer = categorizer;
+ this.blockFactory = blockFactory;
+ }
+
+ public Block eval(BytesRefBlock vBlock) {
+ BytesRefVector vVector = vBlock.asVector();
+ if (vVector == null) {
+ return eval(vBlock.getPositionCount(), vBlock);
+ }
+ IntVector vector = eval(vBlock.getPositionCount(), vVector);
+ return vector.asBlock();
+ }
+
+ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
+ try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
+ BytesRef vScratch = new BytesRef();
+ for (int p = 0; p < positionCount; p++) {
+ if (vBlock.isNull(p)) {
+ result.appendNull();
+ continue;
+ }
+ int first = vBlock.getFirstValueIndex(p);
+ int count = vBlock.getValueCount(p);
+ if (count == 1) {
+ result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
+ continue;
+ }
+ int end = first + count;
+ result.beginPositionEntry();
+ for (int i = first; i < end; i++) {
+ result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
+ }
+ result.endPositionEntry();
+ }
+ return result.build();
+ }
+ }
+
+ public IntVector eval(int positionCount, BytesRefVector vVector) {
+ try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
+ BytesRef vScratch = new BytesRef();
+ for (int p = 0; p < positionCount; p++) {
+ result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
+ }
+ return result.build();
+ }
+ }
+
+ private int process(BytesRef v) {
+ return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
+ }
+
+ @Override
+ public void close() {
+ Releasables.closeExpectNoException(analyzer, categorizer);
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
new file mode 100644
index 0000000000000..1bca34a70e5fa
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * BlockHash implementation for {@code Categorize} grouping function.
+ *
+ * This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
+ *
+ */
+public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
+
+ CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
+ super(blockFactory, channel, outputPartial);
+ }
+
+ @Override
+ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+ if (page.getPositionCount() == 0) {
+ // No categories
+ return;
+ }
+ BytesRefBlock categorizerState = page.getBlock(channel());
+ Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
+ try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
+ for (int i = 0; i < idMap.size(); i++) {
+ newIdsBuilder.appendInt(idMap.get(i));
+ }
+ try (IntBlock newIds = newIdsBuilder.build()) {
+ addInput.add(0, newIds);
+ }
+ }
+ }
+
+ /**
+ * Read intermediate state from a block.
+ *
+ * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
+ */
+ private Map readIntermediate(BytesRef bytes) {
+ Map idMap = new HashMap<>();
+ try (StreamInput in = new BytesArray(bytes).streamInput()) {
+ int count = in.readVInt();
+ for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
+ int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
+ idMap.put(oldCategoryId, newCategoryId);
+ }
+ return idMap;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void close() {
+ categorizer.close();
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java
index 03a4ca2b0ad5e..a69e8ca767014 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java
@@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.Describable;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
@@ -39,11 +40,19 @@ public class HashAggregationOperator implements Operator {
public record HashAggregationOperatorFactory(
List groups,
+ AggregatorMode aggregatorMode,
List aggregators,
int maxPageSize
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
+ if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
+ return new HashAggregationOperator(
+ aggregators,
+ () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
+ driverContext
+ );
+ }
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java
index cb190dfffafb9..1e97bdf5a2e79 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java
@@ -105,6 +105,7 @@ private Operator.OperatorFactory simpleWithMode(
}
return new HashAggregationOperator.HashAggregationOperatorFactory(
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
+ mode,
List.of(supplier.groupingAggregatorFactory(mode)),
randomPageSize()
);
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java
new file mode 100644
index 0000000000000..fa93c0aa1c375
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.compute.data.MockBlockFactory;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public abstract class BlockHashTestCase extends ESTestCase {
+
+ final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
+ final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
+ final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
+
+ // A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
+ private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
+ CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
+ when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
+ return breakerService;
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java
index 088e791348840..ede2d68ca2367 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java
@@ -11,11 +11,7 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.MockBigArrays;
-import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BooleanBlock;
@@ -26,7 +22,6 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
-import org.elasticsearch.compute.data.MockBlockFactory;
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
import org.elasticsearch.compute.data.Page;
@@ -34,8 +29,6 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
-import org.elasticsearch.indices.breaker.CircuitBreakerService;
-import org.elasticsearch.test.ESTestCase;
import org.junit.After;
import java.util.ArrayList;
@@ -54,14 +47,8 @@
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-public class BlockHashTests extends ESTestCase {
-
- final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
- final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
- final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
+public class BlockHashTests extends BlockHashTestCase {
@ParametersFactory
public static List