Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/114317.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 114317
summary: "ESQL: CATEGORIZE as a `BlockHash`"
area: ES|QL
type: enhancement
issues: []

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions docs/reference/esql/functions/types/categorize.asciidoc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 0 additions & 15 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,6 @@ tests:
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
method: test {p0=indices.split/40_routing_partition_size/more than 1}
issue: https://github.com/elastic/elasticsearch/issues/113841
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
method: test {categorize.Categorize SYNC}
issue: https://github.com/elastic/elasticsearch/issues/113722
- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
method: test {categorize.Categorize ASYNC}
issue: https://github.com/elastic/elasticsearch/issues/116373
- class: org.elasticsearch.kibana.KibanaThreadPoolIT
method: testBlockedThreadPoolsRejectUserRequests
issue: https://github.com/elastic/elasticsearch/issues/113939
Expand Down Expand Up @@ -254,12 +248,6 @@ tests:
- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
method: test {p0=search/380_sort_segments_on_timestamp/Test that index segments are NOT sorted on timestamp field when @timestamp field is dynamically added}
issue: https://github.com/elastic/elasticsearch/issues/116221
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {categorize.Categorize SYNC}
issue: https://github.com/elastic/elasticsearch/issues/113054
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
method: test {categorize.Categorize ASYNC}
issue: https://github.com/elastic/elasticsearch/issues/113054
- class: org.elasticsearch.ingest.common.IngestCommonClientYamlTestSuiteIT
method: test {yaml=ingest/310_reroute_processor/Test remove then add reroute processor with and without lazy rollover}
issue: https://github.com/elastic/elasticsearch/issues/116158
Expand All @@ -272,9 +260,6 @@ tests:
- class: org.elasticsearch.xpack.deprecation.DeprecationHttpIT
method: testDeprecatedSettingsReturnWarnings
issue: https://github.com/elastic/elasticsearch/issues/108628
- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
method: test {categorize.Categorize}
issue: https://github.com/elastic/elasticsearch/issues/116434
- class: org.elasticsearch.xpack.apmdata.APMYamlTestSuiteIT
method: test {yaml=/10_apm/Test template reinstallation}
issue: https://github.com/elastic/elasticsearch/issues/116445
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.aggregation.blockhash;

import org.apache.lucene.util.BytesRefBuilder;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;

import java.io.IOException;

/**
* Base BlockHash implementation for {@code Categorize} grouping function.
*/
public abstract class AbstractCategorizeBlockHash extends BlockHash {
// TODO: this should probably also take an emitBatchSize
private final int channel;
private final boolean outputPartial;
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;

AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
super(blockFactory);
this.channel = channel;
this.outputPartial = outputPartial;
this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
CategorizationPartOfSpeechDictionary.getInstance(),
0.70f
);
}

protected int channel() {
return channel;
}

@Override
public Block[] getKeys() {
return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
}

@Override
public IntVector nonEmpty() {
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
}

@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
throw new UnsupportedOperationException();
}

@Override
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
throw new UnsupportedOperationException();
}

/**
* Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
*/
private Block buildIntermediateBlock() {
if (categorizer.getCategoryCount() == 0) {
return blockFactory.newConstantNullBlock(0);
}
try (BytesStreamOutput out = new BytesStreamOutput()) {
// TODO be more careful here.
out.writeVInt(categorizer.getCategoryCount());
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
category.writeTo(out);
}
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private Block buildFinalBlock() {
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
BytesRefBuilder scratch = new BytesRefBuilder();
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
scratch.copyChars(category.getRegex());
result.appendBytesRef(scratch.get());
scratch.clear();
}
return result.build().asBlock();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.util.Int3Hash;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.common.util.LongLongHash;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.aggregation.SeenGroupIds;
import org.elasticsearch.compute.data.Block;
Expand Down Expand Up @@ -58,9 +59,7 @@
* leave a big gap, even if we never see {@code null}.
* </p>
*/
public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
public abstract class BlockHash implements Releasable, SeenGroupIds {

protected final BlockFactory blockFactory;

Expand Down Expand Up @@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
@Override
public abstract BitArray seenGroupIds(BigArrays bigArrays);

public record GroupSpec(int channel, ElementType elementType) {}
/**
* @param isCategorize Whether this group is a CATEGORIZE() or not.
* May be changed in the future when more stateful grouping functions are added.
*/
public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
public GroupSpec(int channel, ElementType elementType) {
this(channel, elementType, false);
}
}

/**
* Creates a specialized hash table that maps one or more {@link Block}s to ids.
Expand Down Expand Up @@ -159,6 +166,19 @@ public static BlockHash buildPackedValuesBlockHash(List<GroupSpec> groups, Block
return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
}

/**
* Builds a BlockHash for the Categorize grouping function.
*/
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
if (groups.size() != 1) {
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
}

return aggregatorMode.isInputPartial()
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
}

/**
* Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.aggregation.blockhash;

import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.BytesRefVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

/**
* BlockHash implementation for {@code Categorize} grouping function.
* <p>
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
* </p>
*/
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
private final CategorizeEvaluator evaluator;

CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
super(blockFactory, channel, outputPartial);
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
// TODO: should be the same analyzer as used in Production
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
);
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
}

@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
addInput.add(0, result);
}
}

@Override
public void close() {
evaluator.close();
}

/**
* Similar implementation to an Evaluator.
*/
public static final class CategorizeEvaluator implements Releasable {
private final CategorizationAnalyzer analyzer;

private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;

private final BlockFactory blockFactory;

public CategorizeEvaluator(
CategorizationAnalyzer analyzer,
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
BlockFactory blockFactory
) {
this.analyzer = analyzer;
this.categorizer = categorizer;
this.blockFactory = blockFactory;
}

public Block eval(BytesRefBlock vBlock) {
BytesRefVector vVector = vBlock.asVector();
if (vVector == null) {
return eval(vBlock.getPositionCount(), vBlock);
}
IntVector vector = eval(vBlock.getPositionCount(), vVector);
return vector.asBlock();
}

public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
for (int p = 0; p < positionCount; p++) {
if (vBlock.isNull(p)) {
result.appendNull();
continue;
}
int first = vBlock.getFirstValueIndex(p);
int count = vBlock.getValueCount(p);
if (count == 1) {
result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
continue;
}
int end = first + count;
result.beginPositionEntry();
for (int i = first; i < end; i++) {
result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
}
result.endPositionEntry();
}
return result.build();
}
}

public IntVector eval(int positionCount, BytesRefVector vVector) {
try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
BytesRef vScratch = new BytesRef();
for (int p = 0; p < positionCount; p++) {
result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
}
return result.build();
}
}

private int process(BytesRef v) {
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
}

@Override
public void close() {
Releasables.closeExpectNoException(analyzer, categorizer);
}
}
}
Loading