diff --git a/docs/changelog/118173.yaml b/docs/changelog/118173.yaml new file mode 100644 index 0000000000000..a3c9054674ba5 --- /dev/null +++ b/docs/changelog/118173.yaml @@ -0,0 +1,5 @@ +pr: 118173 +summary: ES|QL categorize with multiple groupings +area: Machine Learning +type: feature +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 9b53e6558f4db..191d6443264ca 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -180,13 +180,16 @@ public static BlockHash buildCategorizeBlockHash( List groups, AggregatorMode aggregatorMode, BlockFactory blockFactory, - AnalysisRegistry analysisRegistry + AnalysisRegistry analysisRegistry, + int emitBatchSize ) { - if (groups.size() != 1) { - throw new IllegalArgumentException("only a single CATEGORIZE group can used"); + if (groups.size() == 1) { + return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry); + } else { + assert groups.get(0).isCategorize(); + assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize); + return new CategorizePackedValuesBlockHash(groups, blockFactory, aggregatorMode, analysisRegistry, emitBatchSize); } - - return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry); } /** diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java index 35c6faf84e623..f83776fbdbc85 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java @@ -44,7 +44,7 @@ import java.util.Objects; /** - * Base BlockHash implementation for {@code Categorize} grouping function. + * BlockHash implementation for {@code Categorize} grouping function. */ public class CategorizeBlockHash extends BlockHash { @@ -53,11 +53,9 @@ public class CategorizeBlockHash extends BlockHash { ); private static final int NULL_ORD = 0; - // TODO: this should probably also take an emitBatchSize private final int channel; private final AggregatorMode aggregatorMode; private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; - private final CategorizeEvaluator evaluator; /** @@ -95,12 +93,14 @@ public class CategorizeBlockHash extends BlockHash { } } + boolean seenNull() { + return seenNull; + } + @Override public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { - if (aggregatorMode.isInputPartial() == false) { - addInitial(page, addInput); - } else { - addIntermediate(page, addInput); + try (IntBlock block = add(page)) { + addInput.add(0, block); } } @@ -129,50 +129,38 @@ public void close() { Releasables.close(evaluator, categorizer); } + private IntBlock add(Page page) { + return aggregatorMode.isInputPartial() == false ? addInitial(page) : addIntermediate(page); + } + /** * Adds initial (raw) input to the state. */ - private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) { - try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) { - addInput.add(0, result); - } + IntBlock addInitial(Page page) { + return (IntBlock) evaluator.eval(page.getBlock(channel)); } /** * Adds intermediate state to the state. */ - private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) { + private IntBlock addIntermediate(Page page) { if (page.getPositionCount() == 0) { - return; + return null; } BytesRefBlock categorizerState = page.getBlock(channel); if (categorizerState.areAllValuesNull()) { seenNull = true; - try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) { - addInput.add(0, newIds); - } - return; - } - - Map idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef())); - try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) { - int fromId = idMap.containsKey(0) ? 0 : 1; - int toId = fromId + idMap.size(); - for (int i = fromId; i < toId; i++) { - newIdsBuilder.appendInt(idMap.get(i)); - } - try (IntBlock newIds = newIdsBuilder.build()) { - addInput.add(0, newIds); - } + return blockFactory.newConstantIntBlockWith(NULL_ORD, 1); } + return recategorize(categorizerState.getBytesRef(0, new BytesRef()), null).asBlock(); } /** - * Read intermediate state from a block. - * - * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}. + * Reads the intermediate state from a block and recategorizes the provided IDs. + * If no IDs are provided, the IDs are the IDs in the categorizer's state in order. + * (So 0...N-1 or 1...N, depending on whether null is present.) */ - private Map readIntermediate(BytesRef bytes) { + IntVector recategorize(BytesRef bytes, IntVector ids) { Map idMap = new HashMap<>(); try (StreamInput in = new BytesArray(bytes).streamInput()) { if (in.readBoolean()) { @@ -185,10 +173,22 @@ private Map readIntermediate(BytesRef bytes) { // +1 because the 0 ordinal is reserved for null idMap.put(oldCategoryId + 1, newCategoryId + 1); } - return idMap; } catch (IOException e) { throw new RuntimeException(e); } + try (IntVector.Builder newIdsBuilder = blockFactory.newIntVectorBuilder(idMap.size())) { + if (ids == null) { + int idOffset = idMap.containsKey(0) ? 0 : 1; + for (int i = 0; i < idMap.size(); i++) { + newIdsBuilder.appendInt(idMap.get(i + idOffset)); + } + } else { + for (int i = 0; i < ids.getPositionCount(); i++) { + newIdsBuilder.appendInt(idMap.get(ids.getInt(i))); + } + } + return newIdsBuilder.build(); + } } /** @@ -198,15 +198,20 @@ private Block buildIntermediateBlock() { if (categorizer.getCategoryCount() == 0) { return blockFactory.newConstantNullBlock(seenNull ? 1 : 0); } + int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0); + // We're returning a block with N positions just because the Page must have all blocks with the same position count! + return blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), positionCount); + } + + BytesRef serializeCategorizer() { + // TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that! try (BytesStreamOutput out = new BytesStreamOutput()) { out.writeBoolean(seenNull); out.writeVInt(categorizer.getCategoryCount()); for (SerializableTokenListCategory category : categorizer.toCategoriesById()) { category.writeTo(out); } - // We're returning a block with N positions just because the Page must have all blocks with the same position count! - int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0); - return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount); + return out.bytes().toBytesRef(); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java new file mode 100644 index 0000000000000..20874cb10ceb8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation.blockhash; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.ReleasableIterator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.index.analysis.AnalysisRegistry; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * BlockHash implementation for {@code Categorize} grouping function as first + * grouping expression, followed by one or mode other grouping expressions. + *

+ * For the first grouping (the {@code Categorize} grouping function), a + * {@code CategorizeBlockHash} is used, which outputs integers (category IDs). + * Next, a {@code PackedValuesBlockHash} is used on the category IDs and the + * other groupings (which are not {@code Categorize}s). + */ +public class CategorizePackedValuesBlockHash extends BlockHash { + + private final List specs; + private final AggregatorMode aggregatorMode; + private final Block[] blocks; + private final CategorizeBlockHash categorizeBlockHash; + private final PackedValuesBlockHash packedValuesBlockHash; + + CategorizePackedValuesBlockHash( + List specs, + BlockFactory blockFactory, + AggregatorMode aggregatorMode, + AnalysisRegistry analysisRegistry, + int emitBatchSize + ) { + super(blockFactory); + this.specs = specs; + this.aggregatorMode = aggregatorMode; + blocks = new Block[specs.size()]; + + List delegateSpecs = new ArrayList<>(); + delegateSpecs.add(new GroupSpec(0, ElementType.INT)); + for (int i = 1; i < specs.size(); i++) { + delegateSpecs.add(new GroupSpec(i, specs.get(i).elementType())); + } + + boolean success = false; + try { + categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry); + packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + @Override + public void add(Page page, GroupingAggregatorFunction.AddInput addInput) { + try (IntBlock categories = getCategories(page)) { + blocks[0] = categories; + for (int i = 1; i < specs.size(); i++) { + blocks[i] = page.getBlock(specs.get(i).channel()); + } + packedValuesBlockHash.add(new Page(blocks), addInput); + } + } + + private IntBlock getCategories(Page page) { + if (aggregatorMode.isInputPartial() == false) { + return categorizeBlockHash.addInitial(page); + } else { + BytesRefBlock stateBlock = page.getBlock(0); + BytesRef stateBytes = stateBlock.getBytesRef(0, new BytesRef()); + try (StreamInput in = new BytesArray(stateBytes).streamInput()) { + BytesRef categorizerState = in.readBytesRef(); + try (IntVector ids = IntVector.readFrom(blockFactory, in)) { + return categorizeBlockHash.recategorize(categorizerState, ids).asBlock(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public Block[] getKeys() { + Block[] keys = packedValuesBlockHash.getKeys(); + if (aggregatorMode.isOutputPartial() == false) { + // For final output, the keys are the category regexes. + try ( + BytesRefBlock regexes = (BytesRefBlock) categorizeBlockHash.getKeys()[0]; + BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(keys[0].getPositionCount()) + ) { + IntVector idsVector = (IntVector) keys[0].asVector(); + int idsOffset = categorizeBlockHash.seenNull() ? 0 : -1; + BytesRef scratch = new BytesRef(); + for (int i = 0; i < idsVector.getPositionCount(); i++) { + int id = idsVector.getInt(i); + if (id == 0) { + builder.appendNull(); + } else { + builder.appendBytesRef(regexes.getBytesRef(id + idsOffset, scratch)); + } + } + keys[0].close(); + keys[0] = builder.build(); + } + } else { + // For intermediate output, the keys are the delegate PackedValuesBlockHash's + // keys, with the category IDs replaced by the categorizer's internal state + // together with the list of category IDs. + BytesRef state; + // TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that! + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.writeBytesRef(categorizeBlockHash.serializeCategorizer()); + ((IntVector) keys[0].asVector()).writeTo(out); + state = out.bytes().toBytesRef(); + } catch (IOException e) { + throw new RuntimeException(e); + } + keys[0].close(); + keys[0] = blockFactory.newConstantBytesRefBlockWith(state, keys[0].getPositionCount()); + } + return keys; + } + + @Override + public IntVector nonEmpty() { + return packedValuesBlockHash.nonEmpty(); + } + + @Override + public BitArray seenGroupIds(BigArrays bigArrays) { + return packedValuesBlockHash.seenGroupIds(bigArrays); + } + + @Override + public final ReleasableIterator lookup(Page page, ByteSizeValue targetBlockSize) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + Releasables.close(categorizeBlockHash, packedValuesBlockHash); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 6f8386ec08de1..ccddfdf5cc74a 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -51,7 +51,13 @@ public Operator get(DriverContext driverContext) { if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) { return new HashAggregationOperator( aggregators, - () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry), + () -> BlockHash.buildCategorizeBlockHash( + groups, + aggregatorMode, + driverContext.blockFactory(), + analysisRegistry, + maxPageSize + ), driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index f8428b7c33568..587deda650a23 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -130,9 +130,6 @@ public void close() { } finally { page.releaseBlocks(); } - - // TODO: randomize values? May give wrong results - // TODO: assert the categorizer state after adding pages. } public void testCategorizeRawMultivalue() { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java new file mode 100644 index 0000000000000..cfa023af3d18a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java @@ -0,0 +1,248 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation.blockhash; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.analysis.common.CommonAnalysisPlugin; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.HashAggregationOperator; +import org.elasticsearch.compute.operator.LocalSourceOperator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.indices.analysis.AnalysisModule; +import org.elasticsearch.plugins.scanners.StablePluginsRegistry; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizePackedValuesBlockHashTests extends BlockHashTestCase { + + private AnalysisRegistry analysisRegistry; + + @Before + private void initAnalysisRegistry() throws IOException { + analysisRegistry = new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()), + new StablePluginsRegistry() + ).getAnalysisRegistry(); + } + + public void testCategorize_withDriver() { + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + boolean withNull = randomBoolean(); + boolean withMultivalues = randomBoolean(); + + List groupSpecs = List.of( + new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true), + new BlockHash.GroupSpec(1, ElementType.INT, false) + ); + + LocalSourceOperator.BlockSupplier input1 = () -> { + try ( + BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10); + IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10) + ) { + if (withMultivalues) { + messagesBuilder.beginPositionEntry(); + } + messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.1")); + messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.2")); + if (withMultivalues) { + messagesBuilder.endPositionEntry(); + } + idsBuilder.appendInt(7); + if (withMultivalues == false) { + idsBuilder.appendInt(7); + } + + messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.3")); + messagesBuilder.appendBytesRef(new BytesRef("connection error")); + messagesBuilder.appendBytesRef(new BytesRef("connection error")); + messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.4")); + idsBuilder.appendInt(42); + idsBuilder.appendInt(7); + idsBuilder.appendInt(42); + idsBuilder.appendInt(7); + + if (withNull) { + messagesBuilder.appendNull(); + idsBuilder.appendInt(43); + } + return new Block[] { messagesBuilder.build(), idsBuilder.build() }; + } + }; + LocalSourceOperator.BlockSupplier input2 = () -> { + try ( + BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10); + IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10) + ) { + messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.1")); + messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.2")); + messagesBuilder.appendBytesRef(new BytesRef("disconnected")); + messagesBuilder.appendBytesRef(new BytesRef("connection error")); + idsBuilder.appendInt(111); + idsBuilder.appendInt(7); + idsBuilder.appendInt(7); + idsBuilder.appendInt(42); + if (withNull) { + messagesBuilder.appendNull(); + idsBuilder.appendNull(); + } + return new Block[] { messagesBuilder.build(), idsBuilder.build() }; + } + }; + + List intermediateOutput = new ArrayList<>(); + + Driver driver = new Driver( + driverContext, + new LocalSourceOperator(input1), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory( + groupSpecs, + AggregatorMode.INITIAL, + List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)), + 16 * 1024, + analysisRegistry + ).get(driverContext) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + driver = new Driver( + driverContext, + new LocalSourceOperator(input2), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory( + groupSpecs, + AggregatorMode.INITIAL, + List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)), + 16 * 1024, + analysisRegistry + ).get(driverContext) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + List finalOutput = new ArrayList<>(); + + driver = new Driver( + driverContext, + new CannedSourceOperator(intermediateOutput.iterator()), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory( + groupSpecs, + AggregatorMode.FINAL, + List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(2)).groupingAggregatorFactory(AggregatorMode.FINAL)), + 16 * 1024, + analysisRegistry + ).get(driverContext) + ), + new PageConsumerOperator(finalOutput::add), + () -> {} + ); + runDriver(driver); + + assertThat(finalOutput, hasSize(1)); + assertThat(finalOutput.get(0).getBlockCount(), equalTo(3)); + BytesRefBlock outputMessages = finalOutput.get(0).getBlock(0); + IntBlock outputIds = finalOutput.get(0).getBlock(1); + BytesRefBlock outputValues = finalOutput.get(0).getBlock(2); + assertThat(outputIds.getPositionCount(), equalTo(outputMessages.getPositionCount())); + assertThat(outputValues.getPositionCount(), equalTo(outputMessages.getPositionCount())); + Map>> result = new HashMap<>(); + for (int i = 0; i < outputMessages.getPositionCount(); i++) { + BytesRef messageBytesRef = ((BytesRef) BlockUtils.toJavaObject(outputMessages, i)); + String message = messageBytesRef == null ? null : messageBytesRef.utf8ToString(); + result.computeIfAbsent(message, key -> new HashMap<>()); + + Integer id = (Integer) BlockUtils.toJavaObject(outputIds, i); + result.get(message).computeIfAbsent(id, key -> new HashSet<>()); + + Object values = BlockUtils.toJavaObject(outputValues, i); + if (values == null) { + result.get(message).get(id).add(null); + } else { + if ((values instanceof List) == false) { + values = List.of(values); + } + for (Object valueObject : (List) values) { + BytesRef value = (BytesRef) valueObject; + result.get(message).get(id).add(value.utf8ToString()); + } + } + } + Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); + + Map>> expectedResult = Map.of( + ".*?connected.+?to.*?", + Map.of( + 7, + Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"), + 42, + Set.of("connected to 1.1.3"), + 111, + Set.of("connected to 2.1.1") + ), + ".*?connection.+?error.*?", + Map.of(7, Set.of("connection error"), 42, Set.of("connection error")), + ".*?disconnected.*?", + Map.of(7, Set.of("disconnected")) + ); + if (withNull) { + expectedResult = new HashMap<>(expectedResult); + expectedResult.put(null, new HashMap<>()); + expectedResult.get(null).put(null, new HashSet<>()); + expectedResult.get(null).get(null).add(null); + expectedResult.get(null).put(43, new HashSet<>()); + expectedResult.get(null).get(43).add(null); + } + assertThat(result, equalTo(expectedResult)); + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec index 4ce43961a7077..5ad62dd7a21a8 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec @@ -60,6 +60,19 @@ COUNT():long | VALUES(str):keyword | category:keyword 1 | [a, b, c] | .*?disconnected.*? ; +limit before stats +required_capability: categorize_v5 + +FROM sample_data | SORT message | LIMIT 4 + | STATS count=COUNT() BY category=CATEGORIZE(message) + | SORT category +; + +count:long | category:keyword + 3 | .*?Connected.+?to.*? + 1 | .*?Connection.+?error.*? +; + skips stopwords required_capability: categorize_v5 @@ -615,3 +628,159 @@ COUNT():long | x:keyword 3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?] 1 | [.*?Disconnected.*?,.*?Disconnected.*?] ; + +multiple groupings with categorize and ip +required_capability: categorize_multiple_groupings + +FROM sample_data + | STATS count=COUNT() BY category=CATEGORIZE(message), client_ip + | SORT category, client_ip +; + +count:long | category:keyword | client_ip:ip + 1 | .*?Connected.+?to.*? | 172.21.2.113 + 1 | .*?Connected.+?to.*? | 172.21.2.162 + 1 | .*?Connected.+?to.*? | 172.21.3.15 + 3 | .*?Connection.+?error.*? | 172.21.3.15 + 1 | .*?Disconnected.*? | 172.21.0.5 +; + +multiple groupings with categorize and bucketed timestamp +required_capability: categorize_multiple_groupings + +FROM sample_data + | STATS count=COUNT() BY category=CATEGORIZE(message), timestamp=BUCKET(@timestamp, 1 HOUR) + | SORT category, timestamp +; + +count:long | category:keyword | timestamp:datetime + 2 | .*?Connected.+?to.*? | 2023-10-23T12:00:00.000Z + 1 | .*?Connected.+?to.*? | 2023-10-23T13:00:00.000Z + 3 | .*?Connection.+?error.*? | 2023-10-23T13:00:00.000Z + 1 | .*?Disconnected.*? | 2023-10-23T13:00:00.000Z +; + + +multiple groupings with categorize and limit before stats +required_capability: categorize_multiple_groupings + +FROM sample_data | SORT message | LIMIT 5 + | STATS count=COUNT() BY category=CATEGORIZE(message), client_ip + | SORT category, client_ip +; + +count:long | category:keyword | client_ip:ip + 1 | .*?Connected.+?to.*? | 172.21.2.113 + 1 | .*?Connected.+?to.*? | 172.21.2.162 + 1 | .*?Connected.+?to.*? | 172.21.3.15 + 2 | .*?Connection.+?error.*? | 172.21.3.15 +; + +multiple groupings with categorize and nulls +required_capability: categorize_multiple_groupings + +FROM employees + | STATS SUM(languages) BY category=CATEGORIZE(job_positions), gender + | SORT category DESC, gender ASC + | LIMIT 5 +; + +SUM(languages):long | category:keyword | gender:keyword + 11 | null | F + 16 | null | M + 14 | .*?Tech.+?Lead.*? | F + 23 | .*?Tech.+?Lead.*? | M + 9 | .*?Tech.+?Lead.*? | null +; + +multiple groupings with categorize and a field that's always null +required_capability: categorize_multiple_groupings + +FROM sample_data + | EVAL nullfield = null + | STATS count=COUNT() BY category=CATEGORIZE(nullfield), client_ip + | SORT client_ip +; + +count:long | category:keyword | client_ip:ip + 1 | null | 172.21.0.5 + 1 | null | 172.21.2.113 + 1 | null | 172.21.2.162 + 4 | null | 172.21.3.15 +; + +multiple groupings with categorize and the same text field +required_capability: categorize_multiple_groupings + +FROM sample_data + | STATS count=COUNT() BY category=CATEGORIZE(message), message + | SORT message +; + +count:long | category:keyword | message:keyword + 1 | .*?Connected.+?to.*? | Connected to 10.1.0.1 + 1 | .*?Connected.+?to.*? | Connected to 10.1.0.2 + 1 | .*?Connected.+?to.*? | Connected to 10.1.0.3 + 3 | .*?Connection.+?error.*? | Connection error + 1 | .*?Disconnected.*? | Disconnected +; + +multiple additional complex groupings with categorize +required_capability: categorize_multiple_groupings + +FROM sample_data + | STATS count=COUNT(), duration=SUM(event_duration) BY category=CATEGORIZE(message), SUBSTRING(message, 1, 7), ip_part=TO_LONG(SUBSTRING(TO_STRING(client_ip), 8, 1)), hour=BUCKET(@timestamp, 1 HOUR) + | SORT ip_part, category +; + +count:long | duration:long | category:keyword | SUBSTRING(message, 1, 7):keyword | ip_part:long | hour:datetime + 1 | 1232382 | .*?Disconnected.*? | Disconn | 0 | 2023-10-23T13:00:00.000Z + 2 | 6215122 | .*?Connected.+?to.*? | Connect | 2 | 2023-10-23T12:00:00.000Z + 1 | 1756467 | .*?Connected.+?to.*? | Connect | 3 | 2023-10-23T13:00:00.000Z + 3 | 14027356 | .*?Connection.+?error.*? | Connect | 3 | 2023-10-23T13:00:00.000Z +; + +multiple groupings with categorize and some constants including null +required_capability: categorize_multiple_groupings + +FROM sample_data + | STATS count=MV_COUNT(VALUES(message)) BY category=CATEGORIZE(message), null, constant="constant" + | SORT category +; + +count:integer | category:keyword | null:null | constant:keyword + 3 | .*?Connected.+?to.*? | null | constant + 1 | .*?Connection.+?error.*? | null | constant + 1 | .*?Disconnected.*? | null | constant +; + +multiple groupings with categorize and aggregation filters +required_capability: categorize_multiple_groupings + +FROM employees + | STATS lang_low=AVG(languages) WHERE salary<=50000, lang_high=AVG(languages) WHERE salary>50000 BY category=CATEGORIZE(job_positions), gender + | SORT category, gender + | LIMIT 5 +; + +lang_low:double | lang_high:double | category:keyword | gender:keyword + 2.0 | 5.0 | .*?Accountant.*? | F + 3.0 | 2.5 | .*?Accountant.*? | M + 5.0 | 2.0 | .*?Accountant.*? | null + 3.0 | 3.25 | .*?Architect.*? | F + 3.75 | null | .*?Architect.*? | M +; + +multiple groupings with categorize on null row +required_capability: categorize_multiple_groupings + +ROW message = null, str = ["a", "b", "c"] + | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message), str + | SORT str +; + +COUNT():long | VALUES(str):keyword | category:keyword | str:keyword + 1 | [a, b, c] | null | a + 1 | [a, b, c] | null | b + 1 | [a, b, c] | null | c +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 8619c0461ac35..6d61988ac5cb9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -407,6 +407,10 @@ public enum Cap { */ CATEGORIZE_V5, + /** + * Support for multiple groupings in "CATEGORIZE". + */ + CATEGORIZE_MULTIPLE_GROUPINGS, /** * QSTR function */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 07ba2c2a6ef3c..890a4866ccdde 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -325,11 +325,15 @@ private static void checkAggregate(LogicalPlan p, Set failures) { private static void checkCategorizeGrouping(Aggregate agg, Set failures) { // Forbid CATEGORIZE grouping function with other groupings if (agg.groupings().size() > 1) { - agg.groupings().forEach(g -> { + agg.groupings().subList(1, agg.groupings().size()).forEach(g -> { g.forEachDown( Categorize.class, categorize -> failures.add( - fail(categorize, "cannot use CATEGORIZE grouping function [{}] with multiple groupings", categorize.sourceText()) + fail( + categorize, + "CATEGORIZE grouping function [{}] can only be in the first grouping expression", + categorize.sourceText() + ) ) ); }); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index ded913a78bdf1..a100dd64915f1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -95,7 +95,8 @@ public boolean foldable() { @Override public Nullability nullable() { - // Both nulls and empty strings result in null values + // Null strings and strings that don't produce tokens after analysis lead to null values. + // This includes empty strings, only whitespace, (hexa)decimal numbers and stopwords. return Nullability.TRUE; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index d58d233168e2b..bcce4ea228c22 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1894,38 +1894,35 @@ public void testIntervalAsString() { ); } - public void testCategorizeSingleGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - - query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)"); - query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); + public void testCategorizeOnlyFirstGrouping() { + query("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name)"); + query("FROM test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)"); + query("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no"); + query("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no"); assertEquals( - "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", - error("from test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no") + "1:39: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression", + error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)") ); assertEquals( - "1:39: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", - error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)") + "1:55: CATEGORIZE grouping function [CATEGORIZE(last_name)] can only be in the first grouping expression", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)") ); assertEquals( - "1:35: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", - error("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no") + "1:55: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)") ); assertEquals( - "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings\n" - + "line 1:55: cannot use CATEGORIZE grouping function [CATEGORIZE(last_name)] with multiple groupings", - error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)") + "1:63: CATEGORIZE grouping function [CATEGORIZE(last_name)] can only be in the first grouping expression", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no, CATEGORIZE(last_name)") ); assertEquals( - "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings", - error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)") + "1:63: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression", + error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no, CATEGORIZE(first_name)") ); } public void testCategorizeNestedGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)"); assertEquals( @@ -1939,8 +1936,6 @@ public void testCategorizeNestedGrouping() { } public void testCategorizeWithinAggregations() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)"); query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY cat = CATEGORIZE(first_name)"); query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY CATEGORIZE(first_name)"); @@ -1969,8 +1964,6 @@ public void testCategorizeWithinAggregations() { } public void testCategorizeWithFilteredAggregations() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - query("FROM test | STATS COUNT(*) WHERE first_name == \"John\" BY CATEGORIZE(last_name)"); query("FROM test | STATS COUNT(*) WHERE last_name == \"Doe\" BY CATEGORIZE(last_name)"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 737bb2eb23a6f..5a12d26a19a20 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils; @@ -1212,8 +1211,6 @@ public void testCombineProjectionWithAggregationFirstAndAliasedGroupingUsedInAgg * \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..] */ public void testCombineProjectionWithCategorizeGrouping() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - var plan = plan(""" from test | eval k = first_name, k1 = k @@ -3949,8 +3946,6 @@ public void testNestedExpressionsInGroups() { * \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..] */ public void testNestedExpressionsInGroupsWithCategorize() { - assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled()); - var plan = optimizedPlan(""" from test | stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))