diff --git a/docs/changelog/112491.yaml b/docs/changelog/112491.yaml new file mode 100644 index 0000000000000..ee77a2cee79bd --- /dev/null +++ b/docs/changelog/112491.yaml @@ -0,0 +1,5 @@ +pr: 112491 +summary: "[DON'T MERGE] ES|QL Categorize Text proof of concept" +area: "ES|QL" +type: feature +issues: [] diff --git a/x-pack/plugin/esql/build.gradle b/x-pack/plugin/esql/build.gradle index 26cf53b334b1e..0225664918b7b 100644 --- a/x-pack/plugin/esql/build.gradle +++ b/x-pack/plugin/esql/build.gradle @@ -11,7 +11,7 @@ esplugin { name 'x-pack-esql' description 'The plugin that powers ESQL for Elasticsearch' classname 'org.elasticsearch.xpack.esql.plugin.EsqlPlugin' - extendedPlugins = ['x-pack-esql-core', 'lang-painless'] + extendedPlugins = ['x-pack-esql-core', 'lang-painless', 'x-pack-ml'] } base { @@ -22,6 +22,7 @@ dependencies { compileOnly project(path: xpackModule('core')) compileOnly project(':modules:lang-painless:spi') compileOnly project(xpackModule('esql-core')) + compileOnly project(xpackModule('ml')) implementation project('compute') implementation project('compute:ann') implementation project(':libs:elasticsearch-dissect') diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 971bfd39c231f..81d1a6f5360ca 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -11,11 +11,14 @@ base { dependencies { compileOnly project(':server') compileOnly project('ann') + compileOnly project(xpackModule('ml')) annotationProcessor project('gen') implementation 'com.carrotsearch:hppc:0.8.1' testImplementation project(':test:framework') testImplementation(project(xpackModule('esql-core'))) + testImplementation(project(xpackModule('core'))) + testImplementation(project(xpackModule('ml'))) } def projectDirectory = project.layout.projectDirectory diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunction.java new file mode 100644 index 0000000000000..8d3cd8f429f79 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunction.java @@ -0,0 +1,124 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link CategorizeBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class CategorizeBytesRefAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("categorize", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final CategorizeBytesRefAggregator.SingleState state; + + private final List channels; + + public CategorizeBytesRefAggregatorFunction(DriverContext driverContext, List channels, + CategorizeBytesRefAggregator.SingleState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static CategorizeBytesRefAggregatorFunction create(DriverContext driverContext, + List channels) { + return new CategorizeBytesRefAggregatorFunction(driverContext, channels, CategorizeBytesRefAggregator.initSingle(driverContext.bigArrays())); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + BytesRefBlock block = page.getBlock(channels.get(0)); + BytesRefVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(BytesRefVector vector) { + BytesRef scratch = new BytesRef(); + for (int i = 0; i < vector.getPositionCount(); i++) { + CategorizeBytesRefAggregator.combine(state, vector.getBytesRef(i, scratch)); + } + } + + private void addRawBlock(BytesRefBlock block) { + BytesRef scratch = new BytesRef(); + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + CategorizeBytesRefAggregator.combine(state, block.getBytesRef(i, scratch)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block categorizeUncast = page.getBlock(channels.get(0)); + if (categorizeUncast.areAllValuesNull()) { + return; + } + BytesRefBlock categorize = (BytesRefBlock) categorizeUncast; + assert categorize.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + CategorizeBytesRefAggregator.combineIntermediate(state, categorize); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = CategorizeBytesRefAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..ad2e8a64ca41f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link CategorizeBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class CategorizeBytesRefAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public CategorizeBytesRefAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public CategorizeBytesRefAggregatorFunction aggregator(DriverContext driverContext) { + return CategorizeBytesRefAggregatorFunction.create(driverContext, channels); + } + + @Override + public CategorizeBytesRefGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return CategorizeBytesRefGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "categorize of bytes"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..3dfd7bdcf11c4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunction.java @@ -0,0 +1,201 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link CategorizeBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class CategorizeBytesRefGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("categorize", ElementType.BYTES_REF) ); + + private final CategorizeBytesRefAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + public CategorizeBytesRefGroupingAggregatorFunction(List channels, + CategorizeBytesRefAggregator.GroupingState state, DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static CategorizeBytesRefGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new CategorizeBytesRefGroupingAggregatorFunction(channels, CategorizeBytesRefAggregator.initGrouping(driverContext.bigArrays()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); + BytesRefVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CategorizeBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + CategorizeBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + CategorizeBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + CategorizeBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block categorizeUncast = page.getBlock(channels.get(0)); + if (categorizeUncast.areAllValuesNull()) { + return; + } + BytesRefBlock categorize = (BytesRefBlock) categorizeUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + CategorizeBytesRefAggregator.combineIntermediate(state, groupId, categorize, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + CategorizeBytesRefAggregator.GroupingState inState = ((CategorizeBytesRefGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + CategorizeBytesRefAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = CategorizeBytesRefAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/module-info.java b/x-pack/plugin/esql/compute/src/main/java/module-info.java index dc8cda0fbe3c8..1739c90467c2c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/module-info.java +++ b/x-pack/plugin/esql/compute/src/main/java/module-info.java @@ -7,6 +7,7 @@ module org.elasticsearch.compute { + requires org.apache.lucene.analysis.common; requires org.apache.lucene.core; requires org.elasticsearch.base; requires org.elasticsearch.server; @@ -15,6 +16,7 @@ // required due to dependency on org.elasticsearch.common.util.concurrent.AbstractAsyncTask requires org.apache.logging.log4j; requires org.elasticsearch.logging; + requires org.elasticsearch.ml; requires org.elasticsearch.tdigest; requires org.elasticsearch.geo; requires hppc; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregator.java new file mode 100644 index 0000000000000..fd9b9d0ade898 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregator.java @@ -0,0 +1,276 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.core.WhitespaceTokenizer; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.analysis.CharFilterFactory; +import org.elasticsearch.index.analysis.CustomAnalyzer; +import org.elasticsearch.index.analysis.TokenFilterFactory; +import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; +import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +/** + * Categorizes text strings. + */ +@Aggregator({ @IntermediateState(name = "categorize", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class CategorizeBytesRefAggregator { + public static SingleState initSingle(BigArrays bigArrays) { + return new SingleState(bigArrays, createAnalyzer()); + } + + public static void combine(SingleState state, BytesRef v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + combineIntermediate(state, values, 0); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values, int valuesPosition) { + BytesRef scratch = new BytesRef(); + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + ByteArrayStreamInput in = new ByteArrayStreamInput(); + for (int i = start; i < end; i++) { + values.getBytesRef(i, scratch); + if (scratch.length == 0) { + continue; + } + in.reset(scratch.bytes, scratch.offset, scratch.length); + try { + state.categorizer.mergeWireCategory(new SerializableTokenListCategory(in)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toFinal(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays) { + return new GroupingState(bigArrays, createAnalyzer()); + } + + public static void combine(GroupingState state, int groupId, BytesRef v) { + state.getState(groupId).add(v); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + combineIntermediate(state.getState(groupId), values, valuesPosition); + } + + public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { + TokenListCategorizer currentCategorizer = current.getState(currentGroupId).categorizer; + TokenListCategorizer stateCategorizer = state.getState(statePosition).categorizer; + for (InternalCategorizationAggregation.Bucket bucket : stateCategorizer.toOrderedBuckets(stateCategorizer.getCategoryCount())) { + currentCategorizer.mergeWireCategory(bucket.getSerializableCategory()); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toFinal(driverContext.blockFactory(), selected); + } + + private static CategorizationAnalyzer createAnalyzer() { + // TODO: add correct analyzer, see also: CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer + return new CategorizationAnalyzer( + new CustomAnalyzer( + TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new), + new CharFilterFactory[0], + new TokenFilterFactory[0] + ), + true + ); + } + + public static class SingleState implements Releasable { + + private final CategorizationAnalyzer analyzer; + private final CategorizationBytesRefHash bytesRefHash; + private final TokenListCategorizer categorizer; + + private SingleState(BigArrays bigArrays, CategorizationAnalyzer analyzer) { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays)); + categorizer = new TokenListCategorizer(bytesRefHash, CategorizationPartOfSpeechDictionary.getInstance(), 0.70f); + this.analyzer = analyzer; + } + + void add(BytesRef v) { + if (v == null || v.length == 0) { + return; + } + String s = v.utf8ToString(); + try (TokenStream ts = analyzer.tokenStream("text", s)) { + categorizer.computeCategory(ts, s.length(), 1); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + if (categorizer.getCategoryCount() == 0) { + return blockFactory.newConstantNullBlock(1); + } + try (BytesRefBlock.Builder block = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { + addToBlockIntermediate(block); + return block.build(); + } + } + + void addToBlockIntermediate(BytesRefBlock.Builder block) { + if (categorizer.getCategoryCount() == 0) { + block.appendNull(); + return; + } + block.beginPositionEntry(); + for (InternalCategorizationAggregation.Bucket bucket : categorizer.toOrderedBuckets(categorizer.getCategoryCount())) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput out = new OutputStreamStreamOutput(baos); + try { + bucket.writeTo(out); + } catch (IOException e) { + throw new RuntimeException(e); + } + block.appendBytesRef(new BytesRef(baos.toByteArray())); + } + block.endPositionEntry(); + } + + Block toFinal(BlockFactory blockFactory) { + if (categorizer.getCategoryCount() == 0) { + return blockFactory.newConstantNullBlock(1); + } + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) { + addToBlockFinal(builder); + return builder.build(); + } + } + + void addToBlockFinal(BytesRefBlock.Builder block) { + if (categorizer.getCategoryCount() == 0) { + block.appendNull(); + return; + } + block.beginPositionEntry(); + for (InternalCategorizationAggregation.Bucket bucket : categorizer.toOrderedBuckets(categorizer.getCategoryCount())) { + // TODO: find something better for this semi-colon-separated string. + String result = String.join( + ";", + bucket.getKeyAsString(), + bucket.getSerializableCategory().getRegex(), + Long.toString(bucket.getDocCount()) + ); + block.appendBytesRef(new BytesRef(result.getBytes(StandardCharsets.UTF_8))); + } + block.endPositionEntry(); + } + + @Override + public void close() { + Releasables.close(bytesRefHash); + } + } + + public static class GroupingState implements Releasable { + + private final BigArrays bigArrays; + private final CategorizationAnalyzer analyzer; + private final Map states; + + private GroupingState(BigArrays bigArrays, CategorizationAnalyzer analyzer) { + this.bigArrays = bigArrays; + this.analyzer = analyzer; + states = new HashMap<>(); + } + + SingleState getState(int groupId) { + return states.computeIfAbsent(groupId, key -> new SingleState(bigArrays, analyzer)); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + if (states.isEmpty()) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + try (BytesRefBlock.Builder block = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + SingleState state = states.get(selected.getInt(s)); + if (state == null) { + block.appendNull(); + } else { + state.addToBlockIntermediate(block); + } + } + return block.build(); + } + } + + Block toFinal(BlockFactory blockFactory, IntVector selected) { + if (states.isEmpty()) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + try (BytesRefBlock.Builder block = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + SingleState state = states.get(selected.getInt(s)); + if (state == null) { + block.appendNull(); + } else { + state.addToBlockFinal(block); + } + } + return block.build(); + } + } + + void enableGroupIdTracking(SeenGroupIds seen) {} + + @Override + public void close() { + for (SingleState state : states.values()) { + Releasables.closeExpectNoException(state); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionTests.java new file mode 100644 index 0000000000000..b69b90e2e0c52 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefAggregatorFunctionTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.hamcrest.core.IsEqual; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.nullValue; + +public class CategorizeBytesRefAggregatorFunctionTests extends AggregatorFunctionTestCase { + + private static final int NUM_PREFIXES = 10; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + List prefixes = IntStream.range(0, NUM_PREFIXES) + .mapToObj(i -> randomAlphaOfLength(5) + " " + randomAlphaOfLength(5) + " " + randomAlphaOfLength(5) + " ") + .toList(); + return new SequenceBytesRefBlockSourceOperator( + blockFactory, + IntStream.range(0, size).mapToObj(i -> new BytesRef((prefixes.get(i % NUM_PREFIXES) + i).getBytes(StandardCharsets.UTF_8))) + ); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new CategorizeBytesRefAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "categorize of bytes"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + int inputSize = (int) input.stream() + .flatMap(AggregatorFunctionTestCase::allBytesRefs) + .filter(Objects::nonNull) + .map(b -> b.utf8ToString().replaceAll("[0-9]", "")) + .distinct() + .count(); + Object resultValue = BlockUtils.toJavaObject(result, 0); + switch (inputSize) { + case 0 -> assertThat(resultValue, nullValue()); + case 1 -> assertThat(resultValue, instanceOf(BytesRef.class)); + default -> { + assertThat(resultValue, instanceOf(List.class)); + assertThat(((List) resultValue).size(), IsEqual.equalTo(inputSize)); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..fc8ad89e942b4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/CategorizeBytesRefGroupingAggregatorFunctionTests.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongBytesRefTupleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.core.IsEqual.equalTo; + +public class CategorizeBytesRefGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + + private static final int NUM_PREFIXES = 10; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + List prefixes = IntStream.range(0, NUM_PREFIXES) + .mapToObj(i -> randomAlphaOfLength(5) + " " + randomAlphaOfLength(5) + " " + randomAlphaOfLength(5) + " ") + .toList(); + + return new LongBytesRefTupleBlockSourceOperator( + blockFactory, + IntStream.range(0, size) + .mapToObj( + i -> Tuple.tuple( + randomGroupId(size), + new BytesRef((prefixes.get(i % NUM_PREFIXES) + i).getBytes(StandardCharsets.UTF_8)) + ) + ) + ); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new CategorizeBytesRefAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "categorize of bytes"; + } + + @Override + protected void assertSimpleGroup(List input, Block result, int position, Long group) { + int inputSize = (int) input.stream() + .flatMap(p -> GroupingAggregatorFunctionTestCase.allBytesRefs(p, group)) + .filter(Objects::nonNull) + .map(b -> b.utf8ToString().replaceAll("[0-9]", "")) + .distinct() + .count(); + Object resultValue = BlockUtils.toJavaObject(result, position); + switch (inputSize) { + case 0 -> assertThat(resultValue, nullValue()); + case 1 -> assertThat(resultValue, instanceOf(BytesRef.class)); + default -> { + assertThat(resultValue, instanceOf(List.class)); + assertThat(((List) resultValue).size(), equalTo(inputSize)); + } + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index f6558d54b2779..72c480e4e29b7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -286,7 +286,7 @@ public final void testMultivalued() { assertSimpleOutput(origInput, results); } - public final void testMulitvaluedNullGroupsAndValues() { + public final void testMultivaluedNullGroupsAndValues() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); int end = between(50, 60); @@ -298,7 +298,7 @@ public final void testMulitvaluedNullGroupsAndValues() { assertSimpleOutput(origInput, results); } - public final void testMulitvaluedNullGroup() { + public final void testMultivaluedNullGroup() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); int end = between(1, 2); // TODO revert @@ -309,7 +309,7 @@ public final void testMulitvaluedNullGroup() { assertSimpleOutput(origInput, results); } - public final void testMulitvaluedNullValues() { + public final void testMultivaluedNullValues() { DriverContext driverContext = driverContext(); BlockFactory blockFactory = driverContext.blockFactory(); int end = between(50, 60); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec new file mode 100644 index 0000000000000..0a5207f80e6a5 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -0,0 +1,16 @@ +categorize +required_capability: categorize + +FROM sample_data + | STATS result = CATEGORIZE(message) + | MV_EXPAND result + | EVAL result = SPLIT(result, ";") + | EVAL key = MV_SLICE(result, 0), regex = MV_SLICE(result, 1), count = TO_LONG(MV_SLICE(result, 2)) + | DROP result +; + +key:keyword | regex:keyword | count:long +Connection error | .*?Connection.+?error.*? | 3 +Connected to | .*?Connected.+?to.*? | 3 +Disconnected | .*?Disconnected.*? | 1 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 325b984c36d34..60bebf66bb202 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -12,6 +12,7 @@ synopsis:keyword "double|date bin(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)" "double|date bucket(field:integer|long|double|date, buckets:integer|long|double|date_period|time_duration, ?from:integer|long|double|date|keyword|text, ?to:integer|long|double|date|keyword|text)" "boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version case(condition:boolean, trueValue...:boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version)" +"text categorize(field:keyword|text)" "double cbrt(number:double|integer|long|unsigned_long)" "double|integer|long|unsigned_long ceil(number:double|integer|long|unsigned_long)" "boolean cidr_match(ip:ip, blockX...:keyword|text)" @@ -136,6 +137,7 @@ avg |number |"double|integer|long" bin |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.] bucket |[field, buckets, from, to] |["integer|long|double|date", "integer|long|double|date_period|time_duration", "integer|long|double|date|keyword|text", "integer|long|double|date|keyword|text"] |[Numeric or date expression from which to derive buckets., Target number of buckets\, or desired bucket size if `from` and `to` parameters are omitted., Start of the range. Can be a number\, a date or a date expressed as a string., End of the range. Can be a number\, a date or a date expressed as a string.] case |[condition, trueValue] |[boolean, "boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version"] |[A condition., The value that's returned when the corresponding condition is the first to evaluate to `true`. The default value is returned when no condition matches.] +categorize |field |"keyword|text" |"Name of the column to categorize." cbrt |number |"double|integer|long|unsigned_long" |"Numeric expression. If `null`, the function returns `null`." ceil |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. cidr_match |[ip, blockX] |[ip, "keyword|text"] |[IP address of type `ip` (both IPv4 and IPv6 are supported)., CIDR block to test the IP against.] @@ -260,6 +262,7 @@ avg |The average of a numeric field. bin |Creates groups of values - buckets - out of a datetime or numeric input. The size of the buckets can either be provided directly, or chosen based on a recommended count and values range. bucket |Creates groups of values - buckets - out of a datetime or numeric input. The size of the buckets can either be provided directly, or chosen based on a recommended count and values range. case |Accepts pairs of conditions and values. The function returns the value that belongs to the first condition that evaluates to `true`. If the number of arguments is odd, the last argument is the default value which is returned when no condition matches. If the number of arguments is even, and no condition matches, the function returns `null`. +categorize |The categorization of a text field. cbrt |Returns the cube root of a number. The input can be any numeric value, the return value is always a double. Cube roots of infinities are null. ceil |Round a number up to the nearest integer. cidr_match |Returns true if the provided IP is contained in one of the provided CIDR blocks. @@ -386,6 +389,7 @@ avg |double bin |"double|date" |[false, false, true, true] |false |false bucket |"double|date" |[false, false, true, true] |false |false case |"boolean|cartesian_point|date|double|geo_point|integer|ip|keyword|long|text|unsigned_long|version" |[false, false] |true |false +categorize |"text" |false |false |true cbrt |double |false |false |false ceil |"double|integer|long|unsigned_long" |false |false |false cidr_match |boolean |[false, false] |true |false @@ -508,5 +512,5 @@ countFunctions#[skip:-8.15.99] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long -115 | 115 | 115 +116 | 116 | 116 ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 0d50623fe77eb..1a3f135424aa1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Check; import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Categorize; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; @@ -252,6 +253,7 @@ private FunctionDefinition[][] functions() { // aggregate functions new FunctionDefinition[] { def(Avg.class, Avg::new, "avg"), + def(Categorize.class, Categorize::new, "categorize"), def(Count.class, Count::new, "count"), def(CountDistinct.class, CountDistinct::new, "count_distinct"), def(Max.class, Max::new, "max"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index f0acac0e9744e..361d38bac34b2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -31,6 +31,7 @@ public abstract class AggregateFunction extends Function { public static List getNamedWriteables() { return List.of( Avg.ENTRY, + Categorize.ENTRY, Count.ENTRY, CountDistinct.ENTRY, Max.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Categorize.java new file mode 100644 index 0000000000000..798af785de322 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Categorize.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.CategorizeBytesRefAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; + +public class Categorize extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "Categorize", + Categorize::new + ); + + @FunctionInfo( + returnType = { "text" }, + preview = true, + description = "The categorization of a text field.", + isAggregation = true, + examples = @Example(file = "string", tag = "categorize") + ) + public Categorize( + Source source, + @Param(name = "field", type = { "keyword", "text" }, description = "Name of the column to categorize.") Expression v + ) { + super(source, v); + } + + private Categorize(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Categorize::new, field()); + } + + @Override + public Categorize replaceChildren(List newChildren) { + return new Categorize(source(), newChildren.get(0)); + } + + @Override + public DataType dataType() { + return field().dataType(); + } + + @Override + protected TypeResolution resolveType() { + return isString(field(), sourceText(), DEFAULT); + } + + @Override + public AggregatorFunctionSupplier supplier(List inputChannels) { + return new CategorizeBytesRefAggregatorFunctionSupplier(inputChannels); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 60bf4be1d2b03..f89caa7b4861d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Categorize; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial; @@ -71,6 +72,7 @@ final class AggregateMapper { /** List of all mappable ESQL agg functions (excludes surrogates like AVG = SUM/COUNT). */ private static final List> AGG_FUNCTIONS = List.of( + Categorize.class, Count.class, CountDistinct.class, Max.class, @@ -169,6 +171,8 @@ private static Stream, Tuple>> typeAndNames(Class } else if (Values.class.isAssignableFrom(clazz)) { // TODO can't we figure this out from the function itself? types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); + } else if (Categorize.class.isAssignableFrom(clazz)) { + types = List.of("BytesRef"); } else if (Top.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip"); } else if (Rate.class.isAssignableFrom(clazz)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java index 4b6a38a3e8762..e06cddd10fc0f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java @@ -179,6 +179,11 @@ public class EsqlFeatures implements FeatureSpecification { */ public static final NodeFeature RESOLVE_FIELDS_API = new NodeFeature("esql.resolve_fields_api"); + /** + * Support categorize + */ + public static final NodeFeature CATEGORIZE = new NodeFeature("esql.categorize"); + private Set snapshotBuildFeatures() { assert Build.current().isSnapshot() : Build.current(); return Set.of(METRICS_SYNTAX); @@ -208,7 +213,8 @@ public Set getFeatures() { METADATA_FIELDS, TIMESPAN_ABBREVIATIONS, COUNTER_TYPES, - RESOLVE_FIELDS_API + RESOLVE_FIELDS_API, + CATEGORIZE ); if (Build.current().isSnapshot()) { return Collections.unmodifiableSet(Sets.union(features, snapshotBuildFeatures())); diff --git a/x-pack/plugin/ml/src/main/java/module-info.java b/x-pack/plugin/ml/src/main/java/module-info.java index 0f3fdd836feca..7a4f955bb79f0 100644 --- a/x-pack/plugin/ml/src/main/java/module-info.java +++ b/x-pack/plugin/ml/src/main/java/module-info.java @@ -39,4 +39,6 @@ exports org.elasticsearch.xpack.ml.action; exports org.elasticsearch.xpack.ml.autoscaling; exports org.elasticsearch.xpack.ml.notifications; + exports org.elasticsearch.xpack.ml.aggs.categorization; + exports org.elasticsearch.xpack.ml.job.categorization; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index 58feb24480f87..7d5f1d5517de0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -12,11 +12,11 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.core.Releasable; -class CategorizationBytesRefHash implements Releasable { +public class CategorizationBytesRefHash implements Releasable { private final BytesRefHash bytesRefHash; - CategorizationBytesRefHash(BytesRefHash bytesRefHash) { + public CategorizationBytesRefHash(BytesRefHash bytesRefHash) { this.bytesRefHash = bytesRefHash; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 7ef7a8f4e6dd5..d0104bda36acd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -187,7 +187,7 @@ long getBucketOrd() { return bucketOrd; } - SerializableTokenListCategory getSerializableCategory() { + public SerializableTokenListCategory getSerializableCategory() { return serializableCategory; } diff --git a/x-pack/plugin/security/qa/multi-cluster/build.gradle b/x-pack/plugin/security/qa/multi-cluster/build.gradle index 625b6806ab520..c7b8f81bb7876 100644 --- a/x-pack/plugin/security/qa/multi-cluster/build.gradle +++ b/x-pack/plugin/security/qa/multi-cluster/build.gradle @@ -23,6 +23,8 @@ dependencies { // esql with enrich clusterModules project(':x-pack:plugin:esql') clusterModules project(':x-pack:plugin:enrich') + clusterModules project(':x-pack:plugin:autoscaling') + clusterModules project(':x-pack:plugin:ml') clusterModules(project(":modules:ingest-common")) } diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java index f5f9410a145cc..1a236ccb6aa06 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java @@ -56,11 +56,14 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe fulfillingCluster = ElasticsearchCluster.local() .name("fulfilling-cluster") .nodes(3) + .module("x-pack-autoscaling") .module("x-pack-esql") .module("x-pack-enrich") + .module("x-pack-ml") .module("ingest-common") .apply(commonClusterConfig) .setting("remote_cluster.port", "0") + .setting("xpack.ml.enabled", "false") .setting("xpack.security.remote_cluster_server.ssl.enabled", () -> String.valueOf(SSL_ENABLED_REF.get())) .setting("xpack.security.remote_cluster_server.ssl.key", "remote-cluster.key") .setting("xpack.security.remote_cluster_server.ssl.certificate", "remote-cluster.crt") @@ -73,10 +76,13 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe queryCluster = ElasticsearchCluster.local() .name("query-cluster") + .module("x-pack-autoscaling") .module("x-pack-esql") .module("x-pack-enrich") + .module("x-pack-ml") .module("ingest-common") .apply(commonClusterConfig) + .setting("xpack.ml.enabled", "false") .setting("xpack.security.remote_cluster_client.ssl.enabled", () -> String.valueOf(SSL_ENABLED_REF.get())) .setting("xpack.security.remote_cluster_client.ssl.certificate_authorities", "remote-cluster-ca.crt") .setting("xpack.security.authc.token.enabled", "true") diff --git a/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle b/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle index b5b8495870259..32fe9dbf9fbc9 100644 --- a/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle +++ b/x-pack/qa/multi-cluster-search-security/legacy-with-basic-license/build.gradle @@ -23,11 +23,15 @@ def fulfillingCluster = testClusters.register('fulfilling-cluster') { module ':modules:data-streams' module ':x-pack:plugin:mapper-constant-keyword' module ':x-pack:plugin:async-search' + module ':x-pack:plugin:autoscaling' module ':x-pack:plugin:esql-core' module ':x-pack:plugin:esql' + module ':x-pack:plugin:ml' module ':modules:ingest-common' module ':x-pack:plugin:enrich' user username: "test_user", password: "x-pack-test-password" + + setting 'xpack.ml.enabled', 'false' } def queryingCluster = testClusters.register('querying-cluster') { @@ -38,13 +42,16 @@ def queryingCluster = testClusters.register('querying-cluster') { module ':modules:data-streams' module ':x-pack:plugin:mapper-constant-keyword' module ':x-pack:plugin:async-search' + module ':x-pack:plugin:autoscaling' module ':x-pack:plugin:esql-core' module ':x-pack:plugin:esql' + module ':x-pack:plugin:ml' module ':modules:ingest-common' module ':x-pack:plugin:enrich' setting 'cluster.remote.connections_per_cluster', "1" user username: "test_user", password: "x-pack-test-password" + setting 'xpack.ml.enabled', 'false' setting 'cluster.remote.my_remote_cluster.skip_unavailable', 'false' if (proxyMode) { setting 'cluster.remote.my_remote_cluster.mode', 'proxy'