diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
deleted file mode 100644
index 0e89d77820883..0000000000000
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.compute.aggregation.blockhash;
-
-import org.apache.lucene.util.BytesRefBuilder;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
-import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.BitArray;
-import org.elasticsearch.common.util.BytesRefHash;
-import org.elasticsearch.compute.aggregation.SeenGroupIds;
-import org.elasticsearch.compute.data.Block;
-import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.compute.data.BytesRefVector;
-import org.elasticsearch.compute.data.IntBlock;
-import org.elasticsearch.compute.data.IntVector;
-import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.core.ReleasableIterator;
-import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
-import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
-import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
-import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
-
-import java.io.IOException;
-
-/**
- * Base BlockHash implementation for {@code Categorize} grouping function.
- */
-public abstract class AbstractCategorizeBlockHash extends BlockHash {
- protected static final int NULL_ORD = 0;
-
- // TODO: this should probably also take an emitBatchSize
- private final int channel;
- private final boolean outputPartial;
- protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
-
- /**
- * Store whether we've seen any {@code null} values.
- *
- * Null gets the {@link #NULL_ORD} ord.
- *
- */
- protected boolean seenNull = false;
-
- AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
- super(blockFactory);
- this.channel = channel;
- this.outputPartial = outputPartial;
- this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
- new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
- CategorizationPartOfSpeechDictionary.getInstance(),
- 0.70f
- );
- }
-
- protected int channel() {
- return channel;
- }
-
- @Override
- public Block[] getKeys() {
- return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
- }
-
- @Override
- public IntVector nonEmpty() {
- return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory);
- }
-
- @Override
- public BitArray seenGroupIds(BigArrays bigArrays) {
- return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
- }
-
- @Override
- public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) {
- throw new UnsupportedOperationException();
- }
-
- /**
- * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
- */
- private Block buildIntermediateBlock() {
- if (categorizer.getCategoryCount() == 0) {
- return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
- }
- try (BytesStreamOutput out = new BytesStreamOutput()) {
- // TODO be more careful here.
- out.writeBoolean(seenNull);
- out.writeVInt(categorizer.getCategoryCount());
- for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
- category.writeTo(out);
- }
- // We're returning a block with N positions just because the Page must have all blocks with the same position count!
- int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
- return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- private Block buildFinalBlock() {
- BytesRefBuilder scratch = new BytesRefBuilder();
-
- if (seenNull) {
- try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
- result.appendNull();
- for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
- scratch.copyChars(category.getRegex());
- result.appendBytesRef(scratch.get());
- scratch.clear();
- }
- return result.build();
- }
- }
-
- try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
- for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
- scratch.copyChars(category.getRegex());
- result.appendBytesRef(scratch.get());
- scratch.clear();
- }
- return result.build().asBlock();
- }
- }
-}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
index ea76c3bd0a0aa..30afa7ae3128d 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
@@ -180,9 +180,7 @@ public static BlockHash buildCategorizeBlockHash(
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
}
- return aggregatorMode.isInputPartial()
- ? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
- : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry);
+ return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
}
/**
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java
new file mode 100644
index 0000000000000..35c6faf84e623
--- /dev/null
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java
@@ -0,0 +1,309 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefBuilder;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.ReleasableIterator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.analysis.AnalysisRegistry;
+import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
+import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
+import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
+import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Base BlockHash implementation for {@code Categorize} grouping function.
+ */
+public class CategorizeBlockHash extends BlockHash {
+
+ private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
+ List.of()
+ );
+ private static final int NULL_ORD = 0;
+
+ // TODO: this should probably also take an emitBatchSize
+ private final int channel;
+ private final AggregatorMode aggregatorMode;
+ private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
+
+ private final CategorizeEvaluator evaluator;
+
+ /**
+ * Store whether we've seen any {@code null} values.
+ *
+ * Null gets the {@link #NULL_ORD} ord.
+ *
+ */
+ private boolean seenNull = false;
+
+ CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) {
+ super(blockFactory);
+
+ this.channel = channel;
+ this.aggregatorMode = aggregatorMode;
+
+ this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
+ new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
+ CategorizationPartOfSpeechDictionary.getInstance(),
+ 0.70f
+ );
+
+ if (aggregatorMode.isInputPartial() == false) {
+ CategorizationAnalyzer analyzer;
+ try {
+ Objects.requireNonNull(analysisRegistry);
+ analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
+ } catch (Exception e) {
+ categorizer.close();
+ throw new RuntimeException(e);
+ }
+ this.evaluator = new CategorizeEvaluator(analyzer);
+ } else {
+ this.evaluator = null;
+ }
+ }
+
+ @Override
+ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+ if (aggregatorMode.isInputPartial() == false) {
+ addInitial(page, addInput);
+ } else {
+ addIntermediate(page, addInput);
+ }
+ }
+
+ @Override
+ public Block[] getKeys() {
+ return new Block[] { aggregatorMode.isOutputPartial() ? buildIntermediateBlock() : buildFinalBlock() };
+ }
+
+ @Override
+ public IntVector nonEmpty() {
+ return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory);
+ }
+
+ @Override
+ public BitArray seenGroupIds(BigArrays bigArrays) {
+ return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
+ }
+
+ @Override
+ public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() {
+ Releasables.close(evaluator, categorizer);
+ }
+
+ /**
+ * Adds initial (raw) input to the state.
+ */
+ private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) {
+ try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) {
+ addInput.add(0, result);
+ }
+ }
+
+ /**
+ * Adds intermediate state to the state.
+ */
+ private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) {
+ if (page.getPositionCount() == 0) {
+ return;
+ }
+ BytesRefBlock categorizerState = page.getBlock(channel);
+ if (categorizerState.areAllValuesNull()) {
+ seenNull = true;
+ try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
+ addInput.add(0, newIds);
+ }
+ return;
+ }
+
+ Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
+ try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
+ int fromId = idMap.containsKey(0) ? 0 : 1;
+ int toId = fromId + idMap.size();
+ for (int i = fromId; i < toId; i++) {
+ newIdsBuilder.appendInt(idMap.get(i));
+ }
+ try (IntBlock newIds = newIdsBuilder.build()) {
+ addInput.add(0, newIds);
+ }
+ }
+ }
+
+ /**
+ * Read intermediate state from a block.
+ *
+ * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
+ */
+ private Map readIntermediate(BytesRef bytes) {
+ Map idMap = new HashMap<>();
+ try (StreamInput in = new BytesArray(bytes).streamInput()) {
+ if (in.readBoolean()) {
+ seenNull = true;
+ idMap.put(NULL_ORD, NULL_ORD);
+ }
+ int count = in.readVInt();
+ for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
+ int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
+ // +1 because the 0 ordinal is reserved for null
+ idMap.put(oldCategoryId + 1, newCategoryId + 1);
+ }
+ return idMap;
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
+ */
+ private Block buildIntermediateBlock() {
+ if (categorizer.getCategoryCount() == 0) {
+ return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
+ }
+ try (BytesStreamOutput out = new BytesStreamOutput()) {
+ out.writeBoolean(seenNull);
+ out.writeVInt(categorizer.getCategoryCount());
+ for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+ category.writeTo(out);
+ }
+ // We're returning a block with N positions just because the Page must have all blocks with the same position count!
+ int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
+ return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private Block buildFinalBlock() {
+ BytesRefBuilder scratch = new BytesRefBuilder();
+
+ if (seenNull) {
+ try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
+ result.appendNull();
+ for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+ scratch.copyChars(category.getRegex());
+ result.appendBytesRef(scratch.get());
+ scratch.clear();
+ }
+ return result.build();
+ }
+ }
+
+ try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
+ for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+ scratch.copyChars(category.getRegex());
+ result.appendBytesRef(scratch.get());
+ scratch.clear();
+ }
+ return result.build().asBlock();
+ }
+ }
+
+ /**
+ * Similar implementation to an Evaluator.
+ */
+ private final class CategorizeEvaluator implements Releasable {
+ private final CategorizationAnalyzer analyzer;
+
+ CategorizeEvaluator(CategorizationAnalyzer analyzer) {
+ this.analyzer = analyzer;
+ }
+
+ Block eval(BytesRefBlock vBlock) {
+ BytesRefVector vVector = vBlock.asVector();
+ if (vVector == null) {
+ return eval(vBlock.getPositionCount(), vBlock);
+ }
+ IntVector vector = eval(vBlock.getPositionCount(), vVector);
+ return vector.asBlock();
+ }
+
+ IntBlock eval(int positionCount, BytesRefBlock vBlock) {
+ try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
+ BytesRef vScratch = new BytesRef();
+ for (int p = 0; p < positionCount; p++) {
+ if (vBlock.isNull(p)) {
+ seenNull = true;
+ result.appendInt(NULL_ORD);
+ continue;
+ }
+ int first = vBlock.getFirstValueIndex(p);
+ int count = vBlock.getValueCount(p);
+ if (count == 1) {
+ result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
+ continue;
+ }
+ int end = first + count;
+ result.beginPositionEntry();
+ for (int i = first; i < end; i++) {
+ result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
+ }
+ result.endPositionEntry();
+ }
+ return result.build();
+ }
+ }
+
+ IntVector eval(int positionCount, BytesRefVector vVector) {
+ try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
+ BytesRef vScratch = new BytesRef();
+ for (int p = 0; p < positionCount; p++) {
+ result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
+ }
+ return result.build();
+ }
+ }
+
+ int process(BytesRef v) {
+ var category = categorizer.computeCategory(v.utf8ToString(), analyzer);
+ if (category == null) {
+ seenNull = true;
+ return NULL_ORD;
+ }
+ return category.getId() + 1;
+ }
+
+ @Override
+ public void close() {
+ analyzer.close();
+ }
+ }
+}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
deleted file mode 100644
index 47dd7f650dffa..0000000000000
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.compute.aggregation.blockhash;
-
-import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
-import org.elasticsearch.compute.data.Block;
-import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.compute.data.BytesRefVector;
-import org.elasticsearch.compute.data.IntBlock;
-import org.elasticsearch.compute.data.IntVector;
-import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.core.Releasable;
-import org.elasticsearch.core.Releasables;
-import org.elasticsearch.index.analysis.AnalysisRegistry;
-import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
-import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
-import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
-
-import java.io.IOException;
-import java.util.List;
-
-/**
- * BlockHash implementation for {@code Categorize} grouping function.
- *
- * This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
- *
- */
-public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
- private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
- List.of()
- );
-
- private final CategorizeEvaluator evaluator;
-
- CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) {
- super(blockFactory, channel, outputPartial);
-
- CategorizationAnalyzer analyzer;
- try {
- analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
- } catch (IOException e) {
- categorizer.close();
- throw new RuntimeException(e);
- }
-
- this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
- }
-
- @Override
- public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
- try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
- addInput.add(0, result);
- }
- }
-
- @Override
- public void close() {
- evaluator.close();
- }
-
- /**
- * Similar implementation to an Evaluator.
- */
- public final class CategorizeEvaluator implements Releasable {
- private final CategorizationAnalyzer analyzer;
-
- private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
-
- private final BlockFactory blockFactory;
-
- public CategorizeEvaluator(
- CategorizationAnalyzer analyzer,
- TokenListCategorizer.CloseableTokenListCategorizer categorizer,
- BlockFactory blockFactory
- ) {
- this.analyzer = analyzer;
- this.categorizer = categorizer;
- this.blockFactory = blockFactory;
- }
-
- public Block eval(BytesRefBlock vBlock) {
- BytesRefVector vVector = vBlock.asVector();
- if (vVector == null) {
- return eval(vBlock.getPositionCount(), vBlock);
- }
- IntVector vector = eval(vBlock.getPositionCount(), vVector);
- return vector.asBlock();
- }
-
- public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
- try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
- BytesRef vScratch = new BytesRef();
- for (int p = 0; p < positionCount; p++) {
- if (vBlock.isNull(p)) {
- seenNull = true;
- result.appendInt(NULL_ORD);
- continue;
- }
- int first = vBlock.getFirstValueIndex(p);
- int count = vBlock.getValueCount(p);
- if (count == 1) {
- result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
- continue;
- }
- int end = first + count;
- result.beginPositionEntry();
- for (int i = first; i < end; i++) {
- result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
- }
- result.endPositionEntry();
- }
- return result.build();
- }
- }
-
- public IntVector eval(int positionCount, BytesRefVector vVector) {
- try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
- BytesRef vScratch = new BytesRef();
- for (int p = 0; p < positionCount; p++) {
- result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
- }
- return result.build();
- }
- }
-
- private int process(BytesRef v) {
- var category = categorizer.computeCategory(v.utf8ToString(), analyzer);
- if (category == null) {
- seenNull = true;
- return NULL_ORD;
- }
- return category.getId() + 1;
- }
-
- @Override
- public void close() {
- Releasables.closeExpectNoException(analyzer, categorizer);
- }
- }
-}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
deleted file mode 100644
index c774d3b26049d..0000000000000
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.compute.aggregation.blockhash;
-
-import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
-import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.compute.data.IntBlock;
-import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-
-/**
- * BlockHash implementation for {@code Categorize} grouping function.
- *
- * This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
- *
- */
-public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
-
- CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
- super(blockFactory, channel, outputPartial);
- }
-
- @Override
- public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
- if (page.getPositionCount() == 0) {
- // No categories
- return;
- }
- BytesRefBlock categorizerState = page.getBlock(channel());
- if (categorizerState.areAllValuesNull()) {
- seenNull = true;
- try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
- addInput.add(0, newIds);
- }
- return;
- }
-
- Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
- try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
- int fromId = idMap.containsKey(0) ? 0 : 1;
- int toId = fromId + idMap.size();
- for (int i = fromId; i < toId; i++) {
- newIdsBuilder.appendInt(idMap.get(i));
- }
- try (IntBlock newIds = newIdsBuilder.build()) {
- addInput.add(0, newIds);
- }
- }
- }
-
- /**
- * Read intermediate state from a block.
- *
- * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
- */
- private Map readIntermediate(BytesRef bytes) {
- Map idMap = new HashMap<>();
- try (StreamInput in = new BytesArray(bytes).streamInput()) {
- if (in.readBoolean()) {
- seenNull = true;
- idMap.put(NULL_ORD, NULL_ORD);
- }
- int count = in.readVInt();
- for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
- int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
- // +1 because the 0 ordinal is reserved for null
- idMap.put(oldCategoryId + 1, newCategoryId + 1);
- }
- return idMap;
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void close() {
- categorizer.close();
- }
-}
diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
index 8a3c723557151..3c47e85a4a9c8 100644
--- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
+++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
@@ -95,7 +95,7 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}
- try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) {
+ try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
@@ -168,8 +168,8 @@ public void testCategorizeIntermediate() {
// Fill intermediatePages with the intermediate state from the raw hashes
try (
- BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
- BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
+ BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
+ BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry);
) {
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
@Override
@@ -226,7 +226,7 @@ public void close() {
page2.releaseBlocks();
}
- try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) {
+ try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) {
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java
index 31b603ecef889..63b5073c2217a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java
@@ -32,12 +32,8 @@
* This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc).
*
*
- * For the implementation, see:
+ * For the implementation, see {@link org.elasticsearch.compute.aggregation.blockhash.CategorizeBlockHash}
*
- *
- * - {@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}
- * - {@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}
- *
*/
public class Categorize extends GroupingFunction implements Validatable {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(