diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index f23a4b07d8719..f239a7ce8b139 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -17,13 +17,13 @@ import org.elasticsearch.compute.aggregation.CountAggregatorFunction; import org.elasticsearch.compute.aggregation.CountDistinctDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -125,29 +125,32 @@ private static Operator operator(DriverContext driverContext, String grouping, S driverContext ); } - List groups = switch (grouping) { - case LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); - case INTS -> List.of(new BlockHash.GroupSpec(0, ElementType.INT)); - case DOUBLES -> List.of(new BlockHash.GroupSpec(0, ElementType.DOUBLE)); - case BOOLEANS -> List.of(new BlockHash.GroupSpec(0, ElementType.BOOLEAN)); - case BYTES_REFS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)); - case TWO_LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG)); + List groups = switch (grouping) { + case LONGS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE)); + case INTS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.INT).get(AggregatorMode.SINGLE)); + case DOUBLES -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.DOUBLE).get(AggregatorMode.SINGLE)); + case BOOLEANS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BOOLEAN).get(AggregatorMode.SINGLE)); + case BYTES_REFS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BYTES_REF).get(AggregatorMode.SINGLE)); + case TWO_LONGS -> List.of( + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE) + ); case LONGS_AND_BYTES_REFS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG), - new BlockHash.GroupSpec(1, ElementType.BYTES_REF) + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.BYTES_REF).get(AggregatorMode.SINGLE) ); case TWO_LONGS_AND_BYTES_REFS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG), - new BlockHash.GroupSpec(1, ElementType.LONG), - new BlockHash.GroupSpec(2, ElementType.BYTES_REF) + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(2, ElementType.BYTES_REF).get(AggregatorMode.SINGLE) ); default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]"); }; - return new HashAggregationOperator( + return new HashAggregationOperator.HashAggregationOperatorFactory( + groups, List.of(supplier(op, dataType, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)), - () -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false), - driverContext - ); + 16 * 1024 + ).get(driverContext); } private static AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java new file mode 100644 index 0000000000000..b7bf7615c3835 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java @@ -0,0 +1,142 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasable; + +import java.util.ArrayList; +import java.util.List; + +public record GroupingKey(AggregatorMode mode, Thing thing, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator { + public interface Thing extends Releasable { + int extraIntermediateBlocks(); + + Block evalRawInput(Page page); + + Block evalIntermediateInput(BlockFactory blockFactory, Page page); + + void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount); + + void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks); + } + + public interface Supplier { + Factory get(AggregatorMode mode); + } + + public interface Factory { + GroupingKey apply(DriverContext context, int resultOffset); + + ElementType intermediateElementType(); + + GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel); + } + + public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType elementType) { + return mode -> new Factory() { + @Override + public GroupingKey apply(DriverContext context, int resultOffset) { + return new GroupingKey(mode, new Load(channel, resultOffset), context.blockFactory()); + } + + @Override + public ElementType intermediateElementType() { + return elementType; + } + + @Override + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + if (channel != timeBucketChannel) { + final List channels = List.of(channel); + // TODO: perhaps introduce a specialized aggregator for this? + return (switch (intermediateElementType()) { + case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels); + case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels); + case INT -> new ValuesIntAggregatorFunctionSupplier(channels); + case LONG -> new ValuesLongAggregatorFunctionSupplier(channels); + case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels); + case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); + }).groupingAggregatorFactory(AggregatorMode.SINGLE); + } + return null; + } + }; + } + + public static List toBlockHashGroupSpec(List keys) { + List result = new ArrayList<>(keys.size()); + for (int k = 0; k < keys.size(); k++) { + result.add(new BlockHash.GroupSpec(k, keys.get(k).intermediateElementType())); + } + return result; + } + + @Override + public Block eval(Page page) { + return mode.isInputPartial() ? thing.evalIntermediateInput(blockFactory, page) : thing.evalRawInput(page); + } + + public int finishBlockCount() { + return mode.isOutputPartial() ? 1 + thing.extraIntermediateBlocks() : 1; + } + + public void finish(Block[] blocks, IntVector selected, DriverContext driverContext) { + if (mode.isOutputPartial()) { + thing.fetchIntermediateState(driverContext.blockFactory(), blocks, selected.getPositionCount()); + } else { + thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks); + } + } + + public int extraIntermediateBlocks() { + return thing.extraIntermediateBlocks(); + } + + @Override + public void close() { + thing.close(); + } + + private record Load(int channel, int resultOffset) implements Thing { + @Override + public int extraIntermediateBlocks() { + return 0; + } + + @Override + public Block evalRawInput(Page page) { + Block b = page.getBlock(channel); + b.incRef(); + return b; + } + + @Override + public Block evalIntermediateInput(BlockFactory blockFactory, Page page) { + Block b = page.getBlock(resultOffset); + b.incRef(); + return b; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) {} + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) {} + + @Override + public void close() {} + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 03a4ca2b0ad5e..6e6ccf8442a54 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntBlock; @@ -27,18 +28,15 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.function.Supplier; -import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; public class HashAggregationOperator implements Operator { - public record HashAggregationOperatorFactory( - List groups, + List groups, List aggregators, int maxPageSize ) implements OperatorFactory { @@ -46,7 +44,8 @@ public record HashAggregationOperatorFactory( public Operator get(DriverContext driverContext) { return new HashAggregationOperator( aggregators, - () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false), + groups, + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groups), driverContext.blockFactory(), maxPageSize, false), driverContext ); } @@ -61,15 +60,17 @@ public String describe() { } } - private boolean finished; - private Page output; + private final List aggregators; - private final BlockHash blockHash; + private final List groups; - private final List aggregators; + private final BlockHash blockHash; private final DriverContext driverContext; + private boolean finished; + private Page output; + /** * Nanoseconds this operator has spent hashing grouping keys. */ @@ -86,10 +87,12 @@ public String describe() { @SuppressWarnings("this-escape") public HashAggregationOperator( List aggregators, + List groups, Supplier blockHash, DriverContext driverContext ) { this.aggregators = new ArrayList<>(aggregators.size()); + this.groups = new ArrayList<>(groups.size()); this.driverContext = driverContext; boolean success = false; try { @@ -97,6 +100,12 @@ public HashAggregationOperator( for (GroupingAggregator.Factory a : aggregators) { this.aggregators.add(a.apply(driverContext)); } + int offset = 0; + for (GroupingKey.Factory g : groups) { + GroupingKey key = g.apply(driverContext, offset); + this.groups.add(key); + offset += key.extraIntermediateBlocks() + 1; + } success = true; } finally { if (success == false) { @@ -112,6 +121,8 @@ public boolean needsInput() { @Override public void addInput(Page page) { + checkState(needsInput(), "Operator is already finishing"); + try { GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()]; class AddInput implements GroupingAggregatorFunction.AddInput { @@ -156,16 +167,21 @@ public void close() { Releasables.closeExpectNoException(prepared); } } + Block[] keys = new Block[groups.size()]; + page = wrapPage(page); try (AddInput add = new AddInput()) { - checkState(needsInput(), "Operator is already finishing"); - requireNonNull(page, "page is null"); + for (int g = 0; g < groups.size(); g++) { + keys[g] = groups.get(g).eval(page); + } for (int i = 0; i < prepared.length; i++) { prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page); } - blockHash.add(wrapPage(page), add); + blockHash.add(new Page(keys), add); hashNanos += System.nanoTime() - add.hashStart; + } finally { + Releasables.close(keys); } } finally { page.releaseBlocks(); @@ -192,15 +208,29 @@ public void finish() { try { selected = blockHash.nonEmpty(); Block[] keys = blockHash.getKeys(); - int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray(); - blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()]; - System.arraycopy(keys, 0, blocks, 0, keys.length); - int offset = keys.length; - for (int i = 0; i < aggregators.size(); i++) { - var aggregator = aggregators.get(i); - aggregator.evaluate(blocks, offset, selected, driverContext); - offset += aggBlockCounts[i]; + + int blockCount = 0; + for (int g = 0; g < groups.size(); g++) { + blockCount += groups.get(g).finishBlockCount(); + } + int[] aggBlockCounts = new int[aggregators.size()]; + for (int a = 0; a < aggregators.size(); a++) { + aggBlockCounts[a] = aggregators.get(a).evaluateBlockCount(); + blockCount += aggBlockCounts[a]; } + + blocks = new Block[blockCount]; + int offset = 0; + for (int g = 0; g < groups.size(); g++) { + blocks[offset] = keys[g]; + groups.get(g).finish(blocks, selected, driverContext); + offset += groups.get(g).finishBlockCount(); + } + for (int a = 0; a < aggregators.size(); a++) { + aggregators.get(a).evaluate(blocks, offset, selected, driverContext); + offset += aggBlockCounts[a]; + } + output = new Page(blocks); success = true; } finally { @@ -224,7 +254,7 @@ public void close() { if (output != null) { output.releaseBlocks(); } - Releasables.close(blockHash, () -> Releasables.close(aggregators)); + Releasables.close(blockHash, Releasables.wrap(aggregators), Releasables.wrap(groups)); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index b5ae35bfc8d7f..52fa93cc43927 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -17,9 +17,11 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.compute.Describable; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.GroupingAggregator.Factory; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.aggregation.blockhash.BlockHash.GroupSpec; @@ -499,12 +501,11 @@ private static class ValuesAggregator implements Releasable { ); this.aggregator = new HashAggregationOperator( aggregatorFactories, - () -> BlockHash.build( - List.of(new GroupSpec(channelIndex, groupingElementType)), - driverContext.blockFactory(), - maxPageSize, - false + List.of( + // NOCOMMIT double check the mode + GroupingKey.forStatelessGrouping(channelIndex, groupingElementType).get(AggregatorMode.INITIAL) ), + () -> BlockHash.build(List.of(new GroupSpec(0, groupingElementType)), driverContext.blockFactory(), maxPageSize, false), driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java index 1e9ea88b2f1d7..3c8f35d4dad96 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java @@ -10,6 +10,7 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.aggregation.blockhash.TimeSeriesBlockHash; import org.elasticsearch.compute.data.ElementType; @@ -44,7 +45,7 @@ public final class TimeSeriesAggregationOperatorFactories { public record Initial( int tsHashChannel, int timeBucketChannel, - List groupings, + List groupings, List rates, List nonRates, int maxPageSize @@ -58,9 +59,11 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INITIAL)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INITIAL)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); return new HashAggregationOperator( aggregators, + groupings, () -> new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext), driverContext ); @@ -75,7 +78,7 @@ public String describe() { public record Intermediate( int tsHashChannel, int timeBucketChannel, - List groupings, + List groupings, List rates, List nonRates, int maxPageSize @@ -89,6 +92,7 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INTERMEDIATE)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INTERMEDIATE)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); List hashGroups = List.of( new BlockHash.GroupSpec(tsHashChannel, ElementType.BYTES_REF), @@ -96,6 +100,7 @@ public Operator get(DriverContext driverContext) { ); return new HashAggregationOperator( aggregators, + groupings, () -> BlockHash.build(hashGroups, driverContext.blockFactory(), maxPageSize, false), driverContext ); @@ -108,7 +113,7 @@ public String describe() { } public record Final( - List groupings, + List groupings, List outerRates, List nonRates, int maxPageSize @@ -122,9 +127,11 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.FINAL)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.FINAL)).toList(); return new HashAggregationOperator( aggregators, - () -> BlockHash.build(groupings, driverContext.blockFactory(), maxPageSize, false), + groupings, + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groupings), driverContext.blockFactory(), maxPageSize, false), driverContext ); } @@ -135,21 +142,12 @@ public String describe() { } } - static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { + static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { List aggregators = new ArrayList<>(); - for (BlockHash.GroupSpec g : groupings) { - if (g.channel() != timeBucketChannel) { - final List channels = List.of(g.channel()); - // TODO: perhaps introduce a specialized aggregator for this? - var aggregatorSupplier = (switch (g.elementType()) { - case BYTES_REF -> new org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier(channels); - case DOUBLE -> new org.elasticsearch.compute.aggregation.ValuesDoubleAggregatorFunctionSupplier(channels); - case INT -> new org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier(channels); - case LONG -> new org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier(channels); - case BOOLEAN -> new org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier(channels); - case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); - }); - aggregators.add(aggregatorSupplier.groupingAggregatorFactory(AggregatorMode.SINGLE)); + for (GroupingKey.Factory g : groupings) { + GroupingAggregator.Factory factory = g.valuesAggregatorForGroupingsInTimeSeries(timeBucketChannel); + if (factory != null) { + aggregators.add(factory); } } return aggregators; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 8b69b5584e65d..a1f0eb77ef2d4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -33,15 +33,19 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.CountAggregatorFunction; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockTestUtils; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; @@ -61,8 +65,12 @@ import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.RowInTableLookupOperator; +import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.ShuffleDocsOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; +import org.elasticsearch.compute.operator.topn.TopNEncoder; +import org.elasticsearch.compute.operator.topn.TopNOperator; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -81,6 +89,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; +import java.util.stream.Stream; import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; @@ -88,6 +97,7 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; // TODO: Move these tests to the right test classes. public class OperatorTests extends MapperServiceTestCase { @@ -194,16 +204,11 @@ public String toString() { ) ); operators.add( - new HashAggregationOperator( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(GroupingKey.forStatelessGrouping(0, ElementType.BYTES_REF).get(FINAL)), List.of(CountAggregatorFunction.supplier(List.of(1, 2)).groupingAggregatorFactory(FINAL)), - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)), - driverContext.blockFactory(), - randomPageSize(), - false - ), - driverContext - ) + randomPageSize() + ).get(driverContext) ); Driver driver = new Driver( driverContext, @@ -389,4 +394,139 @@ static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query qu limit ); } + + public void testStatefulGrouping() { + DriverContext driverContext = driverContext(); + Stream input = Stream.of( + new BytesRef("abc"), + new BytesRef("def"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("blah") + ); + List output = new ArrayList<>(); + List operators = new ArrayList<>(); + + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new ExampleStatefulGroupingFunction.Factory(INITIAL, 0)), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new ExampleStatefulGroupingFunction.Factory(FINAL, 0)), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 3, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ); + + Driver driver = new Driver( + driverContext, + new SequenceBytesRefBlockSourceOperator(driverContext.blockFactory(), input), + operators, + new TestResultPageSinkOperator(output::add), + () -> {} + ); + OperatorTestCase.runDriver(driver); + + assertThat(output, hasSize(1)); + assertThat(output.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = output.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat(values, equalTo(List.of("7abc", "7blah", "7def"))); + } + + static class ExampleStatefulGroupingFunction implements GroupingKey.Thing { + record Factory(AggregatorMode mode, int inputChannel) implements GroupingKey.Factory { + @Override + public GroupingKey apply(DriverContext context, int resultOffset) { + return new GroupingKey(mode, new ExampleStatefulGroupingFunction(inputChannel, resultOffset), context.blockFactory()); + } + + @Override + public ElementType intermediateElementType() { + return ElementType.BYTES_REF; + } + + @Override + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + throw new UnsupportedOperationException(); + } + } + + private final int inputChannel; + private final int resultOffset; + + int count; + + ExampleStatefulGroupingFunction(int inputChannel, int resultOffset) { + this.inputChannel = inputChannel; + this.resultOffset = resultOffset; + } + + @Override + public int extraIntermediateBlocks() { + return 1; + } + + @Override + public Block evalRawInput(Page page) { + count += page.getPositionCount(); + Block block = page.getBlock(inputChannel); + block.incRef(); + return block; + } + + @Override + public Block evalIntermediateInput(BlockFactory blockFactory, Page page) { + IntBlock block = page.getBlock(resultOffset + 1); + IntVector vector = block.asVector(); + assertThat(vector.isConstant(), equalTo(true)); + count = vector.getInt(0); + Block b = page.getBlock(resultOffset); + b.incRef(); + return b; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) { + blocks[resultOffset + 1] = blockFactory.newConstantIntBlockWith(count, positionCount); + } + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) { + try ( + BytesRefBlock block = (BytesRefBlock) blocks[resultOffset]; + BytesRefVector.Builder replacement = blockFactory.newBytesRefVectorBuilder(block.getPositionCount()) + ) { + BytesRefVector vector = block.asVector(); + for (int p = 0; p < vector.getPositionCount(); p++) { + replacement.appendBytesRef(new BytesRef(count + vector.getBytesRef(p, new BytesRef()).utf8ToString())); + } + blocks[resultOffset] = replacement.build().asBlock(); + } + } + + @Override + public void close() {} + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index de9337f5fce2c..749051b0b3637 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -10,7 +10,6 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockTestUtils; @@ -90,7 +89,7 @@ protected final Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { supplier = chunkGroups(emitChunkSize, supplier); } return new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), + List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(mode)), List.of(supplier.groupingAggregatorFactory(mode)), randomPageSize() ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index f2fa94c1feb08..85eaf4b780470 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -8,13 +8,13 @@ package org.elasticsearch.compute.operator; import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunction; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxLongGroupingAggregatorFunctionTests; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunction; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongGroupingAggregatorFunctionTests; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; @@ -53,7 +53,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { } return new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), + List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(mode)), List.of( new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java index 75e71ff697efb..733b77954d42c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.operator; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; @@ -15,8 +16,9 @@ import java.util.stream.Stream; /** - * A source operator whose output is the given double values. This operator produces pages - * containing a single Block. The Block contains the double values from the given list, in order. + * A source operator whose output is the given {@link BytesRef} values. + * This operator produces pages containing a single {@link Block}. The Block + * contains the double values from the given list, in order. */ public class SequenceBytesRefBlockSourceOperator extends AbstractBlockSourceOperator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java index da1a9c9408f90..c40712b59c407 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java @@ -13,9 +13,9 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Rounding; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.RateLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; @@ -265,7 +265,9 @@ public void close() { Operator intialAgg = new TimeSeriesAggregationOperatorFactories.Initial( 1, 3, - IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(), + IntStream.range(0, nonBucketGroupings.size()) + .mapToObj(n -> GroupingKey.forStatelessGrouping(5 + n, ElementType.BYTES_REF)) + .toList(), List.of(new RateLongAggregatorFunctionSupplier(List.of(4, 2), unitInMillis)), List.of(), between(1, 100) @@ -275,19 +277,21 @@ public void close() { Operator intermediateAgg = new TimeSeriesAggregationOperatorFactories.Intermediate( 0, 1, - IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(), + IntStream.range(0, nonBucketGroupings.size()) + .mapToObj(n -> GroupingKey.forStatelessGrouping(5 + n, ElementType.BYTES_REF)) + .toList(), List.of(new RateLongAggregatorFunctionSupplier(List.of(2, 3, 4), unitInMillis)), List.of(), between(1, 100) ).get(ctx); // tsid, bucket, rate, grouping1, grouping2 - List finalGroups = new ArrayList<>(); + List finalGroups = new ArrayList<>(); int groupChannel = 3; for (String grouping : groupings) { if (grouping.equals("bucket")) { - finalGroups.add(new BlockHash.GroupSpec(1, ElementType.LONG)); + finalGroups.add(GroupingKey.forStatelessGrouping(1, ElementType.LONG)); } else { - finalGroups.add(new BlockHash.GroupSpec(groupChannel++, ElementType.BYTES_REF)); + finalGroups.add(GroupingKey.forStatelessGrouping(groupChannel++, ElementType.BYTES_REF)); } } Operator finalAgg = new TimeSeriesAggregationOperatorFactories.Final( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 82c836a6f9d49..77d8ef1780777 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -10,13 +10,29 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.core.WhitespaceTokenizer; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.core.Releasables; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.CustomAnalyzer; import org.elasticsearch.index.analysis.TokenFilterFactory; @@ -31,11 +47,15 @@ import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; +import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.function.Function; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; @@ -135,7 +155,7 @@ protected TypeResolution resolveType() { @Override public DataType dataType() { - return DataType.INTEGER; + return DataType.KEYWORD; } @Override @@ -156,4 +176,175 @@ public Expression field() { public String toString() { return "Categorize{field=" + field + "}"; } + + public GroupingKey.Supplier groupingKey(Function toEvaluator) { + return mode -> new GroupingKeyFactory(source(), toEvaluator.apply(field), mode); + } + + record GroupingKeyFactory(Source source, ExpressionEvaluator.Factory field, AggregatorMode mode) implements GroupingKey.Factory { + @Override + public GroupingKey apply(DriverContext context, int resultOffset) { + ExpressionEvaluator field = this.field.get(context); + CategorizeEvaluator evaluator = null; + TokenListCategorizer.CloseableTokenListCategorizer categorizer = null; + try { + categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( + new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())), + CategorizationPartOfSpeechDictionary.getInstance(), + 0.70f + ); + evaluator = new CategorizeEvaluator( + source, + field, + new CategorizationAnalyzer( + // TODO(jan): get the correct analyzer in here, see + // CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer + new CustomAnalyzer( + TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new), + new CharFilterFactory[0], + new TokenFilterFactory[0] + ), + true + ), + categorizer, + context + ); + field = null; + GroupingKey result = new GroupingKey( + mode, + new GroupingKeyThing(resultOffset, evaluator, categorizer), + context.blockFactory() + ); + categorizer = null; + evaluator = null; + return result; + } finally { + Releasables.close(field, evaluator, categorizer); + } + + } + + @Override + public ElementType intermediateElementType() { + return ElementType.INT; + } + + @Override + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + throw new UnsupportedOperationException("not supported in time series"); + } + } + + private static class GroupingKeyThing implements GroupingKey.Thing { + private final int resultOffset; + private final CategorizeEvaluator evaluator; + private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; + + private GroupingKeyThing( + int resultOffset, + CategorizeEvaluator evaluator, + TokenListCategorizer.CloseableTokenListCategorizer categorizer + ) { + this.resultOffset = resultOffset; + this.evaluator = evaluator; + this.categorizer = categorizer; + } + + @Override + public Block evalRawInput(Page page) { + return evaluator.eval(page); + } + + @Override + public int extraIntermediateBlocks() { + return 1; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) { + blocks[resultOffset + 1] = buildIntermediateBlock(blockFactory, positionCount); + } + + @Override + public Block evalIntermediateInput(BlockFactory blockFactory, Page page) { + BytesRefBlock intermediate = page.getBlock(resultOffset + 1); + Map idMap; + if (intermediate.areAllValuesNull() == false) { + idMap = readIntermediate(intermediate.getBytesRef(0, new BytesRef())); + } else { + idMap = Collections.emptyMap(); + } + IntBlock oldIds = page.getBlock(resultOffset); + try (IntBlock.Builder newIds = blockFactory.newIntBlockBuilder(oldIds.getTotalValueCount())) { + for (int i = 0; i < oldIds.getTotalValueCount(); i++) { + newIds.appendInt(idMap.get(i)); + } + return newIds.build(); + } + } + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) { + // NOCOMMIT this offset can't be the same in the result array and intermediate array + IntBlock keys = (IntBlock) blocks[resultOffset]; + blocks[resultOffset] = finalKeys(blockFactory, keys); + System.err.println(blocks[resultOffset]); + } + + @Override + public void close() { + evaluator.close(); + } + + private Block buildIntermediateBlock(BlockFactory blockFactory, int positionCount) { + if (categorizer.getCategoryCount() == 0) { + return blockFactory.newConstantNullBlock(positionCount); + } + try (BytesStreamOutput out = new BytesStreamOutput()) { + // TODO be more careful here. + out.writeVInt(categorizer.getCategoryCount()); + for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) { + category.writeTo(out); + } + return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Map readIntermediate(BytesRef bytes) { + Map idMap = new HashMap<>(); + try (StreamInput in = new BytesArray(bytes).streamInput()) { + int count = in.readVInt(); + for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) { + SerializableTokenListCategory category = new SerializableTokenListCategory(in); + int newCategoryId = categorizer.mergeWireCategory(category).getId(); + System.err.println("category id map: " + oldCategoryId + " -> " + newCategoryId + " (" + category.getRegex() + ")"); + idMap.put(oldCategoryId, newCategoryId); + } + return idMap; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private OrdinalBytesRefBlock finalKeys(BlockFactory blockFactory, IntBlock keys) { + return new OrdinalBytesRefBlock(keys, finalBytes(blockFactory)); + } + + /** + * A lookup table containing the category names. + */ + private BytesRefVector finalBytes(BlockFactory blockFactory) { + try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { + BytesRefBuilder scratch = new BytesRefBuilder(); + for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) { + scratch.copyChars(category.getRegex()); + result.appendBytesRef(scratch.get()); + scratch.clear(); + } + return result.build(); + } + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 0e71963e29270..63e6c6a410e5b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -11,7 +11,7 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory; @@ -159,8 +159,12 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato context ); } else { + List groupings = new ArrayList<>(groupSpecs.size()); + for (GroupSpec group : groupSpecs) { + groupings.add(group.toGroupingKey().get(aggregatorMode)); + } operatorFactory = new HashAggregationOperatorFactory( - groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), + groupings, aggregatorFactories, context.pageSize(aggregateExec.estimatedRowSize()) ); @@ -285,11 +289,11 @@ private void aggregatesToFactory( } private record GroupSpec(Integer channel, Attribute attribute) { - BlockHash.GroupSpec toHashGroupSpec() { + GroupingKey.Supplier toGroupingKey() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } - return new BlockHash.GroupSpec(channel, elementType()); + return GroupingKey.forStatelessGrouping(channel, elementType()); } ElementType elementType() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java new file mode 100644 index 0000000000000..1eabbe188e660 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java @@ -0,0 +1,314 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.aggregation.GroupingKey; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.DriverRunner; +import org.elasticsearch.compute.operator.HashAggregationOperator; +import org.elasticsearch.compute.operator.LocalSourceOperator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.topn.TopNEncoder; +import org.elasticsearch.compute.operator.topn.TopNOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.junit.After; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; +import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizeOperatorTests extends ESTestCase { + public void testCategorization() { + DriverContext driverContext = driverContext(); + LocalSourceOperator.BlockSupplier input = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("words words words hello nik")); + builder.appendBytesRef(new BytesRef("words words words goodbye nik")); + builder.appendBytesRef(new BytesRef("words words words hello jan")); + builder.appendBytesRef(new BytesRef("words words words goodbye jan")); + builder.appendBytesRef(new BytesRef("words words words hello kitty")); + builder.appendBytesRef(new BytesRef("words words words goodbye blue sky")); + return new Block[] { builder.build().asBlock() }; + } + }; + List output = new ArrayList<>(); + try { + List operators = new ArrayList<>(); + + Categorize cat = new Categorize(Source.EMPTY, AbstractFunctionTestCase.field("f", DataType.KEYWORD)); + GroupingKey.Supplier key = cat.groupingKey(AbstractFunctionTestCase::evaluator); + + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ); + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(FINAL)), List.of(), 16 * 1024).get(driverContext) + ); + operators.add( + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 3, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ); + + Driver driver = new Driver( + driverContext, + new LocalSourceOperator(input), + operators, + new PageConsumerOperator(output::add), + () -> {} + ); + runDriver(driver); + + assertThat(output, hasSize(1)); + assertThat(output.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = output.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat( + values, + equalTo( + List.of( + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?goodbye.+?blue.+?sky.*?", + ".*?words.+?words.+?words.+?hello.*?" + ) + ) + ); + } finally { + Releasables.close(() -> Iterators.map(output.iterator(), (Page p) -> p::releaseBlocks)); + } + } + + /** + * {@link SourceOperator} that returns a sequence of pre-built {@link Page}s. + * TODO: this class is copy-pasted from the esql:compute plugin; fix that + */ + public static class CannedSourceOperator extends SourceOperator { + + private final Iterator page; + + public CannedSourceOperator(Iterator page) { + this.page = page; + } + + @Override + public void finish() { + while (page.hasNext()) { + page.next(); + } + } + + @Override + public boolean isFinished() { + return false == page.hasNext(); + } + + @Override + public Page getOutput() { + return page.next(); + } + + @Override + public void close() { + // release pages in the case of early termination - failure + while (page.hasNext()) { + page.next().releaseBlocks(); + } + } + } + + public void testCategorization_multipleNodes() { + DriverContext driverContext = driverContext(); + LocalSourceOperator.BlockSupplier input1 = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("a")); + builder.appendBytesRef(new BytesRef("b")); + builder.appendBytesRef(new BytesRef("words words words goodbye jan")); + builder.appendBytesRef(new BytesRef("words words words goodbye nik")); + builder.appendBytesRef(new BytesRef("words words words hello jan")); + builder.appendBytesRef(new BytesRef("c")); + return new Block[] { builder.build().asBlock() }; + } + }; + + LocalSourceOperator.BlockSupplier input2 = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("words words words hello nik")); + builder.appendBytesRef(new BytesRef("c")); + builder.appendBytesRef(new BytesRef("words words words goodbye chris")); + builder.appendBytesRef(new BytesRef("d")); + builder.appendBytesRef(new BytesRef("e")); + return new Block[] { builder.build().asBlock() }; + } + }; + + List intermediateOutput = new ArrayList<>(); + List finalOutput = new ArrayList<>(); + + try { + Categorize cat = new Categorize(Source.EMPTY, AbstractFunctionTestCase.field("f", DataType.KEYWORD)); + GroupingKey.Supplier key = cat.groupingKey(AbstractFunctionTestCase::evaluator); + + Driver driver = new Driver( + driverContext, + new LocalSourceOperator(input1), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + driver = new Driver( + driverContext, + new LocalSourceOperator(input2), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + assertThat(intermediateOutput, hasSize(2)); + + driver = new Driver( + driverContext, + new CannedSourceOperator(intermediateOutput.iterator()), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(FINAL)), List.of(), 16 * 1024).get( + driverContext + ), + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 10, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ), + new PageConsumerOperator(finalOutput::add), + () -> {} + ); + runDriver(driver); + + assertThat(finalOutput, hasSize(1)); + assertThat(finalOutput.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = finalOutput.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat( + values, + equalTo( + List.of( + ".*?a.*?", + ".*?b.*?", + ".*?c.*?", + ".*?d.*?", + ".*?e.*?", + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?hello.*?" + ) + ) + ); + } finally { + Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); + } + } + + private final List breakers = Collections.synchronizedList(new ArrayList<>()); + + private DriverContext driverContext() { + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + breakers.add(breaker); + return new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + } + + @After + public void allMemoryReleased() { + for (CircuitBreaker breaker : breakers) { + assertThat(breaker.getUsed(), equalTo(0L)); + } + } + + public static void runDriver(Driver driver) { + ThreadPool threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder(Settings.EMPTY, "esql", 1, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT) + ); + var driverRunner = new DriverRunner(threadPool.getThreadContext()) { + @Override + protected void start(Driver driver, ActionListener driverListener) { + Driver.start(threadPool.getThreadContext(), threadPool.executor("esql"), driver, between(1, 10000), driverListener); + } + }; + PlainActionFuture future = new PlainActionFuture<>(); + try { + driverRunner.runToCompletion(List.of(driver), future); + future.actionGet(TimeValue.timeValueSeconds(30)); + } finally { + terminate(threadPool); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index c811643c8daea..4e8cefa81f127 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -11,7 +11,9 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.Describable; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -230,11 +232,12 @@ private class TestHashAggregationOperator extends HashAggregationOperator { TestHashAggregationOperator( List aggregators, + List groups, Supplier blockHash, String columnName, DriverContext driverContext ) { - super(aggregators, blockHash, driverContext); + super(aggregators, groups, blockHash, driverContext); this.columnName = columnName; } @@ -273,14 +276,13 @@ private class TestOrdinalsGroupingAggregationOperatorFactory implements Operator public Operator get(DriverContext driverContext) { Random random = Randomness.get(); int pageSize = random.nextBoolean() ? randomIntBetween(random, 1, 16) : randomIntBetween(random, 1, 10 * 1024); + List groupings = List.of( + GroupingKey.forStatelessGrouping(groupByChannel, groupElementType).get(AggregatorMode.INITIAL) + ); return new TestHashAggregationOperator( aggregators, - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(groupByChannel, groupElementType)), - driverContext.blockFactory(), - pageSize, - false - ), + groupings, + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groupings), driverContext.blockFactory(), pageSize, false), columnName, driverContext ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java index d0088edcb0805..9295bff840c40 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java @@ -207,6 +207,7 @@ public TokenListCategory mergeWireCategory(SerializableTokenListCategory seriali } private synchronized TokenListCategory computeCategory( + // NOCOMMIT why synchronized? in ESQL at least there are no threads around List weightedTokenIds, List workTokenUniqueIds, int workWeight, @@ -424,6 +425,13 @@ public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) { .toArray(InternalCategorizationAggregation.Bucket[]::new); } + public List toCategories(int size) { + return categoriesByNumMatches.stream() + .limit(size) + .map(category -> new SerializableTokenListCategory(category, bytesRefHash)) + .toList(); + } + public InternalCategorizationAggregation.Bucket[] toOrderedBuckets( int size, long minNumMatches,