diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java deleted file mode 100644 index 0e89d77820883..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.aggregation.blockhash; - -import org.apache.lucene.util.BytesRefBuilder; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.common.util.BytesRefHash; -import org.elasticsearch.compute.aggregation.SeenGroupIds; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.BytesRefVector; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.core.ReleasableIterator; -import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; -import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; -import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; -import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; - -import java.io.IOException; - -/** - * Base BlockHash implementation for {@code Categorize} grouping function. - */ -public abstract class AbstractCategorizeBlockHash extends BlockHash { - protected static final int NULL_ORD = 0; - - // TODO: this should probably also take an emitBatchSize - private final int channel; - private final boolean outputPartial; - protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer; - - /** - * Store whether we've seen any {@code null} values. - *

- * Null gets the {@link #NULL_ORD} ord. - *

- */ - protected boolean seenNull = false; - - AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) { - super(blockFactory); - this.channel = channel; - this.outputPartial = outputPartial; - this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( - new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())), - CategorizationPartOfSpeechDictionary.getInstance(), - 0.70f - ); - } - - protected int channel() { - return channel; - } - - @Override - public Block[] getKeys() { - return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() }; - } - - @Override - public IntVector nonEmpty() { - return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory); - } - - @Override - public BitArray seenGroupIds(BigArrays bigArrays) { - return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); - } - - @Override - public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { - throw new UnsupportedOperationException(); - } - - /** - * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories. - */ - private Block buildIntermediateBlock() { - if (categorizer.getCategoryCount() == 0) { - return blockFactory.newConstantNullBlock(seenNull ? 1 : 0); - } - try (BytesStreamOutput out = new BytesStreamOutput()) { - // TODO be more careful here. - out.writeBoolean(seenNull); - out.writeVInt(categorizer.getCategoryCount()); - for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - category.writeTo(out); - } - // We're returning a block with N positions just because the Page must have all blocks with the same position count! - int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0); - return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private Block buildFinalBlock() { - BytesRefBuilder scratch = new BytesRefBuilder(); - - if (seenNull) { - try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { - result.appendNull(); - for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); - result.appendBytesRef(scratch.get()); - scratch.clear(); - } - return result.build(); - } - } - - try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { - for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { - scratch.copyChars(category.getRegex()); - result.appendBytesRef(scratch.get()); - scratch.clear(); - } - return result.build().asBlock(); - } - } -} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index ea76c3bd0a0aa..30afa7ae3128d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -180,9 +180,7 @@ public static BlockHash buildCategorizeBlockHash( throw new IllegalArgumentException("only a single CATEGORIZE group can used"); } - return aggregatorMode.isInputPartial() - ? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial()) - : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry); + return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry); } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java new file mode 100644 index 0000000000000..35c6faf84e623 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java @@ -0,0 +1,309 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation.blockhash; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.SeenGroupIds; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; +import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Base BlockHash implementation for {@code Categorize} grouping function. + */ +public class CategorizeBlockHash extends BlockHash { + + private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer( + List.of() + ); + private static final int NULL_ORD = 0; + + // TODO: this should probably also take an emitBatchSize + private final int channel; + private final AggregatorMode aggregatorMode; + private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; + + private final CategorizeEvaluator evaluator; + + /** + * Store whether we've seen any {@code null} values. + *

+ * Null gets the {@link #NULL_ORD} ord. + *

+ */ + private boolean seenNull = false; + + CategorizeBlockHash(BlockFactory blockFactory, int channel, AggregatorMode aggregatorMode, AnalysisRegistry analysisRegistry) { + super(blockFactory); + + this.channel = channel; + this.aggregatorMode = aggregatorMode; + + this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( + new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())), + CategorizationPartOfSpeechDictionary.getInstance(), + 0.70f + ); + + if (aggregatorMode.isInputPartial() == false) { + CategorizationAnalyzer analyzer; + try { + Objects.requireNonNull(analysisRegistry); + analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG); + } catch (Exception e) { + categorizer.close(); + throw new RuntimeException(e); + } + this.evaluator = new CategorizeEvaluator(analyzer); + } else { + this.evaluator = null; + } + } + + @Override + public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { + if (aggregatorMode.isInputPartial() == false) { + addInitial(page, addInput); + } else { + addIntermediate(page, addInput); + } + } + + @Override + public Block[] getKeys() { + return new Block[] { aggregatorMode.isOutputPartial() ? buildIntermediateBlock() : buildFinalBlock() }; + } + + @Override + public IntVector nonEmpty() { + return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory); + } + + @Override + public BitArray seenGroupIds(BigArrays bigArrays) { + return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays); + } + + @Override + public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + Releasables.close(evaluator, categorizer); + } + + /** + * Adds initial (raw) input to the state. + */ + private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) { + try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) { + addInput.add(0, result); + } + } + + /** + * Adds intermediate state to the state. + */ + private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) { + if (page.getPositionCount() == 0) { + return; + } + BytesRefBlock categorizerState = page.getBlock(channel); + if (categorizerState.areAllValuesNull()) { + seenNull = true; + try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) { + addInput.add(0, newIds); + } + return; + } + + Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef())); + try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) { + int fromId = idMap.containsKey(0) ? 0 : 1; + int toId = fromId + idMap.size(); + for (int i = fromId; i < toId; i++) { + newIdsBuilder.appendInt(idMap.get(i)); + } + try (IntBlock newIds = newIdsBuilder.build()) { + addInput.add(0, newIds); + } + } + } + + /** + * Read intermediate state from a block. + * + * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}. + */ + private Map readIntermediate(BytesRef bytes) { + Map idMap = new HashMap<>(); + try (StreamInput in = new BytesArray(bytes).streamInput()) { + if (in.readBoolean()) { + seenNull = true; + idMap.put(NULL_ORD, NULL_ORD); + } + int count = in.readVInt(); + for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) { + int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId(); + // +1 because the 0 ordinal is reserved for null + idMap.put(oldCategoryId + 1, newCategoryId + 1); + } + return idMap; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories. + */ + private Block buildIntermediateBlock() { + if (categorizer.getCategoryCount() == 0) { + return blockFactory.newConstantNullBlock(seenNull ? 1 : 0); + } + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.writeBoolean(seenNull); + out.writeVInt(categorizer.getCategoryCount()); + for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { + category.writeTo(out); + } + // We're returning a block with N positions just because the Page must have all blocks with the same position count! + int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0); + return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Block buildFinalBlock() { + BytesRefBuilder scratch = new BytesRefBuilder(); + + if (seenNull) { + try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { + result.appendNull(); + for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { + scratch.copyChars(category.getRegex()); + result.appendBytesRef(scratch.get()); + scratch.clear(); + } + return result.build(); + } + } + + try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { + for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { + scratch.copyChars(category.getRegex()); + result.appendBytesRef(scratch.get()); + scratch.clear(); + } + return result.build().asBlock(); + } + } + + /** + * Similar implementation to an Evaluator. + */ + private final class CategorizeEvaluator implements Releasable { + private final CategorizationAnalyzer analyzer; + + CategorizeEvaluator(CategorizationAnalyzer analyzer) { + this.analyzer = analyzer; + } + + Block eval(BytesRefBlock vBlock) { + BytesRefVector vVector = vBlock.asVector(); + if (vVector == null) { + return eval(vBlock.getPositionCount(), vBlock); + } + IntVector vector = eval(vBlock.getPositionCount(), vVector); + return vector.asBlock(); + } + + IntBlock eval(int positionCount, BytesRefBlock vBlock) { + try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) { + BytesRef vScratch = new BytesRef(); + for (int p = 0; p < positionCount; p++) { + if (vBlock.isNull(p)) { + seenNull = true; + result.appendInt(NULL_ORD); + continue; + } + int first = vBlock.getFirstValueIndex(p); + int count = vBlock.getValueCount(p); + if (count == 1) { + result.appendInt(process(vBlock.getBytesRef(first, vScratch))); + continue; + } + int end = first + count; + result.beginPositionEntry(); + for (int i = first; i < end; i++) { + result.appendInt(process(vBlock.getBytesRef(i, vScratch))); + } + result.endPositionEntry(); + } + return result.build(); + } + } + + IntVector eval(int positionCount, BytesRefVector vVector) { + try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) { + BytesRef vScratch = new BytesRef(); + for (int p = 0; p < positionCount; p++) { + result.appendInt(p, process(vVector.getBytesRef(p, vScratch))); + } + return result.build(); + } + } + + int process(BytesRef v) { + var category = categorizer.computeCategory(v.utf8ToString(), analyzer); + if (category == null) { + seenNull = true; + return NULL_ORD; + } + return category.getId() + 1; + } + + @Override + public void close() { + analyzer.close(); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java deleted file mode 100644 index 47dd7f650dffa..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.aggregation.blockhash; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.BytesRefVector; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.index.analysis.AnalysisRegistry; -import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; -import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; - -import java.io.IOException; -import java.util.List; - -/** - * BlockHash implementation for {@code Categorize} grouping function. - *

- * This implementation expects rows, and can't deserialize intermediate states coming from other nodes. - *

- */ -public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash { - private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer( - List.of() - ); - - private final CategorizeEvaluator evaluator; - - CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) { - super(blockFactory, channel, outputPartial); - - CategorizationAnalyzer analyzer; - try { - analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG); - } catch (IOException e) { - categorizer.close(); - throw new RuntimeException(e); - } - - this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory); - } - - @Override - public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { - try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) { - addInput.add(0, result); - } - } - - @Override - public void close() { - evaluator.close(); - } - - /** - * Similar implementation to an Evaluator. - */ - public final class CategorizeEvaluator implements Releasable { - private final CategorizationAnalyzer analyzer; - - private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; - - private final BlockFactory blockFactory; - - public CategorizeEvaluator( - CategorizationAnalyzer analyzer, - TokenListCategorizer.CloseableTokenListCategorizer categorizer, - BlockFactory blockFactory - ) { - this.analyzer = analyzer; - this.categorizer = categorizer; - this.blockFactory = blockFactory; - } - - public Block eval(BytesRefBlock vBlock) { - BytesRefVector vVector = vBlock.asVector(); - if (vVector == null) { - return eval(vBlock.getPositionCount(), vBlock); - } - IntVector vector = eval(vBlock.getPositionCount(), vVector); - return vector.asBlock(); - } - - public IntBlock eval(int positionCount, BytesRefBlock vBlock) { - try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) { - BytesRef vScratch = new BytesRef(); - for (int p = 0; p < positionCount; p++) { - if (vBlock.isNull(p)) { - seenNull = true; - result.appendInt(NULL_ORD); - continue; - } - int first = vBlock.getFirstValueIndex(p); - int count = vBlock.getValueCount(p); - if (count == 1) { - result.appendInt(process(vBlock.getBytesRef(first, vScratch))); - continue; - } - int end = first + count; - result.beginPositionEntry(); - for (int i = first; i < end; i++) { - result.appendInt(process(vBlock.getBytesRef(i, vScratch))); - } - result.endPositionEntry(); - } - return result.build(); - } - } - - public IntVector eval(int positionCount, BytesRefVector vVector) { - try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) { - BytesRef vScratch = new BytesRef(); - for (int p = 0; p < positionCount; p++) { - result.appendInt(p, process(vVector.getBytesRef(p, vScratch))); - } - return result.build(); - } - } - - private int process(BytesRef v) { - var category = categorizer.computeCategory(v.utf8ToString(), analyzer); - if (category == null) { - seenNull = true; - return NULL_ORD; - } - return category.getId() + 1; - } - - @Override - public void close() { - Releasables.closeExpectNoException(analyzer, categorizer); - } - } -} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java deleted file mode 100644 index c774d3b26049d..0000000000000 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.compute.aggregation.blockhash; - -import org.apache.lucene.util.BytesRef; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BytesRefBlock; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -/** - * BlockHash implementation for {@code Categorize} grouping function. - *

- * This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}. - *

- */ -public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash { - - CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) { - super(blockFactory, channel, outputPartial); - } - - @Override - public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { - if (page.getPositionCount() == 0) { - // No categories - return; - } - BytesRefBlock categorizerState = page.getBlock(channel()); - if (categorizerState.areAllValuesNull()) { - seenNull = true; - try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) { - addInput.add(0, newIds); - } - return; - } - - Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef())); - try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) { - int fromId = idMap.containsKey(0) ? 0 : 1; - int toId = fromId + idMap.size(); - for (int i = fromId; i < toId; i++) { - newIdsBuilder.appendInt(idMap.get(i)); - } - try (IntBlock newIds = newIdsBuilder.build()) { - addInput.add(0, newIds); - } - } - } - - /** - * Read intermediate state from a block. - * - * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}. - */ - private Map readIntermediate(BytesRef bytes) { - Map idMap = new HashMap<>(); - try (StreamInput in = new BytesArray(bytes).streamInput()) { - if (in.readBoolean()) { - seenNull = true; - idMap.put(NULL_ORD, NULL_ORD); - } - int count = in.readVInt(); - for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) { - int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId(); - // +1 because the 0 ordinal is reserved for null - idMap.put(oldCategoryId + 1, newCategoryId + 1); - } - return idMap; - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public void close() { - categorizer.close(); - } -} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 8a3c723557151..3c47e85a4a9c8 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -95,7 +95,7 @@ public void testCategorizeRaw() { page = new Page(builder.build()); } - try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) { + try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) { hash.add(page, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { @@ -168,8 +168,8 @@ public void testCategorizeIntermediate() { // Fill intermediatePages with the intermediate state from the raw hashes try ( - BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry); - BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry); + BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); + BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry); ) { rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() { @Override @@ -226,7 +226,7 @@ public void close() { page2.releaseBlocks(); } - try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) { + try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) { intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() { @Override public void add(int positionOffset, IntBlock groupIds) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 31b603ecef889..63b5073c2217a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -32,12 +32,8 @@ * This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc). *

*

- * For the implementation, see: + * For the implementation, see {@link org.elasticsearch.compute.aggregation.blockhash.CategorizeBlockHash} *

- *
    - *
  • {@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}
  • - *
  • {@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}
  • - *
*/ public class Categorize extends GroupingFunction implements Validatable { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(