From bc389e42dd05681ac0d11d4ea3f92ae01b47003d Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 11 Sep 2024 11:26:59 -0400 Subject: [PATCH 01/10] Start --- .../compute/operator/AggregatorBenchmark.java | 37 ++--- .../compute/aggregation/GroupingKey.java | 101 +++++++++++++ .../operator/HashAggregationOperator.java | 70 ++++++--- .../operator/OrdinalsGroupingOperator.java | 6 + ...imeSeriesAggregationOperatorFactories.java | 19 ++- .../elasticsearch/compute/OperatorTests.java | 138 ++++++++++++++++-- .../GroupingAggregatorFunctionTestCase.java | 3 +- .../HashAggregationOperatorTests.java | 4 +- .../SequenceBytesRefBlockSourceOperator.java | 6 +- .../TimeSeriesAggregationOperatorTests.java | 16 +- .../AbstractPhysicalOperationProviders.java | 12 +- .../TestPhysicalOperationProviders.java | 16 +- 12 files changed, 349 insertions(+), 79 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 8b22e50e4e8c9..af0635650c4e0 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -16,13 +16,13 @@ import org.elasticsearch.compute.aggregation.CountAggregatorFunction; import org.elasticsearch.compute.aggregation.CountDistinctDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -124,29 +124,32 @@ private static Operator operator(DriverContext driverContext, String grouping, S driverContext ); } - List groups = switch (grouping) { - case LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); - case INTS -> List.of(new BlockHash.GroupSpec(0, ElementType.INT)); - case DOUBLES -> List.of(new BlockHash.GroupSpec(0, ElementType.DOUBLE)); - case BOOLEANS -> List.of(new BlockHash.GroupSpec(0, ElementType.BOOLEAN)); - case BYTES_REFS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)); - case TWO_LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG)); + List groups = switch (grouping) { + case LONGS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE)); + case INTS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.INT).get(AggregatorMode.SINGLE)); + case DOUBLES -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.DOUBLE).get(AggregatorMode.SINGLE)); + case BOOLEANS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BOOLEAN).get(AggregatorMode.SINGLE)); + case BYTES_REFS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.BYTES_REF).get(AggregatorMode.SINGLE)); + case TWO_LONGS -> List.of( + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE) + ); case LONGS_AND_BYTES_REFS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG), - new BlockHash.GroupSpec(1, ElementType.BYTES_REF) + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.BYTES_REF).get(AggregatorMode.SINGLE) ); case TWO_LONGS_AND_BYTES_REFS -> List.of( - new BlockHash.GroupSpec(0, ElementType.LONG), - new BlockHash.GroupSpec(1, ElementType.LONG), - new BlockHash.GroupSpec(2, ElementType.BYTES_REF) + GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(1, ElementType.LONG).get(AggregatorMode.SINGLE), + GroupingKey.forStatelessGrouping(2, ElementType.BYTES_REF).get(AggregatorMode.SINGLE) ); default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]"); }; - return new HashAggregationOperator( + return new HashAggregationOperator.HashAggregationOperatorFactory( + groups, List.of(supplier(op, dataType, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)), - () -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false), - driverContext - ); + 16 * 1024 + ).get(driverContext); } private static AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java new file mode 100644 index 0000000000000..e14867764ca4c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +import java.util.List; + +public record GroupingKey(int channel, AggregatorMode mode, Thing thing) { + public interface Thing { + int intermediateBlockCount(); + + ElementType intermediateElementType(); + + ElementType finalElementType(); + + void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount); + + void receiveIntermediateState(Page page, int offset); + + void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset); + } + + public interface Supplier { + GroupingKey get(AggregatorMode mode); + } + + public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType elementType) { + return mode -> new GroupingKey(channel, mode, new Thing() { + @Override + public int intermediateBlockCount() { + return 0; + } + + @Override + public ElementType intermediateElementType() { + return elementType; + } + + @Override + public ElementType finalElementType() { + return elementType; + } + + @Override + public void receiveIntermediateState(Page page, int offset) {} + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) {} + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset) {} + + @Override + public String toString() { + return "Stateless"; + } + }); + } + + public static List toBlockHashGroupSpec(List keys) { + return keys.stream().map(GroupingKey::toBlockHashSpec).toList(); + } + + public BlockHash.GroupSpec toBlockHashSpec() { + return new BlockHash.GroupSpec(channel, elementType()); // NOCOMMIT this should probably be an evaluator and a BlockType + } + + public ElementType elementType() { + return mode.isOutputPartial() ? thing.intermediateElementType() : thing.finalElementType(); + } + + public void receive(Page page, int offset) { + if (mode.isInputPartial()) { + thing.receiveIntermediateState(page, offset); + } + } + + public int evaluateBlockCount() { + return 1 + (mode.isOutputPartial() ? thing.intermediateBlockCount() : 0); + } + + public void evaluate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + if (mode.isOutputPartial()) { + thing.fetchIntermediateState(driverContext.blockFactory(), blocks, offset + 1, selected.getPositionCount()); + } else { + thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks, offset); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 03a4ca2b0ad5e..bfbc93b81a4c9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.Describable; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntBlock; @@ -27,7 +28,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -36,17 +36,20 @@ import static java.util.stream.Collectors.joining; public class HashAggregationOperator implements Operator { - - public record HashAggregationOperatorFactory( - List groups, - List aggregators, - int maxPageSize - ) implements OperatorFactory { + public record HashAggregationOperatorFactory(List groups, List aggregators, int maxPageSize) + implements + OperatorFactory { @Override public Operator get(DriverContext driverContext) { return new HashAggregationOperator( aggregators, - () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false), + groups, + () -> BlockHash.build( + groups.stream().map(GroupingKey::toBlockHashSpec).toList(), + driverContext.blockFactory(), + maxPageSize, + false + ), driverContext ); } @@ -61,15 +64,17 @@ public String describe() { } } - private boolean finished; - private Page output; + private final List aggregators; - private final BlockHash blockHash; + private final List groups; - private final List aggregators; + private final BlockHash blockHash; private final DriverContext driverContext; + private boolean finished; + private Page output; + /** * Nanoseconds this operator has spent hashing grouping keys. */ @@ -86,10 +91,12 @@ public String describe() { @SuppressWarnings("this-escape") public HashAggregationOperator( List aggregators, + List groups, Supplier blockHash, DriverContext driverContext ) { this.aggregators = new ArrayList<>(aggregators.size()); + this.groups = groups; this.driverContext = driverContext; boolean success = false; try { @@ -158,7 +165,12 @@ public void close() { } try (AddInput add = new AddInput()) { checkState(needsInput(), "Operator is already finishing"); - requireNonNull(page, "page is null"); + + int offset = 0; + for (GroupingKey key : groups) { + key.receive(page, offset); + offset += key.evaluateBlockCount(); + } for (int i = 0; i < prepared.length; i++) { prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page); @@ -192,15 +204,31 @@ public void finish() { try { selected = blockHash.nonEmpty(); Block[] keys = blockHash.getKeys(); - int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray(); - blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()]; - System.arraycopy(keys, 0, blocks, 0, keys.length); - int offset = keys.length; - for (int i = 0; i < aggregators.size(); i++) { - var aggregator = aggregators.get(i); - aggregator.evaluate(blocks, offset, selected, driverContext); - offset += aggBlockCounts[i]; + + int blockCount = 0; + int[] groupBlockCounts = new int[groups.size()]; + for (int g = 0; g < groups.size(); g++) { + groupBlockCounts[g] = groups.get(g).evaluateBlockCount(); + blockCount += groupBlockCounts[g]; + } + int[] aggBlockCounts = new int[aggregators.size()]; + for (int a = 0; a < aggregators.size(); a++) { + aggBlockCounts[a] = aggregators.get(a).evaluateBlockCount(); + blockCount += aggBlockCounts[a]; } + + blocks = new Block[blockCount]; + int offset = 0; + for (int g = 0; g < groups.size(); g++) { + blocks[offset] = keys[g]; + groups.get(g).evaluate(blocks, offset, selected, driverContext); + offset += groupBlockCounts[g]; + } + for (int a = 0; a < aggregators.size(); a++) { + aggregators.get(a).evaluate(blocks, offset, selected, driverContext); + offset += aggBlockCounts[a]; + } + output = new Page(blocks); success = true; } finally { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index b5ae35bfc8d7f..b1e3711fd791c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -17,9 +17,11 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; import org.elasticsearch.compute.Describable; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.GroupingAggregator.Factory; import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.SeenGroupIds; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.aggregation.blockhash.BlockHash.GroupSpec; @@ -499,6 +501,10 @@ private static class ValuesAggregator implements Releasable { ); this.aggregator = new HashAggregationOperator( aggregatorFactories, + List.of( + // NOCOMMIT double check the mode + GroupingKey.forStatelessGrouping(channelIndex, groupingElementType).get(AggregatorMode.INITIAL) + ), () -> BlockHash.build( List.of(new GroupSpec(channelIndex, groupingElementType)), driverContext.blockFactory(), diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java index 1e9ea88b2f1d7..3971f22fd7fa7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java @@ -10,6 +10,7 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.aggregation.blockhash.TimeSeriesBlockHash; import org.elasticsearch.compute.data.ElementType; @@ -44,7 +45,7 @@ public final class TimeSeriesAggregationOperatorFactories { public record Initial( int tsHashChannel, int timeBucketChannel, - List groupings, + List groupings, List rates, List nonRates, int maxPageSize @@ -58,9 +59,11 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INITIAL)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INITIAL)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); return new HashAggregationOperator( aggregators, + groupings, () -> new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext), driverContext ); @@ -75,7 +78,7 @@ public String describe() { public record Intermediate( int tsHashChannel, int timeBucketChannel, - List groupings, + List groupings, List rates, List nonRates, int maxPageSize @@ -89,6 +92,7 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INTERMEDIATE)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INTERMEDIATE)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); List hashGroups = List.of( new BlockHash.GroupSpec(tsHashChannel, ElementType.BYTES_REF), @@ -96,6 +100,7 @@ public Operator get(DriverContext driverContext) { ); return new HashAggregationOperator( aggregators, + groupings, () -> BlockHash.build(hashGroups, driverContext.blockFactory(), maxPageSize, false), driverContext ); @@ -108,7 +113,7 @@ public String describe() { } public record Final( - List groupings, + List groupings, List outerRates, List nonRates, int maxPageSize @@ -122,9 +127,11 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.FINAL)); } + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.FINAL)).toList(); return new HashAggregationOperator( aggregators, - () -> BlockHash.build(groupings, driverContext.blockFactory(), maxPageSize, false), + groupings, + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groupings), driverContext.blockFactory(), maxPageSize, false), driverContext ); } @@ -135,9 +142,9 @@ public String describe() { } } - static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { + static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { List aggregators = new ArrayList<>(); - for (BlockHash.GroupSpec g : groupings) { + for (GroupingKey g : groupings) { if (g.channel() != timeBucketChannel) { final List channels = List.of(g.channel()); // TODO: perhaps introduce a specialized aggregator for this? diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 8b69b5584e65d..8b244ee134ec7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -34,14 +34,18 @@ import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.compute.aggregation.CountAggregatorFunction; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; +import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockTestUtils; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.DocBlock; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.Page; @@ -52,6 +56,7 @@ import org.elasticsearch.compute.lucene.ShardContext; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; +import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; @@ -61,8 +66,12 @@ import org.elasticsearch.compute.operator.OrdinalsGroupingOperator; import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.RowInTableLookupOperator; +import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.ShuffleDocsOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; +import org.elasticsearch.compute.operator.topn.TopNEncoder; +import org.elasticsearch.compute.operator.topn.TopNOperator; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.mapper.KeywordFieldMapper; @@ -81,6 +90,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; +import java.util.stream.Stream; import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; @@ -88,6 +98,7 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; // TODO: Move these tests to the right test classes. public class OperatorTests extends MapperServiceTestCase { @@ -194,16 +205,11 @@ public String toString() { ) ); operators.add( - new HashAggregationOperator( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(GroupingKey.forStatelessGrouping(0, ElementType.BYTES_REF).get(FINAL)), List.of(CountAggregatorFunction.supplier(List.of(1, 2)).groupingAggregatorFactory(FINAL)), - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF)), - driverContext.blockFactory(), - randomPageSize(), - false - ), - driverContext - ) + randomPageSize() + ).get(driverContext) ); Driver driver = new Driver( driverContext, @@ -230,6 +236,118 @@ public String toString() { assertThat(blockFactory.breaker().getUsed(), equalTo(0L)); } + public void testStatefulGrouping() { + DriverContext driverContext = driverContext(); + Stream input = Stream.of( + new BytesRef("abc"), + new BytesRef("def"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("blah") + ); + List output = new ArrayList<>(); + List operators = new ArrayList<>(); + + class Example implements GroupingKey.Thing { + int count; + + @Override + public int intermediateBlockCount() { + return 1; + } + + @Override + public ElementType intermediateElementType() { + return ElementType.BYTES_REF; + } + + @Override + public ElementType finalElementType() { + return ElementType.BYTES_REF; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) { + blocks[offset] = blockFactory.newConstantIntBlockWith(count, positionCount); + } + + @Override + public void receiveIntermediateState(Page page, int offset) { + IntBlock block = page.getBlock(offset + 1); + IntVector vector = block.asVector(); + assertThat(vector.isConstant(), equalTo(true)); + count = vector.getInt(0); + } + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset) { + try ( + BytesRefBlock block = (BytesRefBlock) blocks[offset]; + BytesRefVector.Builder replacement = blockFactory.newBytesRefVectorBuilder(block.getPositionCount()) + ) { + BytesRefVector vector = block.asVector(); + for (int p = 0; p < vector.getPositionCount(); p++) { + replacement.appendBytesRef(new BytesRef(count + vector.getBytesRef(p, new BytesRef()).utf8ToString())); + } + blocks[offset] = replacement.build().asBlock(); + } + } + } + + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new GroupingKey(0, INITIAL, new Example() { + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) { + this.count = 10; // NOCOMMIT remove me + super.fetchIntermediateState(blockFactory, blocks, offset, positionCount); + } + })), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new GroupingKey(0, FINAL, new Example())), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 3, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ); + + Driver driver = new Driver( + driverContext, + new SequenceBytesRefBlockSourceOperator(driverContext.blockFactory(), input), + operators, + new TestResultPageSinkOperator(output::add), + () -> {} + ); + OperatorTestCase.runDriver(driver); + + assertThat(output, hasSize(1)); + assertThat(output.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = output.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat(values, equalTo(List.of("7abc", "7blah", "7def"))); + } + public void testLimitOperator() { var positions = 100; var limit = randomIntBetween(90, 101); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index de9337f5fce2c..749051b0b3637 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -10,7 +10,6 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BitArray; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockTestUtils; @@ -90,7 +89,7 @@ protected final Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { supplier = chunkGroups(emitChunkSize, supplier); } return new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), + List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(mode)), List.of(supplier.groupingAggregatorFactory(mode)), randomPageSize() ); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index f2fa94c1feb08..85eaf4b780470 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -8,13 +8,13 @@ package org.elasticsearch.compute.operator; import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunction; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxLongGroupingAggregatorFunctionTests; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunction; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongGroupingAggregatorFunctionTests; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.ElementType; @@ -53,7 +53,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { } return new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), + List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(mode)), List.of( new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java index 75e71ff697efb..733b77954d42c 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBytesRefBlockSourceOperator.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.operator; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; @@ -15,8 +16,9 @@ import java.util.stream.Stream; /** - * A source operator whose output is the given double values. This operator produces pages - * containing a single Block. The Block contains the double values from the given list, in order. + * A source operator whose output is the given {@link BytesRef} values. + * This operator produces pages containing a single {@link Block}. The Block + * contains the double values from the given list, in order. */ public class SequenceBytesRefBlockSourceOperator extends AbstractBlockSourceOperator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java index da1a9c9408f90..c40712b59c407 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorTests.java @@ -13,9 +13,9 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Rounding; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.RateLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; @@ -265,7 +265,9 @@ public void close() { Operator intialAgg = new TimeSeriesAggregationOperatorFactories.Initial( 1, 3, - IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(), + IntStream.range(0, nonBucketGroupings.size()) + .mapToObj(n -> GroupingKey.forStatelessGrouping(5 + n, ElementType.BYTES_REF)) + .toList(), List.of(new RateLongAggregatorFunctionSupplier(List.of(4, 2), unitInMillis)), List.of(), between(1, 100) @@ -275,19 +277,21 @@ public void close() { Operator intermediateAgg = new TimeSeriesAggregationOperatorFactories.Intermediate( 0, 1, - IntStream.range(0, nonBucketGroupings.size()).mapToObj(n -> new BlockHash.GroupSpec(5 + n, ElementType.BYTES_REF)).toList(), + IntStream.range(0, nonBucketGroupings.size()) + .mapToObj(n -> GroupingKey.forStatelessGrouping(5 + n, ElementType.BYTES_REF)) + .toList(), List.of(new RateLongAggregatorFunctionSupplier(List.of(2, 3, 4), unitInMillis)), List.of(), between(1, 100) ).get(ctx); // tsid, bucket, rate, grouping1, grouping2 - List finalGroups = new ArrayList<>(); + List finalGroups = new ArrayList<>(); int groupChannel = 3; for (String grouping : groupings) { if (grouping.equals("bucket")) { - finalGroups.add(new BlockHash.GroupSpec(1, ElementType.LONG)); + finalGroups.add(GroupingKey.forStatelessGrouping(1, ElementType.LONG)); } else { - finalGroups.add(new BlockHash.GroupSpec(groupChannel++, ElementType.BYTES_REF)); + finalGroups.add(GroupingKey.forStatelessGrouping(groupChannel++, ElementType.BYTES_REF)); } } Operator finalAgg = new TimeSeriesAggregationOperatorFactories.Final( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 0e71963e29270..5e23e176312b3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -11,7 +11,7 @@ import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; -import org.elasticsearch.compute.aggregation.blockhash.BlockHash; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.operator.AggregationOperator; import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory; @@ -127,7 +127,7 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato } layout.append(groupAttributeLayout); Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id()); - groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute)); + groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute, aggregatorMode)); } if (aggregatorMode == AggregatorMode.FINAL) { @@ -160,7 +160,7 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato ); } else { operatorFactory = new HashAggregationOperatorFactory( - groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(), + groupSpecs.stream().map(GroupSpec::toGroupingKey).toList(), aggregatorFactories, context.pageSize(aggregateExec.estimatedRowSize()) ); @@ -284,12 +284,12 @@ private void aggregatesToFactory( } } - private record GroupSpec(Integer channel, Attribute attribute) { - BlockHash.GroupSpec toHashGroupSpec() { + private record GroupSpec(Integer channel, Attribute attribute, AggregatorMode mode) { + GroupingKey toGroupingKey() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } - return new BlockHash.GroupSpec(channel, elementType()); + return GroupingKey.forStatelessGrouping(channel, elementType()).get(mode); } ElementType elementType() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index 0cd1fa11a7499..85db6c7555ba9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -11,7 +11,9 @@ import org.elasticsearch.common.Randomness; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.Describable; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -230,11 +232,12 @@ private class TestHashAggregationOperator extends HashAggregationOperator { TestHashAggregationOperator( List aggregators, + List groups, Supplier blockHash, String columnName, DriverContext driverContext ) { - super(aggregators, blockHash, driverContext); + super(aggregators, groups, blockHash, driverContext); this.columnName = columnName; } @@ -273,14 +276,13 @@ private class TestOrdinalsGroupingAggregationOperatorFactory implements Operator public Operator get(DriverContext driverContext) { Random random = Randomness.get(); int pageSize = random.nextBoolean() ? randomIntBetween(random, 1, 16) : randomIntBetween(random, 1, 10 * 1024); + List groupings = List.of( + GroupingKey.forStatelessGrouping(groupByChannel, groupElementType).get(AggregatorMode.INITIAL) + ); return new TestHashAggregationOperator( aggregators, - () -> BlockHash.build( - List.of(new BlockHash.GroupSpec(groupByChannel, groupElementType)), - driverContext.blockFactory(), - pageSize, - false - ), + groupings, + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groupings), driverContext.blockFactory(), pageSize, false), columnName, driverContext ); From 712b672e7038539ae5c127136c91777c30824e3a Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 11 Sep 2024 11:36:21 -0400 Subject: [PATCH 02/10] Next --- .../compute/operator/HashAggregationOperator.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index bfbc93b81a4c9..38044153b8a93 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -163,6 +163,7 @@ public void close() { Releasables.closeExpectNoException(prepared); } } + Block[] keys = new Block[groups.size()]; try (AddInput add = new AddInput()) { checkState(needsInput(), "Operator is already finishing"); @@ -178,6 +179,8 @@ public void close() { blockHash.add(wrapPage(page), add); hashNanos += System.nanoTime() - add.hashStart; + } finally { + Releasables.close(keys); } } finally { page.releaseBlocks(); From 37415d27f61156512ea05027a78a99452f9af8e7 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 11 Sep 2024 13:15:26 -0400 Subject: [PATCH 03/10] Like so? --- .../compute/operator/AggregatorBenchmark.java | 2 +- .../compute/aggregation/GroupingKey.java | 143 +++++++--- .../operator/HashAggregationOperator.java | 47 ++-- ...imeSeriesAggregationOperatorFactories.java | 25 +- .../elasticsearch/compute/OperatorTests.java | 260 ++++++++++-------- .../AbstractPhysicalOperationProviders.java | 14 +- .../TestPhysicalOperationProviders.java | 4 +- 7 files changed, 290 insertions(+), 205 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index af0635650c4e0..d9ad2bcc95014 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -124,7 +124,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S driverContext ); } - List groups = switch (grouping) { + List groups = switch (grouping) { case LONGS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.LONG).get(AggregatorMode.SINGLE)); case INTS -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.INT).get(AggregatorMode.SINGLE)); case DOUBLES -> List.of(GroupingKey.forStatelessGrouping(0, ElementType.DOUBLE).get(AggregatorMode.SINGLE)); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java index e14867764ca4c..10a0cfb732edd 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java @@ -14,88 +14,147 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasable; +import java.util.ArrayList; import java.util.List; -public record GroupingKey(int channel, AggregatorMode mode, Thing thing) { - public interface Thing { - int intermediateBlockCount(); +public record GroupingKey(AggregatorMode mode, Thing thing) implements EvalOperator.ExpressionEvaluator { + public interface Thing extends Releasable { + int extraIntermediateBlocks(); ElementType intermediateElementType(); ElementType finalElementType(); - void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount); + Block evalRawInput(Page page); - void receiveIntermediateState(Page page, int offset); + Block evalIntermediateInput(Page page); - void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset); + void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount); + + void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks); } public interface Supplier { - GroupingKey get(AggregatorMode mode); + Factory get(AggregatorMode mode); + } + + public interface Factory { + GroupingKey apply(DriverContext context, int resultOffset); + + ElementType elementType(); + + GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel); } public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType elementType) { - return mode -> new GroupingKey(channel, mode, new Thing() { + return mode -> new Factory() { @Override - public int intermediateBlockCount() { - return 0; + public GroupingKey apply(DriverContext context, int resultOffset) { + return new GroupingKey(mode, new Load(channel, elementType, resultOffset)); } @Override - public ElementType intermediateElementType() { + public ElementType elementType() { return elementType; } @Override - public ElementType finalElementType() { - return elementType; + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + if (channel != timeBucketChannel) { + final List channels = List.of(channel); + // TODO: perhaps introduce a specialized aggregator for this? + return (switch (elementType()) { + case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels); + case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels); + case INT -> new ValuesIntAggregatorFunctionSupplier(channels); + case LONG -> new ValuesLongAggregatorFunctionSupplier(channels); + case BOOLEAN -> new ValuesBooleanAggregatorFunctionSupplier(channels); + case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); + }).groupingAggregatorFactory(AggregatorMode.SINGLE); + } + return null; } + }; + } - @Override - public void receiveIntermediateState(Page page, int offset) {} + public static List toBlockHashGroupSpec(List keys) { + List result = new ArrayList<>(keys.size()); + for (int k = 0; k < keys.size(); k++) { + result.add(new BlockHash.GroupSpec(k, keys.get(k).elementType())); + } + return result; + } - @Override - public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) {} + public ElementType elementType() { + return mode.isOutputPartial() ? thing.intermediateElementType() : thing.finalElementType(); + } - @Override - public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset) {} + @Override + public Block eval(Page page) { + return mode.isInputPartial() ? thing.evalIntermediateInput(page) : thing.evalRawInput(page); + } - @Override - public String toString() { - return "Stateless"; - } - }); + public int finishBlockCount() { + return mode.isOutputPartial() ? 1 + thing.extraIntermediateBlocks() : 1; } - public static List toBlockHashGroupSpec(List keys) { - return keys.stream().map(GroupingKey::toBlockHashSpec).toList(); + public void finish(Block[] blocks, IntVector selected, DriverContext driverContext) { + if (mode.isOutputPartial()) { + thing.fetchIntermediateState(driverContext.blockFactory(), blocks, selected.getPositionCount()); + } else { + thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks); + } } - public BlockHash.GroupSpec toBlockHashSpec() { - return new BlockHash.GroupSpec(channel, elementType()); // NOCOMMIT this should probably be an evaluator and a BlockType + public int extraIntermediateBlocks() { + return thing.extraIntermediateBlocks(); } - public ElementType elementType() { - return mode.isOutputPartial() ? thing.intermediateElementType() : thing.finalElementType(); + @Override + public void close() { + thing.close(); } - public void receive(Page page, int offset) { - if (mode.isInputPartial()) { - thing.receiveIntermediateState(page, offset); + private record Load(int channel, ElementType elementType, int resultOffset) implements Thing { + @Override + public int extraIntermediateBlocks() { + return 0; } - } - public int evaluateBlockCount() { - return 1 + (mode.isOutputPartial() ? thing.intermediateBlockCount() : 0); - } + @Override + public ElementType intermediateElementType() { + return elementType; + } - public void evaluate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { - if (mode.isOutputPartial()) { - thing.fetchIntermediateState(driverContext.blockFactory(), blocks, offset + 1, selected.getPositionCount()); - } else { - thing.replaceIntermediateKeys(driverContext.blockFactory(), blocks, offset); + @Override + public ElementType finalElementType() { + return elementType; + } + + @Override + public Block evalRawInput(Page page) { + Block b = page.getBlock(channel); + b.incRef(); + return b; + } + + @Override + public Block evalIntermediateInput(Page page) { + Block b = page.getBlock(resultOffset); + b.incRef(); + return b; } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) {} + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) {} + + @Override + public void close() {} } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 38044153b8a93..5d55cea69349d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -32,24 +32,20 @@ import java.util.Objects; import java.util.function.Supplier; -import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; public class HashAggregationOperator implements Operator { - public record HashAggregationOperatorFactory(List groups, List aggregators, int maxPageSize) - implements - OperatorFactory { + public record HashAggregationOperatorFactory( + List groups, + List aggregators, + int maxPageSize + ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { return new HashAggregationOperator( aggregators, groups, - () -> BlockHash.build( - groups.stream().map(GroupingKey::toBlockHashSpec).toList(), - driverContext.blockFactory(), - maxPageSize, - false - ), + () -> BlockHash.build(GroupingKey.toBlockHashGroupSpec(groups), driverContext.blockFactory(), maxPageSize, false), driverContext ); } @@ -91,12 +87,12 @@ public String describe() { @SuppressWarnings("this-escape") public HashAggregationOperator( List aggregators, - List groups, + List groups, Supplier blockHash, DriverContext driverContext ) { this.aggregators = new ArrayList<>(aggregators.size()); - this.groups = groups; + this.groups = new ArrayList<>(groups.size()); this.driverContext = driverContext; boolean success = false; try { @@ -104,6 +100,12 @@ public HashAggregationOperator( for (GroupingAggregator.Factory a : aggregators) { this.aggregators.add(a.apply(driverContext)); } + int offset = 0; + for (GroupingKey.Factory g : groups) { + GroupingKey key = g.apply(driverContext, offset); + this.groups.add(key); + offset += key.extraIntermediateBlocks() + 1; + } success = true; } finally { if (success == false) { @@ -119,6 +121,8 @@ public boolean needsInput() { @Override public void addInput(Page page) { + checkState(needsInput(), "Operator is already finishing"); + try { GroupingAggregatorFunction.AddInput[] prepared = new GroupingAggregatorFunction.AddInput[aggregators.size()]; class AddInput implements GroupingAggregatorFunction.AddInput { @@ -164,20 +168,17 @@ public void close() { } } Block[] keys = new Block[groups.size()]; + page = wrapPage(page); try (AddInput add = new AddInput()) { - checkState(needsInput(), "Operator is already finishing"); - - int offset = 0; - for (GroupingKey key : groups) { - key.receive(page, offset); - offset += key.evaluateBlockCount(); + for (int g = 0; g < groups.size(); g++) { + keys[g] = groups.get(g).eval(page); } for (int i = 0; i < prepared.length; i++) { prepared[i] = aggregators.get(i).prepareProcessPage(blockHash, page); } - blockHash.add(wrapPage(page), add); + blockHash.add(new Page(keys), add); hashNanos += System.nanoTime() - add.hashStart; } finally { Releasables.close(keys); @@ -209,10 +210,8 @@ public void finish() { Block[] keys = blockHash.getKeys(); int blockCount = 0; - int[] groupBlockCounts = new int[groups.size()]; for (int g = 0; g < groups.size(); g++) { - groupBlockCounts[g] = groups.get(g).evaluateBlockCount(); - blockCount += groupBlockCounts[g]; + blockCount += groups.get(g).finishBlockCount(); } int[] aggBlockCounts = new int[aggregators.size()]; for (int a = 0; a < aggregators.size(); a++) { @@ -224,8 +223,8 @@ public void finish() { int offset = 0; for (int g = 0; g < groups.size(); g++) { blocks[offset] = keys[g]; - groups.get(g).evaluate(blocks, offset, selected, driverContext); - offset += groupBlockCounts[g]; + groups.get(g).finish(blocks, selected, driverContext); + offset += groups.get(g).finishBlockCount(); } for (int a = 0; a < aggregators.size(); a++) { aggregators.get(a).evaluate(blocks, offset, selected, driverContext); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java index 3971f22fd7fa7..3c8f35d4dad96 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java @@ -59,7 +59,7 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INITIAL)); } - List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INITIAL)).toList(); + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INITIAL)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); return new HashAggregationOperator( aggregators, @@ -92,7 +92,7 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.INTERMEDIATE)); } - List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INTERMEDIATE)).toList(); + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.INTERMEDIATE)).toList(); aggregators.addAll(valuesAggregatorForGroupings(groupings, timeBucketChannel)); List hashGroups = List.of( new BlockHash.GroupSpec(tsHashChannel, ElementType.BYTES_REF), @@ -127,7 +127,7 @@ public Operator get(DriverContext driverContext) { for (AggregatorFunctionSupplier f : nonRates) { aggregators.add(f.groupingAggregatorFactory(AggregatorMode.FINAL)); } - List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.FINAL)).toList(); + List groupings = this.groupings.stream().map(g -> g.get(AggregatorMode.FINAL)).toList(); return new HashAggregationOperator( aggregators, groupings, @@ -142,21 +142,12 @@ public String describe() { } } - static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { + static List valuesAggregatorForGroupings(List groupings, int timeBucketChannel) { List aggregators = new ArrayList<>(); - for (GroupingKey g : groupings) { - if (g.channel() != timeBucketChannel) { - final List channels = List.of(g.channel()); - // TODO: perhaps introduce a specialized aggregator for this? - var aggregatorSupplier = (switch (g.elementType()) { - case BYTES_REF -> new org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier(channels); - case DOUBLE -> new org.elasticsearch.compute.aggregation.ValuesDoubleAggregatorFunctionSupplier(channels); - case INT -> new org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier(channels); - case LONG -> new org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier(channels); - case BOOLEAN -> new org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier(channels); - case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type"); - }); - aggregators.add(aggregatorSupplier.groupingAggregatorFactory(AggregatorMode.SINGLE)); + for (GroupingKey.Factory g : groupings) { + GroupingAggregator.Factory factory = g.valuesAggregatorForGroupingsInTimeSeries(timeBucketChannel); + if (factory != null) { + aggregators.add(factory); } } return aggregators; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 8b244ee134ec7..eb9dd2873dfc7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -33,10 +33,10 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.compute.aggregation.CountAggregatorFunction; import org.elasticsearch.compute.aggregation.GroupingAggregator; import org.elasticsearch.compute.aggregation.GroupingKey; -import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockTestUtils; @@ -56,7 +56,6 @@ import org.elasticsearch.compute.lucene.ShardContext; import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator; import org.elasticsearch.compute.operator.AbstractPageMappingOperator; -import org.elasticsearch.compute.operator.CannedSourceOperator; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.HashAggregationOperator; @@ -236,118 +235,6 @@ public String toString() { assertThat(blockFactory.breaker().getUsed(), equalTo(0L)); } - public void testStatefulGrouping() { - DriverContext driverContext = driverContext(); - Stream input = Stream.of( - new BytesRef("abc"), - new BytesRef("def"), - new BytesRef("abc"), - new BytesRef("abc"), - new BytesRef("abc"), - new BytesRef("abc"), - new BytesRef("blah") - ); - List output = new ArrayList<>(); - List operators = new ArrayList<>(); - - class Example implements GroupingKey.Thing { - int count; - - @Override - public int intermediateBlockCount() { - return 1; - } - - @Override - public ElementType intermediateElementType() { - return ElementType.BYTES_REF; - } - - @Override - public ElementType finalElementType() { - return ElementType.BYTES_REF; - } - - @Override - public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) { - blocks[offset] = blockFactory.newConstantIntBlockWith(count, positionCount); - } - - @Override - public void receiveIntermediateState(Page page, int offset) { - IntBlock block = page.getBlock(offset + 1); - IntVector vector = block.asVector(); - assertThat(vector.isConstant(), equalTo(true)); - count = vector.getInt(0); - } - - @Override - public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks, int offset) { - try ( - BytesRefBlock block = (BytesRefBlock) blocks[offset]; - BytesRefVector.Builder replacement = blockFactory.newBytesRefVectorBuilder(block.getPositionCount()) - ) { - BytesRefVector vector = block.asVector(); - for (int p = 0; p < vector.getPositionCount(); p++) { - replacement.appendBytesRef(new BytesRef(count + vector.getBytesRef(p, new BytesRef()).utf8ToString())); - } - blocks[offset] = replacement.build().asBlock(); - } - } - } - - operators.add( - new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new GroupingKey(0, INITIAL, new Example() { - @Override - public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int offset, int positionCount) { - this.count = 10; // NOCOMMIT remove me - super.fetchIntermediateState(blockFactory, blocks, offset, positionCount); - } - })), - List.of(), - 16 * 1024 - ).get(driverContext) - ); - operators.add( - new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new GroupingKey(0, FINAL, new Example())), - List.of(), - 16 * 1024 - ).get(driverContext) - ); - operators.add( - new TopNOperator( - driverContext.blockFactory(), - driverContext.breaker(), - 3, - List.of(ElementType.BYTES_REF), - List.of(TopNEncoder.UTF8), - List.of(new TopNOperator.SortOrder(0, true, true)), - 16 * 1024 - ) - ); - - Driver driver = new Driver( - driverContext, - new SequenceBytesRefBlockSourceOperator(driverContext.blockFactory(), input), - operators, - new TestResultPageSinkOperator(output::add), - () -> {} - ); - OperatorTestCase.runDriver(driver); - - assertThat(output, hasSize(1)); - assertThat(output.get(0).getBlockCount(), equalTo(1)); - BytesRefBlock block = output.get(0).getBlock(0); - BytesRefVector vector = block.asVector(); - List values = new ArrayList<>(); - for (int p = 0; p < vector.getPositionCount(); p++) { - values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); - } - assertThat(values, equalTo(List.of("7abc", "7blah", "7def"))); - } - public void testLimitOperator() { var positions = 100; var limit = randomIntBetween(90, 101); @@ -507,4 +394,149 @@ static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query qu limit ); } + + public void testStatefulGrouping() { + DriverContext driverContext = driverContext(); + Stream input = Stream.of( + new BytesRef("abc"), + new BytesRef("def"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("abc"), + new BytesRef("blah") + ); + List output = new ArrayList<>(); + List operators = new ArrayList<>(); + + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new ExampleStatefulGroupingFunction.Factory(INITIAL, 0)), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory( + List.of(new ExampleStatefulGroupingFunction.Factory(FINAL, 0)), + List.of(), + 16 * 1024 + ).get(driverContext) + ); + operators.add( + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 3, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ); + + Driver driver = new Driver( + driverContext, + new SequenceBytesRefBlockSourceOperator(driverContext.blockFactory(), input), + operators, + new TestResultPageSinkOperator(output::add), + () -> {} + ); + OperatorTestCase.runDriver(driver); + + assertThat(output, hasSize(1)); + assertThat(output.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = output.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat(values, equalTo(List.of("7abc", "7blah", "7def"))); + } + + static class ExampleStatefulGroupingFunction implements GroupingKey.Thing { + record Factory(AggregatorMode mode, int inputChannel) implements GroupingKey.Factory { + @Override + public GroupingKey apply(DriverContext context, int resultOffset) { + return new GroupingKey(mode, new ExampleStatefulGroupingFunction(inputChannel, resultOffset)); + } + + @Override + public ElementType elementType() { + return ElementType.BYTES_REF; + } + + @Override + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + throw new UnsupportedOperationException(); + } + } + + private final int inputChannel; + private final int resultOffset; + + int count; + + ExampleStatefulGroupingFunction(int inputChannel, int resultOffset) { + this.inputChannel = inputChannel; + this.resultOffset = resultOffset; + } + + @Override + public int extraIntermediateBlocks() { + return 1; + } + + @Override + public ElementType intermediateElementType() { + return ElementType.BYTES_REF; + } + + @Override + public ElementType finalElementType() { + return ElementType.BYTES_REF; + } + + @Override + public Block evalRawInput(Page page) { + count += page.getPositionCount(); + Block block = page.getBlock(inputChannel); + block.incRef(); + return block; + } + + @Override + public Block evalIntermediateInput(Page page) { + IntBlock block = page.getBlock(resultOffset + 1); + IntVector vector = block.asVector(); + assertThat(vector.isConstant(), equalTo(true)); + count = vector.getInt(0); + Block b = page.getBlock(resultOffset); + b.incRef(); + return b; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) { + blocks[resultOffset + 1] = blockFactory.newConstantIntBlockWith(count, positionCount); + } + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) { + try ( + BytesRefBlock block = (BytesRefBlock) blocks[resultOffset]; + BytesRefVector.Builder replacement = blockFactory.newBytesRefVectorBuilder(block.getPositionCount()) + ) { + BytesRefVector vector = block.asVector(); + for (int p = 0; p < vector.getPositionCount(); p++) { + replacement.appendBytesRef(new BytesRef(count + vector.getBytesRef(p, new BytesRef()).utf8ToString())); + } + blocks[resultOffset] = replacement.build().asBlock(); + } + } + + @Override + public void close() {} + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 5e23e176312b3..63e6c6a410e5b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -127,7 +127,7 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato } layout.append(groupAttributeLayout); Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id()); - groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute, aggregatorMode)); + groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute)); } if (aggregatorMode == AggregatorMode.FINAL) { @@ -159,8 +159,12 @@ else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == Aggregato context ); } else { + List groupings = new ArrayList<>(groupSpecs.size()); + for (GroupSpec group : groupSpecs) { + groupings.add(group.toGroupingKey().get(aggregatorMode)); + } operatorFactory = new HashAggregationOperatorFactory( - groupSpecs.stream().map(GroupSpec::toGroupingKey).toList(), + groupings, aggregatorFactories, context.pageSize(aggregateExec.estimatedRowSize()) ); @@ -284,12 +288,12 @@ private void aggregatesToFactory( } } - private record GroupSpec(Integer channel, Attribute attribute, AggregatorMode mode) { - GroupingKey toGroupingKey() { + private record GroupSpec(Integer channel, Attribute attribute) { + GroupingKey.Supplier toGroupingKey() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } - return GroupingKey.forStatelessGrouping(channel, elementType()).get(mode); + return GroupingKey.forStatelessGrouping(channel, elementType()); } ElementType elementType() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index 85db6c7555ba9..8d7d4750e8e49 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -232,7 +232,7 @@ private class TestHashAggregationOperator extends HashAggregationOperator { TestHashAggregationOperator( List aggregators, - List groups, + List groups, Supplier blockHash, String columnName, DriverContext driverContext @@ -276,7 +276,7 @@ private class TestOrdinalsGroupingAggregationOperatorFactory implements Operator public Operator get(DriverContext driverContext) { Random random = Randomness.get(); int pageSize = random.nextBoolean() ? randomIntBetween(random, 1, 16) : randomIntBetween(random, 1, 10 * 1024); - List groupings = List.of( + List groupings = List.of( GroupingKey.forStatelessGrouping(groupByChannel, groupElementType).get(AggregatorMode.INITIAL) ); return new TestHashAggregationOperator( From 1fb09b1bafbb4bba94ff682be5bc8c9915788652 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Wed, 11 Sep 2024 14:33:48 -0400 Subject: [PATCH 04/10] Fix one test --- .../compute/aggregation/GroupingKey.java | 30 ++++--------------- .../operator/HashAggregationOperator.java | 2 +- .../operator/OrdinalsGroupingOperator.java | 2 +- .../elasticsearch/compute/OperatorTests.java | 14 ++------- 4 files changed, 11 insertions(+), 37 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java index 10a0cfb732edd..b7f581fe19ea6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java @@ -24,10 +24,6 @@ public record GroupingKey(AggregatorMode mode, Thing thing) implements EvalOpera public interface Thing extends Releasable { int extraIntermediateBlocks(); - ElementType intermediateElementType(); - - ElementType finalElementType(); - Block evalRawInput(Page page); Block evalIntermediateInput(Page page); @@ -44,7 +40,7 @@ public interface Supplier { public interface Factory { GroupingKey apply(DriverContext context, int resultOffset); - ElementType elementType(); + ElementType intermediateElementType(); GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel); } @@ -53,11 +49,11 @@ public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType return mode -> new Factory() { @Override public GroupingKey apply(DriverContext context, int resultOffset) { - return new GroupingKey(mode, new Load(channel, elementType, resultOffset)); + return new GroupingKey(mode, new Load(channel, resultOffset)); } @Override - public ElementType elementType() { + public ElementType intermediateElementType() { return elementType; } @@ -66,7 +62,7 @@ public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int t if (channel != timeBucketChannel) { final List channels = List.of(channel); // TODO: perhaps introduce a specialized aggregator for this? - return (switch (elementType()) { + return (switch (intermediateElementType()) { case BYTES_REF -> new ValuesBytesRefAggregatorFunctionSupplier(channels); case DOUBLE -> new ValuesDoubleAggregatorFunctionSupplier(channels); case INT -> new ValuesIntAggregatorFunctionSupplier(channels); @@ -83,15 +79,11 @@ public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int t public static List toBlockHashGroupSpec(List keys) { List result = new ArrayList<>(keys.size()); for (int k = 0; k < keys.size(); k++) { - result.add(new BlockHash.GroupSpec(k, keys.get(k).elementType())); + result.add(new BlockHash.GroupSpec(k, keys.get(k).intermediateElementType())); } return result; } - public ElementType elementType() { - return mode.isOutputPartial() ? thing.intermediateElementType() : thing.finalElementType(); - } - @Override public Block eval(Page page) { return mode.isInputPartial() ? thing.evalIntermediateInput(page) : thing.evalRawInput(page); @@ -118,22 +110,12 @@ public void close() { thing.close(); } - private record Load(int channel, ElementType elementType, int resultOffset) implements Thing { + private record Load(int channel, int resultOffset) implements Thing { @Override public int extraIntermediateBlocks() { return 0; } - @Override - public ElementType intermediateElementType() { - return elementType; - } - - @Override - public ElementType finalElementType() { - return elementType; - } - @Override public Block evalRawInput(Page page) { Block b = page.getBlock(channel); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index 5d55cea69349d..6e6ccf8442a54 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -254,7 +254,7 @@ public void close() { if (output != null) { output.releaseBlocks(); } - Releasables.close(blockHash, () -> Releasables.close(aggregators)); + Releasables.close(blockHash, Releasables.wrap(aggregators), Releasables.wrap(groups)); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index b1e3711fd791c..8f50c0959d8ad 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -506,7 +506,7 @@ private static class ValuesAggregator implements Releasable { GroupingKey.forStatelessGrouping(channelIndex, groupingElementType).get(AggregatorMode.INITIAL) ), () -> BlockHash.build( - List.of(new GroupSpec(channelIndex, groupingElementType)), + List.of(new GroupSpec(0, groupingElementType)), driverContext.blockFactory(), maxPageSize, false diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index eb9dd2873dfc7..b1f5c72fcc894 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.compute; +import com.carrotsearch.randomizedtesting.annotations.Seed; + import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.LongPoint; @@ -463,7 +465,7 @@ public GroupingKey apply(DriverContext context, int resultOffset) { } @Override - public ElementType elementType() { + public ElementType intermediateElementType() { return ElementType.BYTES_REF; } @@ -488,16 +490,6 @@ public int extraIntermediateBlocks() { return 1; } - @Override - public ElementType intermediateElementType() { - return ElementType.BYTES_REF; - } - - @Override - public ElementType finalElementType() { - return ElementType.BYTES_REF; - } - @Override public Block evalRawInput(Page page) { count += page.getPositionCount(); From 522c0bcd0ce1ecb6ce69d1de4c4d5225f4addbed Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Tue, 17 Sep 2024 17:51:36 -0400 Subject: [PATCH 05/10] foooooooooo --- .../function/grouping/Categorize.java | 193 +++++++++++++++++- .../grouping/CategorizeOperatorTests.java | 162 +++++++++++++++ .../categorization/TokenListCategorizer.java | 9 +- 3 files changed, 362 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 82c836a6f9d49..13fe4fc21ee86 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -10,13 +10,29 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.core.WhitespaceTokenizer; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.compute.aggregation.GroupingAggregator; +import org.elasticsearch.compute.aggregation.GroupingKey; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.core.Releasables; import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.CustomAnalyzer; import org.elasticsearch.index.analysis.TokenFilterFactory; @@ -31,6 +47,8 @@ import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash; import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory; import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer; import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; @@ -135,7 +153,7 @@ protected TypeResolution resolveType() { @Override public DataType dataType() { - return DataType.INTEGER; + return DataType.KEYWORD; } @Override @@ -156,4 +174,177 @@ public Expression field() { public String toString() { return "Categorize{field=" + field + "}"; } + + public GroupingKey.Supplier groupingKey(Function toEvaluator) { + return mode -> new GroupingKeyFactory(source(), toEvaluator.apply(field), mode); + } + + record GroupingKeyFactory(Source source, ExpressionEvaluator.Factory field, AggregatorMode mode) implements GroupingKey.Factory { + @Override + public GroupingKey apply(DriverContext context, int resultOffset) { + ExpressionEvaluator field = this.field.get(context); + CategorizeEvaluator evaluator = null; + TokenListCategorizer.CloseableTokenListCategorizer categorizer = null; + try { + categorizer = new TokenListCategorizer.CloseableTokenListCategorizer( + new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())), + CategorizationPartOfSpeechDictionary.getInstance(), + 0.70f + ); + evaluator = new CategorizeEvaluator( + source, + field, + new CategorizationAnalyzer( + // TODO(jan): get the correct analyzer in here, see + // CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer + new CustomAnalyzer( + TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new), + new CharFilterFactory[0], + new TokenFilterFactory[0] + ), + true + ), + categorizer, + context + ); + field = null; + GroupingKey result = new GroupingKey(mode, new GroupingKeyThing(resultOffset, evaluator, categorizer)); + categorizer = null; + evaluator = null; + return result; + } finally { + Releasables.close(field, evaluator, categorizer); + } + + } + + @Override + public ElementType intermediateElementType() { + return ElementType.INT; + } + + @Override + public GroupingAggregator.Factory valuesAggregatorForGroupingsInTimeSeries(int timeBucketChannel) { + throw new UnsupportedOperationException("not supported in time series"); + } + } + + private static class GroupingKeyThing implements GroupingKey.Thing { + private final int resultOffset; + private final CategorizeEvaluator evaluator; + private final TokenListCategorizer.CloseableTokenListCategorizer categorizer; + + private GroupingKeyThing( + int resultOffset, + CategorizeEvaluator evaluator, + TokenListCategorizer.CloseableTokenListCategorizer categorizer + ) { + this.resultOffset = resultOffset; + this.evaluator = evaluator; + this.categorizer = categorizer; + } + + @Override + public Block evalRawInput(Page page) { + return evaluator.eval(page); + } + + @Override + public int extraIntermediateBlocks() { + return 1; + } + + @Override + public void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount) { + blocks[resultOffset + 1] = buildIntermediateBlock(blockFactory, positionCount); + } + + @Override + public Block evalIntermediateInput(Page page) { + BytesRefBlock intermediate = page.getBlock(resultOffset + 1); + if (intermediate.areAllValuesNull() == false) { + readIntermediate(intermediate.getBytesRef(0, new BytesRef())); + } + // NOCOMMIT this should remap the ints in the input block to whatever the *new* ids are + Block result = page.getBlock(resultOffset); + result.incRef(); + return result; + } + + @Override + public void replaceIntermediateKeys(BlockFactory blockFactory, Block[] blocks) { + // NOCOMMIT this offset can't be the same in the result array and intermediate array + IntBlock keys = (IntBlock) blocks[resultOffset]; + blocks[resultOffset] = finalKeys(blockFactory, keys); + System.err.println(blocks[resultOffset]); + } + + @Override + public void close() { + evaluator.close(); + } + + private Block buildIntermediateBlock(BlockFactory blockFactory, int positionCount) { + if (categorizer.getCategoryCount() == 0) { + return blockFactory.newConstantNullBlock(positionCount); + } + try (BytesStreamOutput out = new BytesStreamOutput()) { + // TODO be more careful here. + out.writeVInt(categorizer.getCategoryCount()); + for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) { + category.writeTo(out); + } + return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void readIntermediate(BytesRef bytes) { + try (StreamInput in = new BytesArray(bytes).streamInput()) { + int count = in.readVInt(); + for (int i = 0; i < count; i++) { + SerializableTokenListCategory category = new SerializableTokenListCategory(in); + categorizer.mergeWireCategory(category); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private OrdinalBytesRefBlock finalKeys(BlockFactory blockFactory, IntBlock keys) { + keys.incRef(); + return new OrdinalBytesRefBlock(keys, finalBytes(blockFactory)); + } + + /** + * A lookup table containing the category names. + */ + private BytesRefVector finalBytes(BlockFactory blockFactory) { + try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { + BytesRefBuilder scratch = new BytesRefBuilder(); + for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) { + // NOCOMMIT build tokens properly + BytesRef[] tokens = category.getKeyTokens(); + if (tokens.length == 0) { + scratch.append((byte) '*'); + result.appendBytesRef(scratch.get()); + scratch.clear(); + continue; + } + scratch.append(tokens[0]); + for (int i = 1; i < tokens.length; i++) { + scratch.append((byte) ' '); + scratch.append(tokens[i]); + } + scratch.append((byte) ' '); + scratch.append((byte) '.'); + scratch.append((byte) '*'); + result.appendBytesRef(scratch.get()); + scratch.clear(); + } + return result.build(); + } + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java new file mode 100644 index 0000000000000..9171ba890e8db --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.compute.aggregation.GroupingKey; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.DriverRunner; +import org.elasticsearch.compute.operator.HashAggregationOperator; +import org.elasticsearch.compute.operator.LocalSourceOperator; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.topn.TopNEncoder; +import org.elasticsearch.compute.operator.topn.TopNOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.FixedExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.junit.After; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; +import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizeOperatorTests extends ESTestCase { + public void testCategorization() { + DriverContext driverContext = driverContext(); + LocalSourceOperator.BlockSupplier input = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("words words words hello nik")); + builder.appendBytesRef(new BytesRef("words words words goodbye nik")); + builder.appendBytesRef(new BytesRef("words words words hello jan")); + builder.appendBytesRef(new BytesRef("words words words goodbye jan")); + builder.appendBytesRef(new BytesRef("words words words hello kitty")); + builder.appendBytesRef(new BytesRef("words words words goodbye blue sky")); + return new Block[] { builder.build().asBlock() }; + } + }; + List output = new ArrayList<>(); + try { + List operators = new ArrayList<>(); + + Categorize cat = new Categorize(Source.EMPTY, AbstractFunctionTestCase.field("f", DataType.KEYWORD)); + GroupingKey.Supplier key = cat.groupingKey(AbstractFunctionTestCase::evaluator); + + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ); + operators.add( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(FINAL)), List.of(), 16 * 1024).get(driverContext) + ); + operators.add( + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 3, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ); + + Driver driver = new Driver( + driverContext, + new LocalSourceOperator(input), + operators, + new PageConsumerOperator(output::add), + () -> {} + ); + runDriver(driver); + + assertThat(output, hasSize(1)); + assertThat(output.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = output.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat(values, equalTo(List.of( + "words words words goodbye .*", + "words words words goodbye blue sky .*", + "words words words hello .+"))); + } finally { + Releasables.close(() -> Iterators.map(output.iterator(), (Page p) -> p::releaseBlocks)); + } + } + + private final List breakers = Collections.synchronizedList(new ArrayList<>()); + + private DriverContext driverContext() { + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + breakers.add(breaker); + return new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + } + + @After + public void allMemoryReleased() { + for (CircuitBreaker breaker : breakers) { + assertThat(breaker.getUsed(), equalTo(0L)); + } + } + + public static void runDriver(Driver driver) { + ThreadPool threadPool = new TestThreadPool( + getTestClass().getSimpleName(), + new FixedExecutorBuilder(Settings.EMPTY, "esql", 1, 1024, "esql", EsExecutors.TaskTrackingConfig.DEFAULT) + ); + var driverRunner = new DriverRunner(threadPool.getThreadContext()) { + @Override + protected void start(Driver driver, ActionListener driverListener) { + Driver.start(threadPool.getThreadContext(), threadPool.executor("esql"), driver, between(1, 10000), driverListener); + } + }; + PlainActionFuture future = new PlainActionFuture<>(); + try { + driverRunner.runToCompletion(List.of(driver), future); + future.actionGet(TimeValue.timeValueSeconds(30)); + } finally { + terminate(threadPool); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java index d0088edcb0805..e1d3398e9d264 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java @@ -206,7 +206,7 @@ public TokenListCategory mergeWireCategory(SerializableTokenListCategory seriali return mergedCategory; } - private synchronized TokenListCategory computeCategory( + private synchronized TokenListCategory computeCategory( // NOCOMMIT why synchronized? in ESQL at least there are no threads around List weightedTokenIds, List workTokenUniqueIds, int workWeight, @@ -424,6 +424,13 @@ public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) { .toArray(InternalCategorizationAggregation.Bucket[]::new); } + public List toCategories(int size) { + return categoriesByNumMatches.stream() + .limit(size) + .map(category -> new SerializableTokenListCategory(category, bytesRefHash)) + .toList(); + } + public InternalCategorizationAggregation.Bucket[] toOrderedBuckets( int size, long minNumMatches, From 96e65056df945c32b051d01e04a4babc63ae0ce4 Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Wed, 18 Sep 2024 13:58:15 +0200 Subject: [PATCH 06/10] Output correct regexes --- .../function/grouping/Categorize.java | 17 +---------------- .../grouping/CategorizeOperatorTests.java | 6 +++--- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 13fe4fc21ee86..e7af3fd9f8e77 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -324,22 +324,7 @@ private BytesRefVector finalBytes(BlockFactory blockFactory) { try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) { BytesRefBuilder scratch = new BytesRefBuilder(); for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) { - // NOCOMMIT build tokens properly - BytesRef[] tokens = category.getKeyTokens(); - if (tokens.length == 0) { - scratch.append((byte) '*'); - result.appendBytesRef(scratch.get()); - scratch.clear(); - continue; - } - scratch.append(tokens[0]); - for (int i = 1; i < tokens.length; i++) { - scratch.append((byte) ' '); - scratch.append(tokens[i]); - } - scratch.append((byte) ' '); - scratch.append((byte) '.'); - scratch.append((byte) '*'); + scratch.copyChars(category.getRegex()); result.appendBytesRef(scratch.get()); scratch.clear(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java index 9171ba890e8db..046420c643502 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java @@ -116,9 +116,9 @@ public void testCategorization() { values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); } assertThat(values, equalTo(List.of( - "words words words goodbye .*", - "words words words goodbye blue sky .*", - "words words words hello .+"))); + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?goodbye.+?blue.+?sky.*?", + ".*?words.+?words.+?words.+?hello.*?"))); } finally { Releasables.close(() -> Iterators.map(output.iterator(), (Page p) -> p::releaseBlocks)); } From f6ef350cfa23e1e4c29e87b5e347ce01f78f5a9e Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Wed, 18 Sep 2024 13:59:00 +0200 Subject: [PATCH 07/10] Remap intermediate category IDs --- .../compute/aggregation/GroupingKey.java | 10 +++--- .../elasticsearch/compute/OperatorTests.java | 4 +-- .../function/grouping/Categorize.java | 34 +++++++++++++------ 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java index b7f581fe19ea6..b7bf7615c3835 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingKey.java @@ -20,13 +20,13 @@ import java.util.ArrayList; import java.util.List; -public record GroupingKey(AggregatorMode mode, Thing thing) implements EvalOperator.ExpressionEvaluator { +public record GroupingKey(AggregatorMode mode, Thing thing, BlockFactory blockFactory) implements EvalOperator.ExpressionEvaluator { public interface Thing extends Releasable { int extraIntermediateBlocks(); Block evalRawInput(Page page); - Block evalIntermediateInput(Page page); + Block evalIntermediateInput(BlockFactory blockFactory, Page page); void fetchIntermediateState(BlockFactory blockFactory, Block[] blocks, int positionCount); @@ -49,7 +49,7 @@ public static GroupingKey.Supplier forStatelessGrouping(int channel, ElementType return mode -> new Factory() { @Override public GroupingKey apply(DriverContext context, int resultOffset) { - return new GroupingKey(mode, new Load(channel, resultOffset)); + return new GroupingKey(mode, new Load(channel, resultOffset), context.blockFactory()); } @Override @@ -86,7 +86,7 @@ public static List toBlockHashGroupSpec(List idMap; if (intermediate.areAllValuesNull() == false) { - readIntermediate(intermediate.getBytesRef(0, new BytesRef())); + idMap = readIntermediate(intermediate.getBytesRef(0, new BytesRef())); + } else { + idMap = Collections.emptyMap(); + } + IntBlock oldIds = page.getBlock(resultOffset); + try (IntBlock.Builder newIds = blockFactory.newIntBlockBuilder(oldIds.getTotalValueCount())) { + for (int i = 0; i < oldIds.getTotalValueCount(); i++) { + newIds.appendInt(idMap.get(i)); + } + return newIds.build(); } - // NOCOMMIT this should remap the ints in the input block to whatever the *new* ids are - Block result = page.getBlock(resultOffset); - result.incRef(); - return result; } @Override @@ -300,13 +308,17 @@ private Block buildIntermediateBlock(BlockFactory blockFactory, int positionCoun } } - private void readIntermediate(BytesRef bytes) { + private Map readIntermediate(BytesRef bytes) { + Map idMap = new HashMap<>(); try (StreamInput in = new BytesArray(bytes).streamInput()) { int count = in.readVInt(); - for (int i = 0; i < count; i++) { + for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) { SerializableTokenListCategory category = new SerializableTokenListCategory(in); - categorizer.mergeWireCategory(category); + int newCategoryId = categorizer.mergeWireCategory(category).getId(); + System.err.println("category id map: " + oldCategoryId + " -> " + newCategoryId); + idMap.put(oldCategoryId, newCategoryId); } + return idMap; } catch (IOException e) { throw new RuntimeException(e); } From a3854418a44ac01a1c3c53325c916b6adebed947 Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Wed, 18 Sep 2024 16:28:21 +0200 Subject: [PATCH 08/10] Fix mem leak --- .../xpack/esql/expression/function/grouping/Categorize.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 38576976ab60c..7fbf0c4f2e708 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -325,7 +325,6 @@ private Map readIntermediate(BytesRef bytes) { } private OrdinalBytesRefBlock finalKeys(BlockFactory blockFactory, IntBlock keys) { - keys.incRef(); return new OrdinalBytesRefBlock(keys, finalBytes(blockFactory)); } From 5a46823b5b997956f0ebd5c44bfcd92daa65d12d Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Thu, 19 Sep 2024 11:48:25 +0200 Subject: [PATCH 09/10] spotless --- .../compute/operator/OrdinalsGroupingOperator.java | 7 +------ .../test/java/org/elasticsearch/compute/OperatorTests.java | 2 -- .../esql/expression/function/grouping/Categorize.java | 6 +++++- .../xpack/ml/aggs/categorization/TokenListCategorizer.java | 3 ++- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java index 8f50c0959d8ad..52fa93cc43927 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java @@ -505,12 +505,7 @@ private static class ValuesAggregator implements Releasable { // NOCOMMIT double check the mode GroupingKey.forStatelessGrouping(channelIndex, groupingElementType).get(AggregatorMode.INITIAL) ), - () -> BlockHash.build( - List.of(new GroupSpec(0, groupingElementType)), - driverContext.blockFactory(), - maxPageSize, - false - ), + () -> BlockHash.build(List.of(new GroupSpec(0, groupingElementType)), driverContext.blockFactory(), maxPageSize, false), driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java index 17e92f66e4614..a1f0eb77ef2d4 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java @@ -7,8 +7,6 @@ package org.elasticsearch.compute; -import com.carrotsearch.randomizedtesting.annotations.Seed; - import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.LongPoint; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 7fbf0c4f2e708..8f304f0deefe7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -210,7 +210,11 @@ public GroupingKey apply(DriverContext context, int resultOffset) { context ); field = null; - GroupingKey result = new GroupingKey(mode, new GroupingKeyThing(resultOffset, evaluator, categorizer), context.blockFactory()); + GroupingKey result = new GroupingKey( + mode, + new GroupingKeyThing(resultOffset, evaluator, categorizer), + context.blockFactory() + ); categorizer = null; evaluator = null; return result; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java index e1d3398e9d264..9295bff840c40 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java @@ -206,7 +206,8 @@ public TokenListCategory mergeWireCategory(SerializableTokenListCategory seriali return mergedCategory; } - private synchronized TokenListCategory computeCategory( // NOCOMMIT why synchronized? in ESQL at least there are no threads around + private synchronized TokenListCategory computeCategory( + // NOCOMMIT why synchronized? in ESQL at least there are no threads around List weightedTokenIds, List workTokenUniqueIds, int workWeight, From 7685a0ca3d63ebca5b1b9544e51744e31e732b58 Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Thu, 19 Sep 2024 11:48:48 +0200 Subject: [PATCH 10/10] Test categorize operator on multiple nodes --- .../function/grouping/Categorize.java | 2 +- .../grouping/CategorizeOperatorTests.java | 166 +++++++++++++++++- 2 files changed, 160 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java index 8f304f0deefe7..77d8ef1780777 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java @@ -319,7 +319,7 @@ private Map readIntermediate(BytesRef bytes) { for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) { SerializableTokenListCategory category = new SerializableTokenListCategory(in); int newCategoryId = categorizer.mergeWireCategory(category).getId(); - System.err.println("category id map: " + oldCategoryId + " -> " + newCategoryId); + System.err.println("category id map: " + oldCategoryId + " -> " + newCategoryId + " (" + category.getRegex() + ")"); idMap.put(oldCategoryId, newCategoryId); } return idMap; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java index 046420c643502..1eabbe188e660 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeOperatorTests.java @@ -10,7 +10,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.Randomness; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.settings.Settings; @@ -33,6 +32,7 @@ import org.elasticsearch.compute.operator.LocalSourceOperator; import org.elasticsearch.compute.operator.Operator; import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SourceOperator; import org.elasticsearch.compute.operator.topn.TopNEncoder; import org.elasticsearch.compute.operator.topn.TopNOperator; import org.elasticsearch.core.Releasables; @@ -48,9 +48,8 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; import java.util.List; -import java.util.stream.LongStream; -import java.util.stream.Stream; import static org.elasticsearch.compute.aggregation.AggregatorMode.FINAL; import static org.elasticsearch.compute.aggregation.AggregatorMode.INITIAL; @@ -115,15 +114,168 @@ public void testCategorization() { for (int p = 0; p < vector.getPositionCount(); p++) { values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); } - assertThat(values, equalTo(List.of( - ".*?words.+?words.+?words.+?goodbye.*?", - ".*?words.+?words.+?words.+?goodbye.+?blue.+?sky.*?", - ".*?words.+?words.+?words.+?hello.*?"))); + assertThat( + values, + equalTo( + List.of( + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?goodbye.+?blue.+?sky.*?", + ".*?words.+?words.+?words.+?hello.*?" + ) + ) + ); } finally { Releasables.close(() -> Iterators.map(output.iterator(), (Page p) -> p::releaseBlocks)); } } + /** + * {@link SourceOperator} that returns a sequence of pre-built {@link Page}s. + * TODO: this class is copy-pasted from the esql:compute plugin; fix that + */ + public static class CannedSourceOperator extends SourceOperator { + + private final Iterator page; + + public CannedSourceOperator(Iterator page) { + this.page = page; + } + + @Override + public void finish() { + while (page.hasNext()) { + page.next(); + } + } + + @Override + public boolean isFinished() { + return false == page.hasNext(); + } + + @Override + public Page getOutput() { + return page.next(); + } + + @Override + public void close() { + // release pages in the case of early termination - failure + while (page.hasNext()) { + page.next().releaseBlocks(); + } + } + } + + public void testCategorization_multipleNodes() { + DriverContext driverContext = driverContext(); + LocalSourceOperator.BlockSupplier input1 = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("a")); + builder.appendBytesRef(new BytesRef("b")); + builder.appendBytesRef(new BytesRef("words words words goodbye jan")); + builder.appendBytesRef(new BytesRef("words words words goodbye nik")); + builder.appendBytesRef(new BytesRef("words words words hello jan")); + builder.appendBytesRef(new BytesRef("c")); + return new Block[] { builder.build().asBlock() }; + } + }; + + LocalSourceOperator.BlockSupplier input2 = () -> { + try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) { + builder.appendBytesRef(new BytesRef("words words words hello nik")); + builder.appendBytesRef(new BytesRef("c")); + builder.appendBytesRef(new BytesRef("words words words goodbye chris")); + builder.appendBytesRef(new BytesRef("d")); + builder.appendBytesRef(new BytesRef("e")); + return new Block[] { builder.build().asBlock() }; + } + }; + + List intermediateOutput = new ArrayList<>(); + List finalOutput = new ArrayList<>(); + + try { + Categorize cat = new Categorize(Source.EMPTY, AbstractFunctionTestCase.field("f", DataType.KEYWORD)); + GroupingKey.Supplier key = cat.groupingKey(AbstractFunctionTestCase::evaluator); + + Driver driver = new Driver( + driverContext, + new LocalSourceOperator(input1), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + driver = new Driver( + driverContext, + new LocalSourceOperator(input2), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(INITIAL)), List.of(), 16 * 1024).get( + driverContext + ) + ), + new PageConsumerOperator(intermediateOutput::add), + () -> {} + ); + runDriver(driver); + + assertThat(intermediateOutput, hasSize(2)); + + driver = new Driver( + driverContext, + new CannedSourceOperator(intermediateOutput.iterator()), + List.of( + new HashAggregationOperator.HashAggregationOperatorFactory(List.of(key.get(FINAL)), List.of(), 16 * 1024).get( + driverContext + ), + new TopNOperator( + driverContext.blockFactory(), + driverContext.breaker(), + 10, + List.of(ElementType.BYTES_REF), + List.of(TopNEncoder.UTF8), + List.of(new TopNOperator.SortOrder(0, true, true)), + 16 * 1024 + ) + ), + new PageConsumerOperator(finalOutput::add), + () -> {} + ); + runDriver(driver); + + assertThat(finalOutput, hasSize(1)); + assertThat(finalOutput.get(0).getBlockCount(), equalTo(1)); + BytesRefBlock block = finalOutput.get(0).getBlock(0); + BytesRefVector vector = block.asVector(); + List values = new ArrayList<>(); + for (int p = 0; p < vector.getPositionCount(); p++) { + values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString()); + } + assertThat( + values, + equalTo( + List.of( + ".*?a.*?", + ".*?b.*?", + ".*?c.*?", + ".*?d.*?", + ".*?e.*?", + ".*?words.+?words.+?words.+?goodbye.*?", + ".*?words.+?words.+?words.+?hello.*?" + ) + ) + ); + } finally { + Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks)); + } + } + private final List breakers = Collections.synchronizedList(new ArrayList<>()); private DriverContext driverContext() {