Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ base {
dependencies {
compileOnly project(':server')
compileOnly project('ann')
compileOnly project(xpackModule('core'))
compileOnly project(xpackModule('ml'))
annotationProcessor project('gen')
implementation 'com.carrotsearch:hppc:0.8.1'

testImplementation project(':test:framework')
testImplementation(project(':modules:analysis-common'))
testImplementation(project(':test:framework'))
testImplementation(project(xpackModule('esql-core')))
testImplementation(project(xpackModule('core')))
testImplementation(project(xpackModule('ml')))
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/esql/compute/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
requires org.elasticsearch.ml;
requires org.elasticsearch.tdigest;
requires org.elasticsearch.geo;
requires org.elasticsearch.xcore;
requires hppc;

exports org.elasticsearch.compute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.index.analysis.AnalysisRegistry;

import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -169,14 +170,19 @@ public static BlockHash buildPackedValuesBlockHash(List<GroupSpec> groups, Block
/**
* Builds a BlockHash for the Categorize grouping function.
*/
public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
public static BlockHash buildCategorizeBlockHash(
List<GroupSpec> groups,
AggregatorMode aggregatorMode,
BlockFactory blockFactory,
AnalysisRegistry analysisRegistry
) {
if (groups.size() != 1) {
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
}

return aggregatorMode.isInputPartial()
? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
: new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial(), analysisRegistry);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.compute.aggregation.blockhash;

import org.apache.lucene.analysis.core.WhitespaceTokenizer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
Expand All @@ -19,33 +18,38 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.CustomAnalyzer;
import org.elasticsearch.index.analysis.TokenFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

import java.io.IOException;
import java.util.List;

/**
* BlockHash implementation for {@code Categorize} grouping function.
* <p>
* This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
* </p>
*/
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
private static final CategorizationAnalyzerConfig ANALYZER_CONFIG = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(
List.of()
);

private final CategorizeEvaluator evaluator;

CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial, AnalysisRegistry analysisRegistry) {
super(blockFactory, channel, outputPartial);
CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
// TODO: should be the same analyzer as used in Production
new CustomAnalyzer(
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
new CharFilterFactory[0],
new TokenFilterFactory[0]
),
true
);

CategorizationAnalyzer analyzer;
try {
analyzer = new CategorizationAnalyzer(analysisRegistry, ANALYZER_CONFIG);
} catch (IOException e) {
categorizer.close();
throw new RuntimeException(e);
}

this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand All @@ -42,14 +43,15 @@ public record HashAggregationOperatorFactory(
List<BlockHash.GroupSpec> groups,
AggregatorMode aggregatorMode,
List<GroupingAggregator.Factory> aggregators,
int maxPageSize
int maxPageSize,
AnalysisRegistry analysisRegistry
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import static org.elasticsearch.compute.data.BlockTestUtils.append;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.in;

/**
* Shared tests for testing grouped aggregations.
Expand Down Expand Up @@ -107,7 +106,8 @@ private Operator.OperatorFactory simpleWithMode(
List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
mode,
List.of(supplier.groupingAggregatorFactory(mode)),
randomPageSize()
randomPageSize(),
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
package org.elasticsearch.compute.aggregation.blockhash;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.collect.Iterators;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockBigArrays;
Expand All @@ -35,7 +37,15 @@
import org.elasticsearch.compute.operator.LocalSourceOperator;
import org.elasticsearch.compute.operator.PageConsumerOperator;
import org.elasticsearch.core.Releasables;

import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.indices.analysis.AnalysisModule;
import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -50,6 +60,19 @@

public class CategorizeBlockHashTests extends BlockHashTestCase {

private AnalysisRegistry analysisRegistry;

@Before
private void initAnalysisRegistry() throws IOException {
analysisRegistry = new AnalysisModule(
TestEnvironment.newEnvironment(
Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
),
List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
new StablePluginsRegistry()
).getAnalysisRegistry();
}

public void testCategorizeRaw() {
final Page page;
boolean withNull = randomBoolean();
Expand All @@ -72,7 +95,7 @@ public void testCategorizeRaw() {
page = new Page(builder.build());
}

try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry)) {
hash.add(page, new GroupingAggregatorFunction.AddInput() {
@Override
public void add(int positionOffset, IntBlock groupIds) {
Expand Down Expand Up @@ -145,8 +168,8 @@ public void testCategorizeIntermediate() {

// Fill intermediatePages with the intermediate state from the raw hashes
try (
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, analysisRegistry);
) {
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
@Override
Expand Down Expand Up @@ -267,14 +290,16 @@ public void testCategorize_withDriver() {
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
) {
textsBuilder.appendBytesRef(new BytesRef("a"));
textsBuilder.appendBytesRef(new BytesRef("b"));
// Note that just using "a" or "aaa" doesn't work, because the ml_standard
// tokenizer drops numbers, including hexadecimal ones.
textsBuilder.appendBytesRef(new BytesRef("aaazz"));
textsBuilder.appendBytesRef(new BytesRef("bbbzz"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
textsBuilder.appendBytesRef(new BytesRef("c"));
textsBuilder.appendBytesRef(new BytesRef("d"));
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
countsBuilder.appendLong(1);
countsBuilder.appendLong(2);
countsBuilder.appendLong(800);
Expand All @@ -293,10 +318,10 @@ public void testCategorize_withDriver() {
) {
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
textsBuilder.appendBytesRef(new BytesRef("c"));
textsBuilder.appendBytesRef(new BytesRef("ccczz"));
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
textsBuilder.appendBytesRef(new BytesRef("d"));
textsBuilder.appendBytesRef(new BytesRef("e"));
textsBuilder.appendBytesRef(new BytesRef("dddzz"));
textsBuilder.appendBytesRef(new BytesRef("eeezz"));
countsBuilder.appendLong(9);
countsBuilder.appendLong(90);
countsBuilder.appendLong(3);
Expand All @@ -320,7 +345,8 @@ public void testCategorize_withDriver() {
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
),
16 * 1024
16 * 1024,
analysisRegistry
).get(driverContext)
),
new PageConsumerOperator(intermediateOutput::add),
Expand All @@ -339,7 +365,8 @@ public void testCategorize_withDriver() {
new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
),
16 * 1024
16 * 1024,
analysisRegistry
).get(driverContext)
),
new PageConsumerOperator(intermediateOutput::add),
Expand All @@ -360,7 +387,8 @@ public void testCategorize_withDriver() {
new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
),
16 * 1024
16 * 1024,
analysisRegistry
).get(driverContext)
),
new PageConsumerOperator(finalOutput::add),
Expand All @@ -385,15 +413,15 @@ public void testCategorize_withDriver() {
sums,
equalTo(
Map.of(
".*?a.*?",
".*?aaazz.*?",
1L,
".*?b.*?",
".*?bbbzz.*?",
2L,
".*?c.*?",
".*?ccczz.*?",
33L,
".*?d.*?",
".*?dddzz.*?",
44L,
".*?e.*?",
".*?eeezz.*?",
5L,
".*?words.+?words.+?words.+?goodbye.*?",
8888L,
Expand All @@ -406,15 +434,15 @@ public void testCategorize_withDriver() {
maxs,
equalTo(
Map.of(
".*?a.*?",
".*?aaazz.*?",
1L,
".*?b.*?",
".*?bbbzz.*?",
2L,
".*?c.*?",
".*?ccczz.*?",
30L,
".*?d.*?",
".*?dddzz.*?",
40L,
".*?e.*?",
".*?eeezz.*?",
5L,
".*?words.+?words.+?words.+?goodbye.*?",
8000L,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) {
new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)
),
randomPageSize()
randomPageSize(),
null
);
}

Expand Down
Loading