diff --git a/docs/changelog/127629.yaml b/docs/changelog/127629.yaml new file mode 100644 index 0000000000000..20ae5eebfb3a4 --- /dev/null +++ b/docs/changelog/127629.yaml @@ -0,0 +1,5 @@ +pr: 127629 +summary: ES|QL SAMPLE aggregation function +area: Machine Learning +type: feature +issues: [] diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/sample.md b/docs/reference/query-languages/esql/_snippets/functions/description/sample.md new file mode 100644 index 0000000000000..bc767a278a6cb --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/sample.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +**Description** + +Collects sample values for a field. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/sample.md b/docs/reference/query-languages/esql/_snippets/functions/examples/sample.md new file mode 100644 index 0000000000000..af31afc75413d --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/sample.md @@ -0,0 +1,14 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +FROM employees +| STATS sample = SAMPLE(gender, 5) +``` + +| sample:keyword | +| --- | +| [F, M, M, F, M] | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/sample.md b/docs/reference/query-languages/esql/_snippets/functions/layout/sample.md new file mode 100644 index 0000000000000..dd5acebcdb5d2 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/sample.md @@ -0,0 +1,23 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +## `SAMPLE` [esql-sample] + +**Syntax** + +:::{image} ../../../images/functions/sample.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/sample.md +::: + +:::{include} ../description/sample.md +::: + +:::{include} ../types/sample.md +::: + +:::{include} ../examples/sample.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/sample.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/sample.md new file mode 100644 index 0000000000000..12d4332371198 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/sample.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field` +: The field to collect sample values for. + +`limit` +: The maximum number of values to collect. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/sample.md b/docs/reference/query-languages/esql/_snippets/functions/types/sample.md new file mode 100644 index 0000000000000..b45d32a607936 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/sample.md @@ -0,0 +1,21 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| field | limit | result | +| --- | --- | --- | +| boolean | integer | boolean | +| cartesian_point | integer | cartesian_point | +| cartesian_shape | integer | cartesian_shape | +| date | integer | date | +| date_nanos | integer | date_nanos | +| double | integer | double | +| geo_point | integer | geo_point | +| geo_shape | integer | geo_shape | +| integer | integer | integer | +| ip | integer | ip | +| keyword | integer | keyword | +| long | integer | long | +| text | integer | keyword | +| version | integer | version | + diff --git a/docs/reference/query-languages/esql/images/functions/sample.svg b/docs/reference/query-languages/esql/images/functions/sample.svg new file mode 100644 index 0000000000000..2870c29a6922c --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/sample.svg @@ -0,0 +1 @@ +SAMPLE(field,limit) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/sample.json b/docs/reference/query-languages/esql/kibana/definition/functions/sample.json new file mode 100644 index 0000000000000..df8ceadff07bf --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/sample.json @@ -0,0 +1,265 @@ +{ + "comment" : "This is generated by ESQL’s AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "sample", + "description" : "Collects sample values for a field.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "boolean", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "cartesian_point", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "cartesian_point" + }, + { + "params" : [ + { + "name" : "field", + "type" : "cartesian_shape", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "cartesian_shape" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date_nanos", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "geo_point", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "geo_point" + }, + { + "params" : [ + { + "name" : "field", + "type" : "geo_shape", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "geo_shape" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "ip", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "ip" + }, + { + "params" : [ + { + "name" : "field", + "type" : "keyword", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "long" + }, + { + "params" : [ + { + "name" : "field", + "type" : "text", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field", + "type" : "version", + "optional" : false, + "description" : "The field to collect sample values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + } + ], + "variadic" : false, + "returnType" : "version" + } + ], + "examples" : [ + "FROM employees\n| STATS sample = SAMPLE(gender, 5)" + ], + "preview" : false, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/sample.md b/docs/reference/query-languages/esql/kibana/docs/functions/sample.md new file mode 100644 index 0000000000000..78391f41c2e8c --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/sample.md @@ -0,0 +1,9 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +### SAMPLE +Collects sample values for a field. + +```esql +FROM employees +| STATS sample = SAMPLE(gender, 5) +``` diff --git a/server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java b/server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java index 31c6401f32549..0abbb5d935156 100644 --- a/server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java +++ b/server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java @@ -81,7 +81,7 @@ static class IntermediateResult { private static long C1 = 0x87c37b91114253d5L; private static long C2 = 0x4cf5ad432745937fL; - protected static long fmix(long k) { + public static long fmix(long k) { k ^= k >>> 33; k *= 0xff51afd7ed558ccdL; k ^= k >>> 33; diff --git a/test/framework/src/main/java/org/elasticsearch/test/MixWithIncrement.java b/test/framework/src/main/java/org/elasticsearch/test/MixWithIncrement.java new file mode 100644 index 0000000000000..1d3f0c3cb06dd --- /dev/null +++ b/test/framework/src/main/java/org/elasticsearch/test/MixWithIncrement.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.test; + +import com.carrotsearch.randomizedtesting.SeedDecorator; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.hash.MurmurHash3; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * The {@link Randomness} class creates random generators with the same seed + * in every thread. + *

+ * This means that repeatedly calling: + *

+ *   {@code
+ *     new Thread(() -> System.out.println(Randomness.get().nextInt())).start();
+ *   }
+ * 
+ * will print the same number in every thread. + *

+ * For some use cases, this is not desirable, e.g. when testing that the random + * behavior obeys certain statistical properties. + *

+ * To fix this, annotate a test class with: + *

+ *   {@code
+ *     @SeedDecorators(MixWithIncrement.class)
+ *   }
+ * 
+ * In this way, an additional seed is mixed into the seed of the random generators. + * This additional seed can be updated be calling: + *
+ *   {@code
+ *     MixWithIncrement.next()
+ *   }
+ * 
+ * to make sure that new threads will get a different seed. + */ +public class MixWithIncrement implements SeedDecorator { + + private static final AtomicLong mix = new AtomicLong(1); + + @Override + public void initialize(Class aClass) { + next(); + } + + public long decorate(long seed) { + return seed ^ mix.get(); + } + + public static void next() { + mix.updateAndGet(MurmurHash3::fmix); + } +} diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 1655a07110d96..66867ae668fcc 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -649,6 +649,33 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/aggregation/StdDevDoubleAggregator.java" } + File sampleAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st") + template { + it.properties = booleanProperties + it.inputFile = sampleAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java" + } + template { + it.properties = bytesRefProperties + it.inputFile = sampleAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java" + } + template { + it.properties = doubleProperties + it.inputFile = sampleAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java" + } + template { + it.properties = intProperties + it.inputFile = sampleAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/SampleIntAggregator.java" + } + template { + it.properties = longProperties + it.inputFile = sampleAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/SampleLongAggregator.java" + } + File topAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st") template { it.properties = intProperties diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java new file mode 100644 index 0000000000000..e0388c216ae56 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBooleanAggregator.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for boolean. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class SampleBooleanAggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, boolean value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, boolean value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try (BooleanBlock.Builder booleanBlock = driverContext.blockFactory().newBooleanBlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + booleanBlock.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + booleanBlock.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + booleanBlock.appendBoolean(ENCODER.decodeBoolean(value)); + } + if (valueCount > 1) { + booleanBlock.endPositionEntry(); + } + } + } + block.close(); + return booleanBlock.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, boolean value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encodeBoolean(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add(boolean value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java new file mode 100644 index 0000000000000..2291e759d9677 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleBytesRefAggregator.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for BytesRef. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class SampleBytesRefAggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, BytesRef value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, BytesRef value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try (BytesRefBlock.Builder BytesRefBlock = driverContext.blockFactory().newBytesRefBlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + BytesRefBlock.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + BytesRefBlock.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + BytesRefBlock.appendBytesRef(ENCODER.decodeBytesRef(value, scratch)); + } + if (valueCount > 1) { + BytesRefBlock.endPositionEntry(); + } + } + } + block.close(); + return BytesRefBlock.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, BytesRef value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encodeBytesRef(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add(BytesRef value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java new file mode 100644 index 0000000000000..6a5a33bb06255 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleDoubleAggregator.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for double. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class SampleDoubleAggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, double value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, double value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try (DoubleBlock.Builder doubleBlock = driverContext.blockFactory().newDoubleBlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + doubleBlock.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + doubleBlock.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + doubleBlock.appendDouble(ENCODER.decodeDouble(value)); + } + if (valueCount > 1) { + doubleBlock.endPositionEntry(); + } + } + } + block.close(); + return doubleBlock.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, double value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encodeDouble(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add(double value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java new file mode 100644 index 0000000000000..762367387bc27 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleIntAggregator.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for int. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class SampleIntAggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, int value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, int value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try (IntBlock.Builder intBlock = driverContext.blockFactory().newIntBlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + intBlock.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + intBlock.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + intBlock.appendInt(ENCODER.decodeInt(value)); + } + if (valueCount > 1) { + intBlock.endPositionEntry(); + } + } + } + block.close(); + return intBlock.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, int value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encodeInt(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add(int value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java new file mode 100644 index 0000000000000..1cb5931575513 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/SampleLongAggregator.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for long. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class SampleLongAggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, long value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, long value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try (LongBlock.Builder longBlock = driverContext.blockFactory().newLongBlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + longBlock.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + longBlock.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + longBlock.appendLong(ENCODER.decodeLong(value)); + } + if (valueCount > 1) { + longBlock.endPositionEntry(); + } + } + } + block.close(); + return longBlock.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, long value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encodeLong(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add(long value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java index b9ee302f45b24..c0e299d57f6bb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeDoubleGroupingAggregatorFunction.java @@ -13,7 +13,8 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -72,7 +73,12 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -88,7 +94,12 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -103,31 +114,45 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { + if (groups.isNull(groupPosition)) { continue; } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); + } } } } - private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); + } } } - private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -149,7 +174,7 @@ private void addRawInput(int positionOffset, IntBlock groups, DoubleBlock values } } - private void addRawInput(int positionOffset, IntBlock groups, DoubleVector values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -165,6 +190,30 @@ private void addRawInput(int positionOffset, IntBlock groups, DoubleVector value } } + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(v), values.getDouble(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeDoubleAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getDouble(valuePosition)); + } + } + @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java index ad3f37cd22a00..df4b6c843ff75 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeFloatGroupingAggregatorFunction.java @@ -13,7 +13,8 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.FloatVector; -import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -72,7 +73,12 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -88,7 +94,12 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -103,31 +114,45 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { + if (groups.isNull(groupPosition)) { continue; } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); + } } } } - private void addRawInput(int positionOffset, IntVector groups, FloatVector values, + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); + } } } - private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -149,7 +174,7 @@ private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values, } } - private void addRawInput(int positionOffset, IntBlock groups, FloatVector values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -165,6 +190,30 @@ private void addRawInput(int positionOffset, IntBlock groups, FloatVector values } } + private void addRawInput(int positionOffset, IntVector groups, FloatBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(v), values.getFloat(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeFloatAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getFloat(valuePosition)); + } + } + @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java index 9253aa51831b2..d0252f8b420d0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeIntGroupingAggregatorFunction.java @@ -11,6 +11,8 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; @@ -70,7 +72,12 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -86,7 +93,12 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -101,31 +113,45 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntVector groups, IntBlock values, + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { + if (groups.isNull(groupPosition)) { continue; } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); + } } } } - private void addRawInput(int positionOffset, IntVector groups, IntVector values, + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); + } } } - private void addRawInput(int positionOffset, IntBlock groups, IntBlock values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -147,7 +173,7 @@ private void addRawInput(int positionOffset, IntBlock groups, IntBlock values, } } - private void addRawInput(int positionOffset, IntBlock groups, IntVector values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -163,6 +189,30 @@ private void addRawInput(int positionOffset, IntBlock groups, IntVector values, } } + private void addRawInput(int positionOffset, IntVector groups, IntBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(v), values.getInt(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeIntAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getInt(valuePosition)); + } + } + @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java index e5a372c767b73..8506d1e8d527b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/FirstOverTimeLongGroupingAggregatorFunction.java @@ -11,7 +11,8 @@ import java.util.List; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.LongVector; @@ -70,7 +71,12 @@ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenG } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesBlock, timestampsVector); } @@ -86,7 +92,12 @@ public void close() { } return new GroupingAggregatorFunction.AddInput() { @Override - public void add(int positionOffset, IntBlock groupIds) { + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { addRawInput(positionOffset, groupIds, valuesVector, timestampsVector); } @@ -101,31 +112,45 @@ public void close() { }; } - private void addRawInput(int positionOffset, IntVector groups, LongBlock values, + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - if (values.isNull(groupPosition + positionOffset)) { + if (groups.isNull(groupPosition)) { continue; } - int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); - int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); - for (int v = valuesStart; v < valuesEnd; v++) { - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); + } } } } - private void addRawInput(int positionOffset, IntVector groups, LongVector values, + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { - int groupId = groups.getInt(groupPosition); - var valuePosition = groupPosition + positionOffset; - FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); + } } } - private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -147,7 +172,7 @@ private void addRawInput(int positionOffset, IntBlock groups, LongBlock values, } } - private void addRawInput(int positionOffset, IntBlock groups, LongVector values, + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values, LongVector timestamps) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { if (groups.isNull(groupPosition)) { @@ -163,6 +188,30 @@ private void addRawInput(int positionOffset, IntBlock groups, LongVector values, } } + private void addRawInput(int positionOffset, IntVector groups, LongBlock values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(v), values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values, + LongVector timestamps) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + var valuePosition = groupPosition + positionOffset; + FirstOverTimeLongAggregator.combine(state, groupId, timestamps.getLong(valuePosition), values.getLong(valuePosition)); + } + } + @Override public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { state.enableGroupIdTracking(seenGroupIds); diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunction.java new file mode 100644 index 0000000000000..45a6a2d060813 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunction.java @@ -0,0 +1,167 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link SampleBooleanAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class SampleBooleanAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final SampleBooleanAggregator.SingleState state; + + private final List channels; + + private final int limit; + + public SampleBooleanAggregatorFunction(DriverContext driverContext, List channels, + SampleBooleanAggregator.SingleState state, int limit) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + } + + public static SampleBooleanAggregatorFunction create(DriverContext driverContext, + List channels, int limit) { + return new SampleBooleanAggregatorFunction(driverContext, channels, SampleBooleanAggregator.initSingle(driverContext.bigArrays(), limit), limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + BooleanBlock block = page.getBlock(channels.get(0)); + BooleanVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + BooleanBlock block = page.getBlock(channels.get(0)); + BooleanVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(BooleanVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + SampleBooleanAggregator.combine(state, vector.getBoolean(i)); + } + } + + private void addRawVector(BooleanVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + SampleBooleanAggregator.combine(state, vector.getBoolean(i)); + } + } + + private void addRawBlock(BooleanBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleBooleanAggregator.combine(state, block.getBoolean(i)); + } + } + } + + private void addRawBlock(BooleanBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleBooleanAggregator.combine(state, block.getBoolean(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + assert sample.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + SampleBooleanAggregator.combineIntermediate(state, sample); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = SampleBooleanAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..13c65284663fd --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionSupplier.java @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link SampleBooleanAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class SampleBooleanAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + public SampleBooleanAggregatorFunctionSupplier(int limit) { + this.limit = limit; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return SampleBooleanAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return SampleBooleanGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public SampleBooleanAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return SampleBooleanAggregatorFunction.create(driverContext, channels, limit); + } + + @Override + public SampleBooleanGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return SampleBooleanGroupingAggregatorFunction.create(channels, driverContext, limit); + } + + @Override + public String describe() { + return "sample of booleans"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..cec8ea8b6c21a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBooleanGroupingAggregatorFunction.java @@ -0,0 +1,260 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SampleBooleanAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class SampleBooleanGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final SampleBooleanAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + public SampleBooleanGroupingAggregatorFunction(List channels, + SampleBooleanAggregator.GroupingState state, DriverContext driverContext, int limit) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + } + + public static SampleBooleanGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit) { + return new SampleBooleanGroupingAggregatorFunction(channels, SampleBooleanAggregator.initGrouping(driverContext.bigArrays(), limit), driverContext, limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BooleanBlock valuesBlock = page.getBlock(channels.get(0)); + BooleanVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset)); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleBooleanAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + SampleBooleanAggregator.GroupingState inState = ((SampleBooleanGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + SampleBooleanAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext evaluatorContext) { + blocks[offset] = SampleBooleanAggregator.evaluateFinal(state, selected, evaluatorContext.driverContext()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunction.java new file mode 100644 index 0000000000000..9b3a7718d9898 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunction.java @@ -0,0 +1,171 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link SampleBytesRefAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class SampleBytesRefAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final SampleBytesRefAggregator.SingleState state; + + private final List channels; + + private final int limit; + + public SampleBytesRefAggregatorFunction(DriverContext driverContext, List channels, + SampleBytesRefAggregator.SingleState state, int limit) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + } + + public static SampleBytesRefAggregatorFunction create(DriverContext driverContext, + List channels, int limit) { + return new SampleBytesRefAggregatorFunction(driverContext, channels, SampleBytesRefAggregator.initSingle(driverContext.bigArrays(), limit), limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + BytesRefBlock block = page.getBlock(channels.get(0)); + BytesRefVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + BytesRefBlock block = page.getBlock(channels.get(0)); + BytesRefVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(BytesRefVector vector) { + BytesRef scratch = new BytesRef(); + for (int i = 0; i < vector.getPositionCount(); i++) { + SampleBytesRefAggregator.combine(state, vector.getBytesRef(i, scratch)); + } + } + + private void addRawVector(BytesRefVector vector, BooleanVector mask) { + BytesRef scratch = new BytesRef(); + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + SampleBytesRefAggregator.combine(state, vector.getBytesRef(i, scratch)); + } + } + + private void addRawBlock(BytesRefBlock block) { + BytesRef scratch = new BytesRef(); + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleBytesRefAggregator.combine(state, block.getBytesRef(i, scratch)); + } + } + } + + private void addRawBlock(BytesRefBlock block, BooleanVector mask) { + BytesRef scratch = new BytesRef(); + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleBytesRefAggregator.combine(state, block.getBytesRef(i, scratch)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + assert sample.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + SampleBytesRefAggregator.combineIntermediate(state, sample); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = SampleBytesRefAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..26a4fdefb4236 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionSupplier.java @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link SampleBytesRefAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class SampleBytesRefAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + public SampleBytesRefAggregatorFunctionSupplier(int limit) { + this.limit = limit; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return SampleBytesRefAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return SampleBytesRefGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public SampleBytesRefAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return SampleBytesRefAggregatorFunction.create(driverContext, channels, limit); + } + + @Override + public SampleBytesRefGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return SampleBytesRefGroupingAggregatorFunction.create(channels, driverContext, limit); + } + + @Override + public String describe() { + return "sample of bytes"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..60e38edd06d1f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleBytesRefGroupingAggregatorFunction.java @@ -0,0 +1,265 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SampleBytesRefAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class SampleBytesRefGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final SampleBytesRefAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + public SampleBytesRefGroupingAggregatorFunction(List channels, + SampleBytesRefAggregator.GroupingState state, DriverContext driverContext, int limit) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + } + + public static SampleBytesRefGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit) { + return new SampleBytesRefGroupingAggregatorFunction(channels, SampleBytesRefAggregator.initGrouping(driverContext.bigArrays(), limit), driverContext, limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); + BytesRefVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleBytesRefAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + SampleBytesRefAggregator.GroupingState inState = ((SampleBytesRefGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + SampleBytesRefAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext evaluatorContext) { + blocks[offset] = SampleBytesRefAggregator.evaluateFinal(state, selected, evaluatorContext.driverContext()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..b308b4cff4860 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunction.java @@ -0,0 +1,168 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link SampleDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class SampleDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final SampleDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + public SampleDoubleAggregatorFunction(DriverContext driverContext, List channels, + SampleDoubleAggregator.SingleState state, int limit) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + } + + public static SampleDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit) { + return new SampleDoubleAggregatorFunction(driverContext, channels, SampleDoubleAggregator.initSingle(driverContext.bigArrays(), limit), limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + DoubleBlock block = page.getBlock(channels.get(0)); + DoubleVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(DoubleVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + SampleDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawVector(DoubleVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + SampleDoubleAggregator.combine(state, vector.getDouble(i)); + } + } + + private void addRawBlock(DoubleBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + private void addRawBlock(DoubleBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleDoubleAggregator.combine(state, block.getDouble(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + assert sample.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + SampleDoubleAggregator.combineIntermediate(state, sample); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = SampleDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..4c98a3a6fac32 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link SampleDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class SampleDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + public SampleDoubleAggregatorFunctionSupplier(int limit) { + this.limit = limit; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return SampleDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return SampleDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public SampleDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return SampleDoubleAggregatorFunction.create(driverContext, channels, limit); + } + + @Override + public SampleDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return SampleDoubleGroupingAggregatorFunction.create(channels, driverContext, limit); + } + + @Override + public String describe() { + return "sample of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..cd76527394432 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleDoubleGroupingAggregatorFunction.java @@ -0,0 +1,260 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SampleDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class SampleDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final SampleDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + public SampleDoubleGroupingAggregatorFunction(List channels, + SampleDoubleAggregator.GroupingState state, DriverContext driverContext, int limit) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + } + + public static SampleDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit) { + return new SampleDoubleGroupingAggregatorFunction(channels, SampleDoubleAggregator.initGrouping(driverContext.bigArrays(), limit), driverContext, limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock valuesBlock = page.getBlock(channels.get(0)); + DoubleVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleDoubleAggregator.combine(state, groupId, values.getDouble(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleDoubleAggregator.combine(state, groupId, values.getDouble(groupPosition + positionOffset)); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleDoubleAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + SampleDoubleAggregator.GroupingState inState = ((SampleDoubleGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + SampleDoubleAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext evaluatorContext) { + blocks[offset] = SampleDoubleAggregator.evaluateFinal(state, selected, evaluatorContext.driverContext()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunction.java new file mode 100644 index 0000000000000..97f31295e829e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunction.java @@ -0,0 +1,168 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link SampleIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class SampleIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final SampleIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + public SampleIntAggregatorFunction(DriverContext driverContext, List channels, + SampleIntAggregator.SingleState state, int limit) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + } + + public static SampleIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit) { + return new SampleIntAggregatorFunction(driverContext, channels, SampleIntAggregator.initSingle(driverContext.bigArrays(), limit), limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + IntBlock block = page.getBlock(channels.get(0)); + IntVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(IntVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + SampleIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawVector(IntVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + SampleIntAggregator.combine(state, vector.getInt(i)); + } + } + + private void addRawBlock(IntBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleIntAggregator.combine(state, block.getInt(i)); + } + } + } + + private void addRawBlock(IntBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleIntAggregator.combine(state, block.getInt(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + assert sample.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + SampleIntAggregator.combineIntermediate(state, sample); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = SampleIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..5798cf9860f6c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionSupplier.java @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link SampleIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class SampleIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + public SampleIntAggregatorFunctionSupplier(int limit) { + this.limit = limit; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return SampleIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return SampleIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public SampleIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return SampleIntAggregatorFunction.create(driverContext, channels, limit); + } + + @Override + public SampleIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return SampleIntGroupingAggregatorFunction.create(channels, driverContext, limit); + } + + @Override + public String describe() { + return "sample of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..b2cf3114fa951 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleIntGroupingAggregatorFunction.java @@ -0,0 +1,259 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SampleIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class SampleIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final SampleIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + public SampleIntGroupingAggregatorFunction(List channels, + SampleIntAggregator.GroupingState state, DriverContext driverContext, int limit) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + } + + public static SampleIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit) { + return new SampleIntGroupingAggregatorFunction(channels, SampleIntAggregator.initGrouping(driverContext.bigArrays(), limit), driverContext, limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock valuesBlock = page.getBlock(channels.get(0)); + IntVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleIntAggregator.combine(state, groupId, values.getInt(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleIntAggregator.combine(state, groupId, values.getInt(groupPosition + positionOffset)); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleIntAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + SampleIntAggregator.GroupingState inState = ((SampleIntGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + SampleIntAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext evaluatorContext) { + blocks[offset] = SampleIntAggregator.evaluateFinal(state, selected, evaluatorContext.driverContext()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunction.java new file mode 100644 index 0000000000000..269f0f14f166a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunction.java @@ -0,0 +1,168 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link SampleLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class SampleLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final SampleLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + public SampleLongAggregatorFunction(DriverContext driverContext, List channels, + SampleLongAggregator.SingleState state, int limit) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + } + + public static SampleLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit) { + return new SampleLongAggregatorFunction(driverContext, channels, SampleLongAggregator.initSingle(driverContext.bigArrays(), limit), limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(LongVector vector) { + for (int i = 0; i < vector.getPositionCount(); i++) { + SampleLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawVector(LongVector vector, BooleanVector mask) { + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + SampleLongAggregator.combine(state, vector.getLong(i)); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleLongAggregator.combine(state, block.getLong(i)); + } + } + } + + private void addRawBlock(LongBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + SampleLongAggregator.combine(state, block.getLong(i)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + assert sample.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + SampleLongAggregator.combineIntermediate(state, sample); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = SampleLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..09ca5ab1c1d3b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionSupplier.java @@ -0,0 +1,50 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link SampleLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class SampleLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + public SampleLongAggregatorFunctionSupplier(int limit) { + this.limit = limit; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return SampleLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return SampleLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public SampleLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return SampleLongAggregatorFunction.create(driverContext, channels, limit); + } + + @Override + public SampleLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return SampleLongGroupingAggregatorFunction.create(channels, driverContext, limit); + } + + @Override + public String describe() { + return "sample of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..afb1e94a23f5a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SampleLongGroupingAggregatorFunction.java @@ -0,0 +1,260 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link SampleLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class SampleLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sample", ElementType.BYTES_REF) ); + + private final SampleLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + public SampleLongGroupingAggregatorFunction(List channels, + SampleLongAggregator.GroupingState state, DriverContext driverContext, int limit) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + } + + public static SampleLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit) { + return new SampleLongGroupingAggregatorFunction(channels, SampleLongAggregator.initGrouping(driverContext.bigArrays(), limit), driverContext, limit); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + SampleLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + SampleLongAggregator.combine(state, groupId, values.getLong(v)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleLongAggregator.combine(state, groupId, values.getLong(groupPosition + positionOffset)); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sampleUncast = page.getBlock(channels.get(0)); + if (sampleUncast.areAllValuesNull()) { + return; + } + BytesRefBlock sample = (BytesRefBlock) sampleUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + SampleLongAggregator.combineIntermediate(state, groupId, sample, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + SampleLongAggregator.GroupingState inState = ((SampleLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + SampleLongAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext evaluatorContext) { + blocks[offset] = SampleLongAggregator.evaluateFinal(state, selected, evaluatorContext.driverContext()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st new file mode 100644 index 0000000000000..f13857002e848 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-SampleAggregator.java.st @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.$Type$Block; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.topn.DefaultUnsortableTopNEncoder; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +import org.elasticsearch.common.Randomness; +import java.util.random.RandomGenerator; +// end generated imports + +/** + * Sample N field values for $type$. + *

+ * This class is generated. Edit `X-SampleAggregator.java.st` to edit this file. + *

+ *

+ * This works by prepending a random long to the value, and then collecting the + * top values. This gives a uniform random sample of the values. See also: + * Wikipedia Reservoir Sampling + *

+ */ +@Aggregator({ @IntermediateState(name = "sample", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class Sample$Type$Aggregator { + private static final DefaultUnsortableTopNEncoder ENCODER = new DefaultUnsortableTopNEncoder(); + + public static SingleState initSingle(BigArrays bigArrays, int limit) { + return new SingleState(bigArrays, limit); + } + + public static void combine(SingleState state, $type$ value) { + state.add(value); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.internalState.sort.collect(values.getBytesRef(i, scratch), 0); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory())); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit) { + return new GroupingState(bigArrays, limit); + } + + public static void combine(GroupingState state, int groupId, $type$ value) { + state.add(groupId, value); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + BytesRef scratch = new BytesRef(); + for (int i = start; i < end; i++) { + state.sort.collect(values.getBytesRef(i, scratch), groupId); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return stripWeights(driverContext, state.toBlock(driverContext.blockFactory(), selected)); + } + + private static Block stripWeights(DriverContext driverContext, Block block) { + if (block.areAllValuesNull()) { + return block; + } + BytesRefBlock bytesRefBlock = (BytesRefBlock) block; + try ($Type$Block.Builder $type$Block = driverContext.blockFactory().new$Type$BlockBuilder(bytesRefBlock.getPositionCount())) { + BytesRef scratch = new BytesRef(); + for (int position = 0; position < block.getPositionCount(); position++) { + if (bytesRefBlock.isNull(position)) { + $type$Block.appendNull(); + } else { + int valueCount = bytesRefBlock.getValueCount(position); + if (valueCount > 1) { + $type$Block.beginPositionEntry(); + } + int start = bytesRefBlock.getFirstValueIndex(position); + int end = start + valueCount; + for (int i = start; i < end; i++) { + BytesRef value = bytesRefBlock.getBytesRef(i, scratch).clone(); + ENCODER.decodeLong(value); + $type$Block.append$Type$(ENCODER.decode$Type$(value$if(BytesRef)$, scratch$endif$)); + } + if (valueCount > 1) { + $type$Block.endPositionEntry(); + } + } + } + block.close(); + return $type$Block.build(); + } + } + + public static class GroupingState implements GroupingAggregatorState { + private final BytesRefBucketedSort sort; + private final BreakingBytesRefBuilder bytesRefBuilder; + + private GroupingState(BigArrays bigArrays, int limit) { + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "sample", bigArrays, SortOrder.ASC, limit); + boolean success = false; + try { + this.bytesRefBuilder = new BreakingBytesRefBuilder(breaker, "sample"); + success = true; + } finally { + if (success == false) { + Releasables.closeExpectNoException(sort); + } + } + } + + public void add(int groupId, $type$ value) { + ENCODER.encodeLong(Randomness.get().nextLong(), bytesRefBuilder); + ENCODER.encode$Type$(value, bytesRefBuilder); + sort.collect(bytesRefBuilder.bytesRefView(), groupId); + bytesRefBuilder.clear(); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort, bytesRefBuilder); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit) { + this.internalState = new GroupingState(bigArrays, limit); + } + + public void add($type$ value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/DefaultUnsortableTopNEncoder.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/DefaultUnsortableTopNEncoder.java index f1ae4cab8a4bd..df1025f89c8eb 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/DefaultUnsortableTopNEncoder.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/DefaultUnsortableTopNEncoder.java @@ -18,7 +18,7 @@ * A {@link TopNEncoder} that doesn't encode values so they are sortable but is * capable of encoding any values. */ -final class DefaultUnsortableTopNEncoder implements TopNEncoder { +public final class DefaultUnsortableTopNEncoder implements TopNEncoder { public static final VarHandle LONG = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.nativeOrder()); public static final VarHandle INT = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.nativeOrder()); public static final VarHandle FLOAT = MethodHandles.byteArrayViewVarHandle(float[].class, ByteOrder.nativeOrder()); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionTests.java new file mode 100644 index 0000000000000..80a547cdb7774 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBooleanAggregatorFunctionTests.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import com.carrotsearch.randomizedtesting.annotations.SeedDecorators; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.SequenceBooleanBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.test.MixWithIncrement; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; + +@SeedDecorators(MixWithIncrement.class) +public class SampleBooleanAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 50; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceBooleanBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToObj(l -> randomBoolean())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new SampleBooleanAggregatorFunctionSupplier(LIMIT); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sample of booleans"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + List inputValues = input.stream().flatMap(AggregatorFunctionTestCase::allBooleans).collect(Collectors.toList()); + Boolean[] resultValues = AggregatorFunctionTestCase.allBooleans(result).toArray(Boolean[]::new); + assertThat(resultValues, arrayWithSize(Math.min(inputValues.size(), LIMIT))); + } + + public void testDistribution() { + // Sample from the numbers 50x true and 50x false. + int N = 100; + Aggregator.Factory aggregatorFactory = aggregatorFunction().aggregatorFactory(AggregatorMode.SINGLE, List.of(0)); + AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory( + List.of(aggregatorFactory), + AggregatorMode.SINGLE + ); + + // Repeat 1000x, count how often each value is sampled. + int trueCount = 0; + int falseCount = 0; + for (int iteration = 0; iteration < 1000; iteration++) { + List input = CannedSourceOperator.collectPages( + new SequenceBooleanBlockSourceOperator(driverContext().blockFactory(), IntStream.range(0, N).mapToObj(i -> i % 2 == 0)) + ); + List results = drive(operatorFactory.get(driverContext()), input.iterator(), driverContext()); + for (Page page : results) { + BooleanBlock block = page.getBlock(0); + for (int i = 0; i < block.getTotalValueCount(); i++) { + if (block.getBoolean(i)) { + trueCount++; + } else { + falseCount++; + } + } + } + MixWithIncrement.next(); + } + + // On average, both boolean values should be sampled 25000x. + // The interval [23000,27000] is at least 10 sigma, so this should never fail. + assertThat(trueCount, both(greaterThan(23000)).and(lessThan(27000))); + assertThat(falseCount, both(greaterThan(23000)).and(lessThan(27000))); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionTests.java new file mode 100644 index 0000000000000..17d20d83b7879 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleBytesRefAggregatorFunctionTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import com.carrotsearch.randomizedtesting.annotations.SeedDecorators; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.test.MixWithIncrement; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.lessThan; + +@SeedDecorators(MixWithIncrement.class) +public class SampleBytesRefAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 50; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceBytesRefBlockSourceOperator( + blockFactory, + IntStream.range(0, size).mapToObj(l -> new BytesRef(randomAlphanumericOfLength(100))) + ); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new SampleBytesRefAggregatorFunctionSupplier(LIMIT); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sample of bytes"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Set inputValues = input.stream().flatMap(AggregatorFunctionTestCase::allBytesRefs).collect(Collectors.toSet()); + BytesRef[] resultValues = AggregatorFunctionTestCase.allBytesRefs(result).toArray(BytesRef[]::new); + assertThat(resultValues, arrayWithSize(Math.min(inputValues.size(), LIMIT))); + assertThat(inputValues, hasItems(resultValues)); + } + + public void testDistribution() { + // Sample from the numbers 0...99. + int N = 100; + Aggregator.Factory aggregatorFactory = aggregatorFunction().aggregatorFactory(AggregatorMode.SINGLE, List.of(0)); + AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory( + List.of(aggregatorFactory), + AggregatorMode.SINGLE + ); + + // Repeat 1000x, count how often each number is sampled. + int[] sampledCounts = new int[N]; + for (int iteration = 0; iteration < 1000; iteration++) { + List input = CannedSourceOperator.collectPages( + new SequenceBytesRefBlockSourceOperator( + driverContext().blockFactory(), + IntStream.range(0, N).mapToObj(i -> new BytesRef(Integer.toString(i))) + ) + ); + List results = drive(operatorFactory.get(driverContext()), input.iterator(), driverContext()); + for (Page page : results) { + BytesRefBlock block = page.getBlock(0); + BytesRef scratch = new BytesRef(); + for (int i = 0; i < block.getTotalValueCount(); i++) { + sampledCounts[Integer.parseInt(block.getBytesRef(i, scratch).utf8ToString())]++; + } + } + MixWithIncrement.next(); + } + + // On average, each string should be sampled 500x. + // The interval [300,700] is approx. 10 sigma, so this should never fail. + for (int i = 0; i < N; i++) { + assertThat(sampledCounts[i], both(greaterThan(300)).and(lessThan(700))); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionTests.java new file mode 100644 index 0000000000000..73e7ed8eee5e1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleDoubleAggregatorFunctionTests.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import com.carrotsearch.randomizedtesting.annotations.SeedDecorators; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.SequenceDoubleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.test.MixWithIncrement; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.lessThan; + +@SeedDecorators(MixWithIncrement.class) +public class SampleDoubleAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 50; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceDoubleBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToDouble(l -> randomDouble())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new SampleDoubleAggregatorFunctionSupplier(LIMIT); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sample of doubles"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Set inputValues = input.stream() + .flatMapToDouble(AggregatorFunctionTestCase::allDoubles) + .boxed() + .collect(Collectors.toSet()); + Double[] resultValues = AggregatorFunctionTestCase.allDoubles(result).boxed().toArray(Double[]::new); + assertThat(resultValues, arrayWithSize(Math.min(inputValues.size(), LIMIT))); + assertThat(inputValues, hasItems(resultValues)); + } + + public void testDistribution() { + // Sample from the numbers 0...99. + int N = 100; + Aggregator.Factory aggregatorFactory = aggregatorFunction().aggregatorFactory(AggregatorMode.SINGLE, List.of(0)); + AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory( + List.of(aggregatorFactory), + AggregatorMode.SINGLE + ); + + // Repeat 1000x, count how often each number is sampled. + int[] sampledCounts = new int[N]; + for (int iteration = 0; iteration < 1000; iteration++) { + List input = CannedSourceOperator.collectPages( + new SequenceDoubleBlockSourceOperator(driverContext().blockFactory(), IntStream.range(0, N).asDoubleStream()) + ); + List results = drive(operatorFactory.get(driverContext()), input.iterator(), driverContext()); + for (Page page : results) { + DoubleBlock block = page.getBlock(0); + for (int i = 0; i < block.getTotalValueCount(); i++) { + sampledCounts[(int) block.getDouble(i)]++; + } + } + MixWithIncrement.next(); + } + + // On average, each number should be sampled 500x. + // The interval [300,700] is approx. 10 sigma, so this should never fail. + for (int i = 0; i < N; i++) { + assertThat(sampledCounts[i], both(greaterThan(300)).and(lessThan(700))); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionTests.java new file mode 100644 index 0000000000000..84b4fec44c289 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleIntAggregatorFunctionTests.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import com.carrotsearch.randomizedtesting.annotations.SeedDecorators; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.SequenceIntBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.test.MixWithIncrement; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.lessThan; + +@SeedDecorators(MixWithIncrement.class) +public class SampleIntAggregatorFunctionTests extends AggregatorFunctionTestCase { + private static final int LIMIT = 50; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceIntBlockSourceOperator(blockFactory, IntStream.range(0, size).map(l -> randomInt())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new SampleIntAggregatorFunctionSupplier(LIMIT); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sample of ints"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Set inputValues = input.stream().flatMapToInt(AggregatorFunctionTestCase::allInts).boxed().collect(Collectors.toSet()); + Integer[] resultValues = AggregatorFunctionTestCase.allInts(result).boxed().toArray(Integer[]::new); + assertThat(resultValues, arrayWithSize(Math.min(inputValues.size(), LIMIT))); + assertThat(inputValues, hasItems(resultValues)); + } + + public void testDistribution() { + // Sample from the numbers 0...99. + int N = 100; + Aggregator.Factory aggregatorFactory = aggregatorFunction().aggregatorFactory(AggregatorMode.SINGLE, List.of(0)); + AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory( + List.of(aggregatorFactory), + AggregatorMode.SINGLE + ); + + // Repeat 1000x, count how often each number is sampled. + int[] sampledCounts = new int[N]; + for (int iteration = 0; iteration < 1000; iteration++) { + List input = CannedSourceOperator.collectPages( + new SequenceIntBlockSourceOperator(driverContext().blockFactory(), IntStream.range(0, N)) + ); + List results = drive(operatorFactory.get(driverContext()), input.iterator(), driverContext()); + for (Page page : results) { + IntBlock block = page.getBlock(0); + for (int i = 0; i < block.getTotalValueCount(); i++) { + sampledCounts[block.getInt(i)]++; + } + } + MixWithIncrement.next(); + } + + // On average, each number should be sampled 500x. + // The interval [300,700] is approx. 10 sigma, so this should never fail. + for (int i = 0; i < N; i++) { + assertThat(sampledCounts[i], both(greaterThan(300)).and(lessThan(700))); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionTests.java new file mode 100644 index 0000000000000..54c92dfb10cb7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SampleLongAggregatorFunctionTests.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import com.carrotsearch.randomizedtesting.annotations.SeedDecorators; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.AggregationOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.test.CannedSourceOperator; +import org.elasticsearch.compute.test.SequenceLongBlockSourceOperator; +import org.elasticsearch.test.MixWithIncrement; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.arrayWithSize; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.lessThan; + +@SeedDecorators(MixWithIncrement.class) +public class SampleLongAggregatorFunctionTests extends AggregatorFunctionTestCase { + + private static final int LIMIT = 50; + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLong())); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction() { + return new SampleLongAggregatorFunctionSupplier(LIMIT); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "sample of longs"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Set inputValues = input.stream().flatMapToLong(AggregatorFunctionTestCase::allLongs).boxed().collect(Collectors.toSet()); + Long[] resultValues = AggregatorFunctionTestCase.allLongs(result).boxed().toArray(Long[]::new); + assertThat(resultValues, arrayWithSize(Math.min(inputValues.size(), LIMIT))); + assertThat(inputValues, hasItems(resultValues)); + } + + public void testDistribution() { + // Sample from the numbers 0...99. + int N = 100; + Aggregator.Factory aggregatorFactory = aggregatorFunction().aggregatorFactory(AggregatorMode.SINGLE, List.of(0)); + AggregationOperator.AggregationOperatorFactory operatorFactory = new AggregationOperator.AggregationOperatorFactory( + List.of(aggregatorFactory), + AggregatorMode.SINGLE + ); + + // Repeat 1000x, count how often each number is sampled. + int[] sampledCounts = new int[N]; + for (int iteration = 0; iteration < 1000; iteration++) { + List input = CannedSourceOperator.collectPages( + new SequenceLongBlockSourceOperator(driverContext().blockFactory(), LongStream.range(0, N)) + ); + List results = drive(operatorFactory.get(driverContext()), input.iterator(), driverContext()); + for (Page page : results) { + LongBlock block = page.getBlock(0); + for (int i = 0; i < block.getTotalValueCount(); i++) { + sampledCounts[(int) block.getLong(i)]++; + } + } + MixWithIncrement.next(); + } + + // On average, each number should be sampled 500x. + // The interval [300,700] is approx. 10 sigma, so this should never fail. + for (int i = 0; i < N; i++) { + assertThat(sampledCounts[i], both(greaterThan(300)).and(lessThan(700))); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBooleanBlockSourceOperator.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBooleanBlockSourceOperator.java index 97fb380b6aac1..bb1fda7b2223b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBooleanBlockSourceOperator.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/SequenceBooleanBlockSourceOperator.java @@ -13,6 +13,7 @@ import org.elasticsearch.compute.test.AbstractBlockSourceOperator; import java.util.List; +import java.util.stream.Stream; /** * A source operator whose output is the given boolean values. This operator produces pages @@ -24,6 +25,10 @@ public class SequenceBooleanBlockSourceOperator extends AbstractBlockSourceOpera private final boolean[] values; + public SequenceBooleanBlockSourceOperator(BlockFactory blockFactory, Stream values) { + this(blockFactory, values.toList()); + } + public SequenceBooleanBlockSourceOperator(BlockFactory blockFactory, List values) { this(blockFactory, values, DEFAULT_MAX_PAGE_POSITIONS); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sample.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sample.csv-spec new file mode 100644 index 0000000000000..d464a57319592 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_sample.csv-spec @@ -0,0 +1,249 @@ +// Tests focused on the SAMPLE aggregation function +// Note: this tests only basic behavior, because of the non-deterministic +// behavior of SAMPLE and limitations of the CSV tests. + + +documentation +required_capability: agg_sample + +// tag::doc[] +FROM employees +| STATS sample = SAMPLE(gender, 5) +// end::doc[] +// Hardcode the sample values to work around the limitations of the CSV tests in the +// presence of randomness, and be able to specify an expected result for the docs. +| EVAL sample = ["F", "M", "M", "F", "M"] +; + +// tag::doc-result[] +sample:keyword +[F, M, M, F, M] +// end::doc-result[] +; + + +sample size +required_capability: agg_sample + +FROM employees +| STATS sample_boolean = SAMPLE(still_hired, 1), + sample_datetime = SAMPLE(hire_date, 2), + sample_double = SAMPLE(height, 3), + sample_integer = SAMPLE(emp_no, 4), + sample_keyword = SAMPLE(first_name, 5), + sample_long = SAMPLE(languages.long, 6) +| EVAL count_boolean = MV_COUNT(sample_boolean), + count_datetime = MV_COUNT(sample_datetime), + count_double = MV_COUNT(sample_double), + count_integer = MV_COUNT(sample_integer), + count_keyword = MV_COUNT(sample_keyword), + count_long = MV_COUNT(sample_long) +| KEEP count_* +; + +count_boolean:integer | count_datetime:integer | count_double:integer | count_integer:integer | count_keyword:integer | count_long:integer +1 | 2 | 3 | 4 | 5 | 6 +; + + +sample values (boolean, datetime, double, integer, keyword, long) +required_capability: agg_sample + +FROM employees +| SORT emp_no +| LIMIT 3 +| STATS sample_boolean = MV_SORT(SAMPLE(still_hired, 99)), + sample_datetime = MV_SORT(SAMPLE(hire_date, 99)), + sample_double = MV_SORT(SAMPLE(height, 99)), + sample_integer = MV_SORT(SAMPLE(emp_no, 99)), + sample_keyword = MV_SORT(SAMPLE(first_name, 99)), + sample_long = MV_SORT(SAMPLE(languages.long, 99)) +; + +sample_boolean:boolean | sample_datetime:datetime | sample_double:double | sample_integer:integer | sample_keyword:keyword | sample_long:long +[false, true, true] | [1985-11-21T00:00:00.000Z, 1986-06-26T00:00:00.000Z, 1986-08-28T00:00:00.000Z] | [1.83, 2.03, 2.08] | [10001, 10002, 10003] | [Bezalel, Georgi, Parto] | [2, 4, 5] +; + + +multivalued +required_capability: agg_sample + +FROM mv_sample_data +| STATS sample = SAMPLE(message, 20) +| EVAL sample = MV_SORT(sample) +; + +sample:keyword +[Banana, Banana, Banana, Banana, Banana, Banana, Banana, Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3, Connection error, Connection error, Connection error, Disconnected] +; + + +some null input +required_capability: agg_sample + +FROM employees +| SORT emp_no +| LIMIT 15 +| STATS sample = MV_SORT(SAMPLE(gender, 999)) +; + +sample:keyword +[F, F, F, F, M, M, M, M, M] +; + + +some null output +required_capability: agg_sample + +FROM employees +| WHERE emp_no >= 10008 AND emp_no <= 10011 +| STATS sample = SAMPLE(gender, 1) BY emp_no +| SORT emp_no +; + +sample:keyword | emp_no:integer +M | 10008 +F | 10009 +null | 10010 +null | 10011 +; + + +stats by +required_capability: agg_sample + +FROM employees +| STATS sample_keyword = MV_SORT(SAMPLE(gender, 999)), + sample_integer = MV_SORT(SAMPLE(salary, 999)) BY job_positions +| SORT job_positions +; + +sample_keyword:keyword | sample_integer:integer | job_positions:keyword +[F, F, F, F, F, M, M, M, M, M, M, M, M, M, M, M] | [25976, 31897, 35742, 37691, 39356, 39728, 39878, 43026, 43602, 47411, 47896, 48942, 50128, 57305, 58121, 61358, 66817, 74970] | Accountant +[F, F, F, F, F, F, F, M, M, M, M] | [28941, 30404, 31120, 37716, 42716, 43889, 44307, 44817, 45797, 54518, 62233, 62405, 69904] | Architect +[F, F, F, F, M, M, M, M, M, M, M] | [29175, 30404, 35742, 36051, 37853, 39638, 39878, 40612, 41933, 50249, 58121] | Business Analyst +[F, M, M, M, M, M, M, M, M, M, M] | [25945, 29175, 31897, 34341, 37137, 39878, 42716, 48233, 50249, 56415, 58715, 67492, 74999] | Data Scientist +[F, F, M, M, M, M] | [25324, 27215, 36174, 37137, 39110, 48942, 49281, 50064, 56415, 58715] | Head Human Resources +[F, F, F, F, F, F, M, M, M, M, M, M, M, M, M] | [26436, 30404, 31897, 32272, 39356, 43906, 44817, 46595, 48233, 49281, 50064, 50128, 56415, 66174, 69904] | Internship +[F, F, F, F, F, F, F, M, M, M, M, M] | [25324, 25976, 30404, 32272, 32568, 41933, 43026, 43602, 43906, 50064, 56760, 62233, 64675, 74970] | Junior Developer +[F, F, F, F, F, F, M, M, M, M, M, M, M, M, M, M, M, M, M] | [25324, 28035, 32568, 36051, 37112, 38376, 39728, 42716, 44307, 45656, 49818, 50064, 50249, 52044, 60335, 65367, 66817, 69904, 74970, 74999] | Principal Support Engineer +[F, F, F, F, F, F, M, M, M, M, M, M, M, M] | [32568, 33956, 37716, 41933, 43906, 44307, 45656, 45797, 47896, 49095, 51956, 58121, 58715, 61358, 62233, 68431, 73717, 74970] | Purchase Manager +[F, F, F, M, M, M, M, M, M, M, M, M] | [27215, 32568, 34341, 35222, 36051, 38645, 38992, 39356, 39878, 48233, 54518, 61358, 65030] | Python Developer +[F, M, M, M, M, M, M, M, M, M] | [28336, 31120, 36174, 37137, 38645, 39638, 40612, 43026, 43889, 45656, 45797, 48233, 48735, 61358, 71165] | Reporting Analyst +[F, F, F, F, F, F, F, M, M, M, M, M, M, M, M, M, M, M, M, M] | [25945, 31897, 35222, 35742, 37691, 37716, 37853, 38992, 43906, 49281, 52833, 57305, 60781, 62233, 62405, 66174, 66817, 68547, 73851, 74999] | Senior Python Developer +[F, F, F, F, F, F, F, M, M, M, M, M, M, M] | [29175, 31120, 33370, 37716, 40612, 42716, 44307, 44817, 49095, 54518, 56371, 56415, 60335, 65030, 67492] | Senior Team Lead +[F, F, F, F, M, M, M, M, M, M] | [25324, 34341, 35222, 36174, 39728, 41933, 43026, 47896, 49281, 54462, 60408] | Support Engineer +[F, F, F, F, M, M, M, M, M, M, M, M, M] | [31120, 35742, 36174, 37691, 39356, 39638, 39728, 40031, 45656, 45797, 52044, 54518, 60335, 67492, 71165] | Tech Lead +[F, F, F, F, M, M, M, M, M, M, M] | [32263, 37702, 44956, 52121, 54329, 55360, 61805, 63528, 70011, 73578, 74572] | null +; + + +multiple samples are different +required_capability: agg_sample + +FROM employees +| STATS sample1 = MV_SORT(SAMPLE(last_name, 50)), + sample2 = MV_SORT(SAMPLE(last_name, 50)) +| EVAL samples = MV_ZIP(sample1, sample2, "|") +| KEEP samples +| MV_EXPAND samples +| EVAL tokens = SPLIT(samples, "|"), + token_different = MV_SLICE(tokens, 0) != MV_SLICE(tokens, 1) +| WHERE token_different == true +| STATS token_different_count = COUNT() +| EVAL samples_different = token_different_count > 0 +| KEEP samples_different +; + +samples_different:boolean +true +; + + +sample cartesian_point +required_capability: agg_sample + +FROM airports_web | SORT abbrev | LIMIT 3 | STATS sample = SAMPLE(location, 999) | EVAL sample = MV_SORT(sample) +; + +sample:cartesian_point +[POINT (809321.6344269889 1006514.3393965173), POINT (-1.1868515102256078E7 4170563.5012235222), POINT (-437732.64923689933 585738.5549131387)] +; + + +sample cartesian_shape +required_capability: agg_sample + +FROM cartesian_multipolygons | SORT id | LIMIT 1 | STATS sample = SAMPLE(shape, 999) | MV_EXPAND sample +; + +sample:cartesian_shape +MULTIPOLYGON (((0.0 0.0, 1.0 0.0, 1.0 1.0, 0.0 1.0, 0.0 0.0)),((2.0 0.0, 3.0 0.0, 3.0 1.0, 2.0 1.0, 2.0 0.0)),((2.0 2.0, 3.0 2.0, 3.0 3.0, 2.0 3.0, 2.0 2.0)),((0.0 2.0, 1.0 2.0, 1.0 3.0, 0.0 3.0, 0.0 2.0))) +; + + +sample date_nanos +required_capability: agg_sample + +FROM date_nanos | STATS sample = SAMPLE(nanos,999) | EVAL sample = MV_SORT(sample) +; + +sample:date_nanos +[2023-01-23T13:55:01.543123456Z, 2023-02-23T13:33:34.937193Z, 2023-03-23T12:15:03.360103847Z, 2023-03-23T12:15:03.360103847Z, 2023-03-23T12:15:03.360103847Z, 2023-03-23T12:15:03.360103847Z, 2023-10-23T12:15:03.360103847Z, 2023-10-23T12:15:03.360103847Z, 2023-10-23T12:27:28.948Z, 2023-10-23T13:33:34.937193Z, 2023-10-23T13:51:54.732102837Z, 2023-10-23T13:52:55.015787878Z, 2023-10-23T13:53:55.832987654Z, 2023-10-23T13:55:01.543123456Z] +; + + +sample geo_point +required_capability: agg_sample + +FROM airports | SORT abbrev | LIMIT 2 | STATS sample = SAMPLE(location, 999) | EVAL sample = MV_SORT(sample) +; + +sample:geo_point +[POINT (-106.6166851616 35.0491578018276), POINT (-3.93221929167636 5.2543984451492)] +; + + +sample geo_shape +required_capability: agg_sample + +FROM countries_bbox | SORT id | LIMIT 1 | STATS sample = SAMPLE(shape, 999) +; + +sample:geo_shape +BBOX (-70.059664, -69.874864, 12.627773, 12.411109) +; + + +sample ip +required_capability: agg_sample + +FROM k8s | SORT @timestamp | LIMIT 5 | STATS sample = SAMPLE(client.ip,999) | EVAL sample = MV_SORT(sample) +; + +sample:ip +[10.10.20.30, 10.10.20.30, 10.10.20.31, 10.10.20.34, 10.10.20.34] +; + + +sample text +required_capability: agg_sample + +FROM books | SORT book_no | LIMIT 3 | STATS sample = SAMPLE(title,999) | EVAL sample = MV_SORT(sample) +; + +sample:keyword +[Realms of Tolkien: Images of Middle-earth, Selected Passages from Correspondence with Friends, The brothers Karamazov] +; + + + +sample version +required_capability: agg_sample + +FROM apps | STATS sample = SAMPLE(version,999) | EVAL sample = MV_SORT(sample) +; + +sample:version +[1, 1.2.3.4, 1.2.3.4, 1.11.0, 2.1, 2.3.4, 2.12.0, 5.2.9-SNAPSHOT, 5.2.9, 5.2.9, 5.2.9, bad] +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index fe22f1737a7f9..07990a72e99cc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1069,7 +1069,12 @@ public enum Cap { /** * Resolve groupings before resolving references to groupings in the aggregations. */ - RESOLVE_GROUPINGS_BEFORE_RESOLVING_REFERENCES_TO_GROUPINGS_IN_AGGREGATIONS; + RESOLVE_GROUPINGS_BEFORE_RESOLVING_REFERENCES_TO_GROUPINGS_IN_AGGREGATIONS, + + /** + * Support for the SAMPLE aggregation function + */ + AGG_SAMPLE; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index b5b509ef7ec11..b2d85d809c058 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.MinOverTime; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Sample; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent; import org.elasticsearch.xpack.esql.expression.function.aggregate.StdDev; @@ -296,6 +297,7 @@ private static FunctionDefinition[][] functions() { def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"), def(Min.class, uni(Min::new), "min"), def(Percentile.class, bi(Percentile::new), "percentile"), + def(Sample.class, bi(Sample::new), "sample"), def(StdDev.class, uni(StdDev::new), "std_dev"), def(Sum.class, uni(Sum::new), "sum"), def(Top.class, tri(Top::new), "top"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java index 776111ca6bb08..7387d3bb8ecb2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateWritables.java @@ -24,6 +24,7 @@ public static List getNamedWriteables() { Min.ENTRY, Percentile.ENTRY, Rate.ENTRY, + Sample.ENTRY, SpatialCentroid.ENTRY, SpatialExtent.ENTRY, StdDev.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sample.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sample.java new file mode 100644 index 0000000000000..781f9ad67c05c --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sample.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SampleBooleanAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SampleBytesRefAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SampleDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SampleIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.SampleLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.FunctionType; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; +import org.elasticsearch.xpack.esql.planner.ToAggregator; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; + +public class Sample extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sample", Sample::new); + + @FunctionInfo( + returnType = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "version" }, + description = "Collects sample values for a field.", + type = FunctionType.AGGREGATE, + examples = @Example(file = "stats_sample", tag = "doc") + ) + public Sample( + Source source, + @Param( + name = "field", + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "integer", + "ip", + "keyword", + "long", + "text", + "version" }, + description = "The field to collect sample values for." + ) Expression field, + @Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit + ) { + this(source, field, Literal.TRUE, limit); + } + + public Sample(Source source, Expression field, Expression filter, Expression limit) { + this(source, field, filter, limit, new Literal(Source.EMPTY, Randomness.get().nextLong(), DataType.LONG)); + } + + /** + * The query "FROM data | STATS s1=SAMPLE(x,N), s2=SAMPLE(x,N)" should give two different + * samples of size N. The uuid is used to ensure that the optimizer does not optimize both + * expressions to one, resulting in identical samples. + */ + public Sample(Source source, Expression field, Expression filter, Expression limit, Expression uuid) { + super(source, field, filter, List.of(limit, uuid)); + } + + private Sample(StreamInput in) throws IOException { + super(in); + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + var typeResolution = isType(field(), dt -> dt != DataType.UNSIGNED_LONG, sourceText(), FIRST, "any type except unsigned_long").and( + isNotNullAndFoldable(limitField(), sourceText(), SECOND) + ).and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer")); + if (typeResolution.unresolved()) { + return typeResolution; + } + int limit = limitValue(); + if (limit <= 0) { + return new TypeResolution(format(null, "Limit must be greater than 0 in [{}], found [{}]", sourceText(), limit)); + } + return TypeResolution.TYPE_RESOLVED; + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public DataType dataType() { + return field().dataType().noText(); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Sample::new, field(), filter(), limitField(), uuid()); + } + + @Override + public Sample replaceChildren(List newChildren) { + return new Sample(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); + } + + @Override + public AggregatorFunctionSupplier supplier() { + return switch (PlannerUtils.toElementType(field().dataType())) { + case BOOLEAN -> new SampleBooleanAggregatorFunctionSupplier(limitValue()); + case BYTES_REF -> new SampleBytesRefAggregatorFunctionSupplier(limitValue()); + case DOUBLE -> new SampleDoubleAggregatorFunctionSupplier(limitValue()); + case INT -> new SampleIntAggregatorFunctionSupplier(limitValue()); + case LONG -> new SampleLongAggregatorFunctionSupplier(limitValue()); + default -> throw EsqlIllegalArgumentException.illegalDataType(field().dataType()); + }; + } + + @Override + public Sample withFilter(Expression filter) { + return new Sample(source(), field(), filter, limitField(), uuid()); + } + + Expression limitField() { + return parameters().get(0); + } + + private int limitValue() { + return (int) limitField().fold(FoldContext.small() /* TODO remove me */); + } + + Expression uuid() { + return parameters().get(1); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SampleTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SampleTests.java new file mode 100644 index 0000000000000..61ad7345a4a7b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SampleTests.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.search.Multiset; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; + +public class SampleTests extends AbstractAggregationTestCase { + public SampleTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 100, false)) { + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000), + MultiRowTestCaseSupplier.dateNanosCases(1, 1000), + MultiRowTestCaseSupplier.booleanCases(1, 1000), + MultiRowTestCaseSupplier.ipCases(1, 1000), + MultiRowTestCaseSupplier.versionCases(1, 1000), + MultiRowTestCaseSupplier.stringCases(1, 1000, DataType.KEYWORD), + MultiRowTestCaseSupplier.stringCases(1, 1000, DataType.TEXT), + MultiRowTestCaseSupplier.geoPointCases(1, 1000, MultiRowTestCaseSupplier.IncludingAltitude.NO), + MultiRowTestCaseSupplier.cartesianPointCases(1, 1000, MultiRowTestCaseSupplier.IncludingAltitude.NO), + MultiRowTestCaseSupplier.geoShapeCasesWithoutCircle(1, 100, MultiRowTestCaseSupplier.IncludingAltitude.NO), + MultiRowTestCaseSupplier.cartesianShapeCasesWithoutCircle(1, 100, MultiRowTestCaseSupplier.IncludingAltitude.NO) + ) + .flatMap(List::stream) + .map(fieldCaseSupplier -> makeSupplier(fieldCaseSupplier, limitCaseSupplier)) + .collect(Collectors.toCollection(() -> suppliers)); + } + return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Sample(source, args.get(0), args.get(1)); + } + + private static TestCaseSupplier makeSupplier( + TestCaseSupplier.TypedDataSupplier fieldSupplier, + TestCaseSupplier.TypedDataSupplier limitCaseSupplier + ) { + return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type(), limitCaseSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var limitTypedData = limitCaseSupplier.get().forceLiteral(); + var limit = (int) limitTypedData.getValue(); + + var rows = fieldTypedData.multiRowData(); + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData, limitTypedData), + "Sample[field=Attribute[channel=0], limit=Attribute[channel=1]]", + fieldSupplier.type(), + subsetOfSize(rows, limit) + ); + }); + } + + private static Matcher subsetOfSize(Collection data, int size) { + if (data == null || data.isEmpty()) { + return nullValue(); + } + if (data.size() == 1) { + return equalTo(data.iterator().next()); + } + // New Matcher, as `containsInAnyOrder` returns Matcher> instead of Matcher + return new BaseMatcher<>() { + @Override + public void describeTo(Description description) { + description.appendText("subset of size ").appendValue(size).appendText(" of ").appendValue(data); + } + + @Override + public boolean matches(Object object) { + Iterable items = object instanceof Iterable ? (Iterable) object : List.of(object); + Multiset dataSet = new Multiset<>(); + dataSet.addAll(data); + for (Object item : items) { + if (dataSet.remove(item) == false) { + return false; + } + } + return true; + } + }; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java index 3b83beebc9c04..11c7476eb4b03 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesTests.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.HashSet; import java.util.List; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -67,27 +68,21 @@ protected Expression build(Source source, List args) { return new Values(source, args.get(0)); } - @SuppressWarnings("unchecked") private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type()), () -> { var fieldTypedData = fieldSupplier.get(); - - var expected = fieldTypedData.multiRowData() - .stream() - .map(v -> (Comparable>) v) - .collect(Collectors.toSet()); - + var expected = new HashSet<>(fieldTypedData.multiRowData()); return new TestCaseSupplier.TestCase( List.of(fieldTypedData), "Values[field=Attribute[channel=0]]", fieldSupplier.type(), - expected.isEmpty() ? nullValue() : valuesInAnyOrder(expected) + valuesInAnyOrder(expected) ); }); } - private static Matcher valuesInAnyOrder(Collection data) { - if (data == null) { + private static Matcher valuesInAnyOrder(Collection data) { + if (data == null || data.isEmpty()) { return nullValue(); } if (data.size() == 1) { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 61eee781bd7cc..18ea2903d1d86 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -122,8 +122,8 @@ setup: - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} - # Testing for the entire function set isn't feasbile, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 140} # check the "sister" test below for a likely update to the same esql.functions length check + # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. + - length: {esql.functions: 141} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version":