diff --git a/docs/changelog/135434.yaml b/docs/changelog/135434.yaml new file mode 100644 index 0000000000000..0a1506087a427 --- /dev/null +++ b/docs/changelog/135434.yaml @@ -0,0 +1,6 @@ +pr: 135434 +summary: Support extra field in TOP function +area: ES|QL +type: enhancement +issues: + - 128630 diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/top.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/top.md index 057bd91855a1a..c56bf558809d1 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/top.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/top.md @@ -11,3 +11,6 @@ `order` : The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted. +`outputField` +: The extra field that, if present, will be the output of the TOP call instead of `field`. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/top.md b/docs/reference/query-languages/esql/_snippets/functions/types/top.md index 559a779cd1d9d..063fb157980d0 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/types/top.md +++ b/docs/reference/query-languages/esql/_snippets/functions/types/top.md @@ -2,22 +2,38 @@ **Supported types** -| field | limit | order | result | -| --- | --- | --- | --- | -| boolean | integer | keyword | boolean | -| boolean | integer | | boolean | -| date | integer | keyword | date | -| date | integer | | date | -| double | integer | keyword | double | -| double | integer | | double | -| integer | integer | keyword | integer | -| integer | integer | | integer | -| ip | integer | keyword | ip | -| ip | integer | | ip | -| keyword | integer | keyword | keyword | -| keyword | integer | | keyword | -| long | integer | keyword | long | -| long | integer | | long | -| text | integer | keyword | keyword | -| text | integer | | keyword | +| field | limit | order | outputField | result | +| --- | --- | --- | --- | --- | +| boolean | integer | keyword | | boolean | +| boolean | integer | | | boolean | +| date | integer | keyword | date | date | +| date | integer | keyword | double | double | +| date | integer | keyword | integer | integer | +| date | integer | keyword | long | long | +| date | integer | keyword | | date | +| date | integer | | | date | +| double | integer | keyword | date | date | +| double | integer | keyword | double | double | +| double | integer | keyword | integer | integer | +| double | integer | keyword | long | long | +| double | integer | keyword | | double | +| double | integer | | | double | +| integer | integer | keyword | date | date | +| integer | integer | keyword | double | double | +| integer | integer | keyword | integer | integer | +| integer | integer | keyword | long | long | +| integer | integer | keyword | | integer | +| integer | integer | | | integer | +| ip | integer | keyword | | ip | +| ip | integer | | | ip | +| keyword | integer | keyword | | keyword | +| keyword | integer | | | keyword | +| long | integer | keyword | date | date | +| long | integer | keyword | double | double | +| long | integer | keyword | integer | integer | +| long | integer | keyword | long | long | +| long | integer | keyword | | long | +| long | integer | | | long | +| text | integer | keyword | | keyword | +| text | integer | | | keyword | diff --git a/docs/reference/query-languages/esql/images/functions/top.svg b/docs/reference/query-languages/esql/images/functions/top.svg index 947890a49f31c..1987d050f981a 100644 --- a/docs/reference/query-languages/esql/images/functions/top.svg +++ b/docs/reference/query-languages/esql/images/functions/top.svg @@ -1 +1 @@ -TOP(field,limit,order) \ No newline at end of file +TOP(field,limit,order,outputField) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/top.json b/docs/reference/query-languages/esql/kibana/definition/functions/top.json index 8cef8d534764e..6308eed3bff61 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/top.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/top.json @@ -88,6 +88,126 @@ "variadic" : false, "returnType" : "date" }, + { + "params" : [ + { + "name" : "field", + "type" : "date", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "date", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "double", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "integer", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "date", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "long", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "long" + }, { "params" : [ { @@ -130,6 +250,126 @@ "variadic" : false, "returnType" : "double" }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "date", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "double", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "integer", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "long", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "long" + }, { "params" : [ { @@ -172,6 +412,126 @@ "variadic" : false, "returnType" : "integer" }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "date", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "double", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "integer", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "long", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "long" + }, { "params" : [ { @@ -298,6 +658,126 @@ "variadic" : false, "returnType" : "long" }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "date", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "double", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "integer", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : true, + "description" : "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." + }, + { + "name" : "outputField", + "type" : "long", + "optional" : true, + "description" : "The extra field that, if present, will be the output of the TOP call instead of `field`." + } + ], + "variadic" : false, + "returnType" : "long" + }, { "params" : [ { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java index 56bbf69c76e55..50a8c46018b08 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java @@ -211,6 +211,18 @@ public static TypeResolution isType( ParamOrdinal paramOrd, boolean allowUnionTypes, String... acceptedTypes + ) { + return isType(e, predicate, null, operationName, paramOrd, allowUnionTypes, acceptedTypes); + } + + public static TypeResolution isType( + Expression e, + Predicate predicate, + String errorMessagePrefix, + String operationName, + ParamOrdinal paramOrd, + boolean allowUnionTypes, + String... acceptedTypes ) { if (predicate.test(e.dataType()) || e.dataType() == NULL) { return TypeResolution.TYPE_RESOLVED; @@ -225,11 +237,19 @@ public static TypeResolution isType( } return new TypeResolution( - errorStringIncompatibleTypes(operationName, paramOrd, name(e), e.dataType(), acceptedTypesForErrorMsg(acceptedTypes)) + errorStringIncompatibleTypes( + errorMessagePrefix, + operationName, + paramOrd, + name(e), + e.dataType(), + acceptedTypesForErrorMsg(acceptedTypes) + ) ); } private static String errorStringIncompatibleTypes( + String errorMessagePrefix, String operationName, ParamOrdinal paramOrd, String argumentName, @@ -237,7 +257,7 @@ private static String errorStringIncompatibleTypes( String... acceptedTypes ) { return format( - null, + errorMessagePrefix, "{}argument of [{}] must be [{}], found value [{}] type [{}]", paramOrd == null || paramOrd == DEFAULT ? "" : paramOrd.name().toLowerCase(Locale.ROOT) + " ", operationName, diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index bd4bb33873be5..44f9bdd331238 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -51,7 +51,7 @@ spotless { * Generated files go here. */ toggleOffOn('begin generated imports', 'end generated imports') - targetExclude "src/main/generated/**/*.java" + targetExclude "src/main/generated*/**/*.java" } } @@ -81,6 +81,26 @@ def prop(Name, Type, type, Wrapper, TYPE, BYTES, Array, Hash) { ] } +def propWithoutExtra(prop1, extraPrefix) { + def res = [ ("has" + extraPrefix): "" ] + for ( e in prop1 ) { + res.put(e.key, e.value) + res.put(extraPrefix + e.key, "") + } + return res +} + +def propWithExtra(prop1, prop2, extraPrefix) { + def res = [ ("has" + extraPrefix): "true" ] + for ( e in prop1 ) { + res.put(e.key, e.value) + } + for ( e in prop2 ) { + res.put(extraPrefix + e.key, e.value) + } + return res +} + def addOccurrence(props, Occurrence) { def newProps = props.collectEntries { [(it.key): it.value] } newProps["Occurrence"] = Occurrence @@ -723,40 +743,29 @@ tasks.named('stringTemplates').configure { } File topAggregatorInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st") - template { - it.properties = intProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopIntAggregator.java" - } - template { - it.properties = longProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopLongAggregator.java" - } - template { - it.properties = floatProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopFloatAggregator.java" - } - template { - it.properties = doubleProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopDoubleAggregator.java" - } - template { - it.properties = booleanProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopBooleanAggregator.java" - } - template { - it.properties = bytesRefProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java" + // Simple TOP when the sort field and the output field are the same field + [intProperties, longProperties, floatProperties, doubleProperties, booleanProperties, bytesRefProperties, ipProperties].forEach { props -> + { + template { + it.properties = propWithoutExtra(props, "OutputField") + it.inputFile = topAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/Top${props.Name}Aggregator.java" + } + } } - template { - it.properties = ipProperties - it.inputFile = topAggregatorInputFile - it.outputFile = "org/elasticsearch/compute/aggregation/TopIpAggregator.java" + // TOP when the sort field and the output field can be *different* fields + [intProperties, longProperties, floatProperties, doubleProperties].forEach { props1 -> + { + [intProperties, longProperties, floatProperties, doubleProperties].forEach { props2 -> + { + template { + it.properties = propWithExtra(props1, props2, "OutputField") + it.inputFile = topAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/Top${props1.Name}${props2.Name}Aggregator.java" + } + } + } + } } File multivalueDedupeInputFile = file("src/main/java/org/elasticsearch/compute/operator/mvdedupe/X-MultivalueDedupe.java.st") @@ -896,25 +905,28 @@ tasks.named('stringTemplates').configure { } File bucketedSortInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st") - template { - it.properties = intProperties - it.inputFile = bucketedSortInputFile - it.outputFile = "org/elasticsearch/compute/data/sort/IntBucketedSort.java" - } - template { - it.properties = longProperties - it.inputFile = bucketedSortInputFile - it.outputFile = "org/elasticsearch/compute/data/sort/LongBucketedSort.java" - } - template { - it.properties = floatProperties - it.inputFile = bucketedSortInputFile - it.outputFile = "org/elasticsearch/compute/data/sort/FloatBucketedSort.java" + [intProperties, longProperties, floatProperties, doubleProperties].forEach { props -> + { + template { + it.properties = propWithoutExtra(props, "Extra") + it.inputFile = bucketedSortInputFile + it.outputFile = "org/elasticsearch/compute/data/sort/${props.Name}BucketedSort.java" + } + } } - template { - it.properties = doubleProperties - it.inputFile = bucketedSortInputFile - it.outputFile = "org/elasticsearch/compute/data/sort/DoubleBucketedSort.java" + // TOP when the sort field and the output field can be *different* fields + [intProperties, longProperties, floatProperties, doubleProperties].forEach { props1 -> + { + [intProperties, longProperties, floatProperties, doubleProperties].forEach { props2 -> + { + template { + it.properties = propWithExtra(props1, props2, "Extra") + it.inputFile = bucketedSortInputFile + it.outputFile = "org/elasticsearch/compute/data/sort/${props1.Name}${props2.Name}BucketedSort.java" + } + } + } + } } File enrichResultBuilderInput = file("src/main/java/org/elasticsearch/compute/operator/lookup/X-EnrichResultBuilder.java.st") diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java index f93e3095524c4..05964998003ea 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.sort.BooleanBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -118,7 +117,9 @@ public void add(boolean value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java index ecc68e7d8a992..88a5d5021983f 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -122,7 +121,9 @@ public void add(BytesRef value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java index e9e1803e36fff..df1294c610ed5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.sort.DoubleBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -118,7 +117,9 @@ public void add(double value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregator.java new file mode 100644 index 0000000000000..45cd12e43c182 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.DoubleDoubleBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for double. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "DOUBLE_BLOCK"), @IntermediateState(name = "output", type = "DOUBLE_BLOCK") }) +@GroupingAggregator +class TopDoubleDoubleAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, double v, double outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, DoubleBlock values, DoubleBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getDouble(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, double v, double outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, DoubleBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getDouble(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final DoubleDoubleBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new DoubleDoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, double value, double outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(double value, double outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregator.java new file mode 100644 index 0000000000000..da103d70c0a54 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.DoubleFloatBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for double. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "DOUBLE_BLOCK"), @IntermediateState(name = "output", type = "FLOAT_BLOCK") }) +@GroupingAggregator +class TopDoubleFloatAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, double v, float outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, DoubleBlock values, FloatBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getDouble(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, double v, float outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, FloatBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getDouble(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final DoubleFloatBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new DoubleFloatBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, double value, float outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(double value, float outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleIntAggregator.java new file mode 100644 index 0000000000000..2b0790c303e59 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleIntAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.DoubleIntBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for double. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "DOUBLE_BLOCK"), @IntermediateState(name = "output", type = "INT_BLOCK") }) +@GroupingAggregator +class TopDoubleIntAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, double v, int outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, DoubleBlock values, IntBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getDouble(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, double v, int outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, IntBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getDouble(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final DoubleIntBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new DoubleIntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, double value, int outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(double value, int outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleLongAggregator.java new file mode 100644 index 0000000000000..57dd898fce1e4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleLongAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.DoubleLongBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for double. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "DOUBLE_BLOCK"), @IntermediateState(name = "output", type = "LONG_BLOCK") }) +@GroupingAggregator +class TopDoubleLongAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, double v, long outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, DoubleBlock values, LongBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getDouble(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, double v, long outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, DoubleBlock values, LongBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getDouble(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final DoubleLongBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new DoubleLongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, double value, long outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(double value, long outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java index 1b5fddc0b0038..25e20ec26ffa0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.sort.FloatBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -118,7 +117,9 @@ public void add(float value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregator.java new file mode 100644 index 0000000000000..ae4ba17da2529 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.FloatDoubleBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for float. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "FLOAT_BLOCK"), @IntermediateState(name = "output", type = "DOUBLE_BLOCK") }) +@GroupingAggregator +class TopFloatDoubleAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, float v, double outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, FloatBlock values, DoubleBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getFloat(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, float v, double outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, DoubleBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getFloat(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final FloatDoubleBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new FloatDoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, float value, double outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(float value, double outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatFloatAggregator.java new file mode 100644 index 0000000000000..e2876b5cddffa --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatFloatAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.FloatFloatBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for float. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "FLOAT_BLOCK"), @IntermediateState(name = "output", type = "FLOAT_BLOCK") }) +@GroupingAggregator +class TopFloatFloatAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, float v, float outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, FloatBlock values, FloatBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getFloat(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, float v, float outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, FloatBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getFloat(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final FloatFloatBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new FloatFloatBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, float value, float outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(float value, float outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatIntAggregator.java new file mode 100644 index 0000000000000..0de205c508713 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatIntAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.FloatIntBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for float. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "FLOAT_BLOCK"), @IntermediateState(name = "output", type = "INT_BLOCK") }) +@GroupingAggregator +class TopFloatIntAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, float v, int outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, FloatBlock values, IntBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getFloat(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, float v, int outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, IntBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getFloat(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final FloatIntBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new FloatIntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, float value, int outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(float value, int outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatLongAggregator.java new file mode 100644 index 0000000000000..580d649abae42 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatLongAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.FloatLongBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for float. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "FLOAT_BLOCK"), @IntermediateState(name = "output", type = "LONG_BLOCK") }) +@GroupingAggregator +class TopFloatLongAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, float v, long outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, FloatBlock values, LongBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getFloat(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, float v, long outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, FloatBlock values, LongBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getFloat(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final FloatLongBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new FloatLongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, float value, long outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(float value, long outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java index aa8c5e8e1bf3f..7fc6053bf7681 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.sort.IntBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -118,7 +117,9 @@ public void add(int value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntDoubleAggregator.java new file mode 100644 index 0000000000000..10090a179b5bb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntDoubleAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.IntDoubleBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for int. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "INT_BLOCK"), @IntermediateState(name = "output", type = "DOUBLE_BLOCK") }) +@GroupingAggregator +class TopIntDoubleAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, int v, double outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, IntBlock values, DoubleBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getInt(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, int v, double outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, DoubleBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getInt(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final IntDoubleBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new IntDoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, int value, double outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(int value, double outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntFloatAggregator.java new file mode 100644 index 0000000000000..e007e66e8b526 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntFloatAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.IntFloatBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for int. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "INT_BLOCK"), @IntermediateState(name = "output", type = "FLOAT_BLOCK") }) +@GroupingAggregator +class TopIntFloatAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, int v, float outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, IntBlock values, FloatBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getInt(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, int v, float outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, FloatBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getInt(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final IntFloatBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new IntFloatBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, int value, float outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(int value, float outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntIntAggregator.java new file mode 100644 index 0000000000000..72b2065e1fe77 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntIntAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.IntIntBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for int. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "INT_BLOCK"), @IntermediateState(name = "output", type = "INT_BLOCK") }) +@GroupingAggregator +class TopIntIntAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, int v, int outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, IntBlock values, IntBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getInt(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, int v, int outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, IntBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getInt(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final IntIntBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new IntIntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, int value, int outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(int value, int outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntLongAggregator.java new file mode 100644 index 0000000000000..a3eba01bd34cf --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntLongAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.IntLongBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for int. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "INT_BLOCK"), @IntermediateState(name = "output", type = "LONG_BLOCK") }) +@GroupingAggregator +class TopIntLongAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, int v, long outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, IntBlock values, LongBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getInt(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, int v, long outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, IntBlock values, LongBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getInt(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final IntLongBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new IntLongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, int value, long outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(int value, long outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java index 831f573cb3cd0..118aa86b43dbe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.sort.IpBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -120,7 +119,9 @@ public void add(BytesRef value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java index a31ee1afd8a07..a066ac3d779e0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.sort.LongBucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; @@ -118,7 +117,9 @@ public void add(long value) { @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongDoubleAggregator.java new file mode 100644 index 0000000000000..31c4f9096ca0e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongDoubleAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.LongDoubleBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for long. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "LONG_BLOCK"), @IntermediateState(name = "output", type = "DOUBLE_BLOCK") }) +@GroupingAggregator +class TopLongDoubleAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, long v, double outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, LongBlock values, DoubleBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getLong(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, long v, double outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, DoubleBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getLong(i), outputValues.getDouble(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final LongDoubleBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new LongDoubleBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, long value, double outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(long value, double outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongFloatAggregator.java new file mode 100644 index 0000000000000..66d699cfd3d32 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongFloatAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.LongFloatBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for long. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "LONG_BLOCK"), @IntermediateState(name = "output", type = "FLOAT_BLOCK") }) +@GroupingAggregator +class TopLongFloatAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, long v, float outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, LongBlock values, FloatBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getLong(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, long v, float outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, FloatBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getLong(i), outputValues.getFloat(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final LongFloatBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new LongFloatBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, long value, float outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(long value, float outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongIntAggregator.java new file mode 100644 index 0000000000000..d2f661348a146 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongIntAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.LongIntBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for long. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "LONG_BLOCK"), @IntermediateState(name = "output", type = "INT_BLOCK") }) +@GroupingAggregator +class TopLongIntAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, long v, int outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, LongBlock values, IntBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getLong(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, long v, int outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, IntBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getLong(i), outputValues.getInt(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final LongIntBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new LongIntBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, long value, int outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(long value, int outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongLongAggregator.java new file mode 100644 index 0000000000000..3db32f71905c0 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongLongAggregator.java @@ -0,0 +1,140 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +// begin generated imports +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.LongLongBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; +// end generated imports + +/** + * Aggregates the top N field values for long. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "LONG_BLOCK"), @IntermediateState(name = "output", type = "LONG_BLOCK") }) +@GroupingAggregator +class TopLongLongAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, long v, long outputValue) { + state.add(v, outputValue); + } + + public static void combineIntermediate(SingleState state, LongBlock values, LongBlock outputValues) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + for (int i = start; i < end; i++) { + combine(state, values.getLong(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, long v, long outputValue) { + state.add(groupId, v, outputValue); + } + + public static void combineIntermediate(GroupingState state, int groupId, LongBlock values, LongBlock outputValues, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getLong(i), outputValues.getLong(i)); + } + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, GroupingAggregatorEvaluationContext ctx) { + return state.toBlock(ctx.blockFactory(), selected); + } + + public static class GroupingState implements GroupingAggregatorState { + private final LongLongBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + this.sort = new LongLongBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, long value, long outputValue) { + sort.collect(value, outputValue, groupId); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements AggregatorState { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(long value, long outputValue) { + internalState.add(0, value, outputValue); + } + + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java index c8c6701e68e4a..995cf68cd6a50 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleBucketedSort.java @@ -23,7 +23,7 @@ import java.util.stream.IntStream; /** - * Aggregates the top N double values per bucket. + * Aggregates the top N {@code double} values per bucket. * See {@link BucketedSort} for more information. * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. */ @@ -162,12 +162,7 @@ public void merge(int groupId, DoubleBucketedSort other, int otherGroupId) { */ public Block toBlock(BlockFactory blockFactory, IntVector selected) { // Check if the selected groups are all empty, to avoid allocating extra memory - if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { - var bounds = this.getBucketValuesIndexes(bucket); - var size = bounds.v2() - bounds.v1(); - - return size > 0; - })) { + if (allSelectedGroupsAreEmpty(selected)) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -185,7 +180,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } if (size == 1) { - builder.appendDouble(values.get(bounds.v1())); + builder.appendDouble(values.get(rootIndex)); continue; } @@ -197,7 +192,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { builder.beginPositionEntry(); for (int i = 0; i < size; i++) { - builder.appendDouble(values.get(bounds.v1() + i)); + builder.appendDouble(values.get(rootIndex + i)); } builder.endPositionEntry(); } @@ -205,6 +200,17 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + /** * Is this bucket a min heap {@code true} or in gathering mode {@code false}? */ @@ -234,7 +240,8 @@ private void setNextGatherOffset(long rootIndex, int offset) { * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ private boolean betterThan(double lhs, double rhs) { - return getOrder().reverseMul() * Double.compare(lhs, rhs) < 0; + int res = Double.compare(lhs, rhs); + return getOrder().reverseMul() * res < 0; } /** diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleDoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleDoubleBucketedSort.java new file mode 100644 index 0000000000000..6438cb1c99a62 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleDoubleBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code double} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class DoubleDoubleBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ * + */ + private DoubleArray values; + private DoubleArray extraValues; + + public DoubleDoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newDoubleArray(0, false); + extraValues = bigArrays.newDoubleArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(double value, double extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, DoubleDoubleBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendDouble(values.get(rootIndex)); + extraBuilder.appendDouble(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendDouble(values.get(rootIndex + i)); + extraBuilder.appendDouble(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(double lhs, double rhs, double lhsExtra, double rhsExtra) { + int res = Double.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Double.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.DOUBLE_PAGE_SIZE, Double.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleFloatBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleFloatBucketedSort.java new file mode 100644 index 0000000000000..f4bd6d9fdbf92 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleFloatBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code double} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class DoubleFloatBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ * + */ + private DoubleArray values; + private FloatArray extraValues; + + public DoubleFloatBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newDoubleArray(0, false); + extraValues = bigArrays.newFloatArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(double value, float extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, DoubleFloatBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendDouble(values.get(rootIndex)); + extraBuilder.appendFloat(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendDouble(values.get(rootIndex + i)); + extraBuilder.appendFloat(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(double lhs, double rhs, float lhsExtra, float rhsExtra) { + int res = Double.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Float.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.DOUBLE_PAGE_SIZE, Double.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleIntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleIntBucketedSort.java new file mode 100644 index 0000000000000..c20fb8f9ce1a4 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleIntBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code double} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class DoubleIntBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private DoubleArray values; + private IntArray extraValues; + + public DoubleIntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newDoubleArray(0, false); + extraValues = bigArrays.newIntArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(double value, int extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, DoubleIntBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newIntBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendDouble(values.get(rootIndex)); + extraBuilder.appendInt(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendDouble(values.get(rootIndex + i)); + extraBuilder.appendInt(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(double lhs, double rhs, int lhsExtra, int rhsExtra) { + int res = Double.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Integer.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.DOUBLE_PAGE_SIZE, Double.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleLongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleLongBucketedSort.java new file mode 100644 index 0000000000000..6f84c762c3258 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/DoubleLongBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code double} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class DoubleLongBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private DoubleArray values; + private LongArray extraValues; + + public DoubleLongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newDoubleArray(0, false); + extraValues = bigArrays.newLongArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(double value, long extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, DoubleLongBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newLongBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendDouble(values.get(rootIndex)); + extraBuilder.appendLong(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendDouble(values.get(rootIndex + i)); + extraBuilder.appendLong(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(double lhs, double rhs, long lhsExtra, long rhsExtra) { + int res = Double.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Long.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.DOUBLE_PAGE_SIZE, Double.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatBucketedSort.java index 4afaa818855e4..5dee9e1555526 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatBucketedSort.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatBucketedSort.java @@ -23,7 +23,7 @@ import java.util.stream.IntStream; /** - * Aggregates the top N float values per bucket. + * Aggregates the top N {@code float} values per bucket. * See {@link BucketedSort} for more information. * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. */ @@ -162,12 +162,7 @@ public void merge(int groupId, FloatBucketedSort other, int otherGroupId) { */ public Block toBlock(BlockFactory blockFactory, IntVector selected) { // Check if the selected groups are all empty, to avoid allocating extra memory - if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { - var bounds = this.getBucketValuesIndexes(bucket); - var size = bounds.v2() - bounds.v1(); - - return size > 0; - })) { + if (allSelectedGroupsAreEmpty(selected)) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -185,7 +180,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } if (size == 1) { - builder.appendFloat(values.get(bounds.v1())); + builder.appendFloat(values.get(rootIndex)); continue; } @@ -197,7 +192,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { builder.beginPositionEntry(); for (int i = 0; i < size; i++) { - builder.appendFloat(values.get(bounds.v1() + i)); + builder.appendFloat(values.get(rootIndex + i)); } builder.endPositionEntry(); } @@ -205,6 +200,17 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + /** * Is this bucket a min heap {@code true} or in gathering mode {@code false}? */ @@ -234,7 +240,8 @@ private void setNextGatherOffset(long rootIndex, int offset) { * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ private boolean betterThan(float lhs, float rhs) { - return getOrder().reverseMul() * Float.compare(lhs, rhs) < 0; + int res = Float.compare(lhs, rhs); + return getOrder().reverseMul() * res < 0; } /** diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatDoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatDoubleBucketedSort.java new file mode 100644 index 0000000000000..9e66583b2373e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatDoubleBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code float} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class FloatDoubleBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private FloatArray values; + private DoubleArray extraValues; + + public FloatDoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newFloatArray(0, false); + extraValues = bigArrays.newDoubleArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(float value, double extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, FloatDoubleBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendFloat(values.get(rootIndex)); + extraBuilder.appendDouble(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendFloat(values.get(rootIndex + i)); + extraBuilder.appendDouble(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(float lhs, float rhs, double lhsExtra, double rhsExtra) { + int res = Float.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Double.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.FLOAT_PAGE_SIZE, Float.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatFloatBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatFloatBucketedSort.java new file mode 100644 index 0000000000000..c3ad1e47f1698 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatFloatBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code float} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class FloatFloatBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private FloatArray values; + private FloatArray extraValues; + + public FloatFloatBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newFloatArray(0, false); + extraValues = bigArrays.newFloatArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(float value, float extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, FloatFloatBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendFloat(values.get(rootIndex)); + extraBuilder.appendFloat(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendFloat(values.get(rootIndex + i)); + extraBuilder.appendFloat(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(float lhs, float rhs, float lhsExtra, float rhsExtra) { + int res = Float.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Float.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.FLOAT_PAGE_SIZE, Float.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatIntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatIntBucketedSort.java new file mode 100644 index 0000000000000..fefe511a88586 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatIntBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code float} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class FloatIntBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private FloatArray values; + private IntArray extraValues; + + public FloatIntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newFloatArray(0, false); + extraValues = bigArrays.newIntArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(float value, int extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, FloatIntBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newIntBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendFloat(values.get(rootIndex)); + extraBuilder.appendInt(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendFloat(values.get(rootIndex + i)); + extraBuilder.appendInt(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(float lhs, float rhs, int lhsExtra, int rhsExtra) { + int res = Float.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Integer.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.FLOAT_PAGE_SIZE, Float.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatLongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatLongBucketedSort.java new file mode 100644 index 0000000000000..32aa8e9fdb191 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/FloatLongBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code float} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class FloatLongBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private FloatArray values; + private LongArray extraValues; + + public FloatLongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newFloatArray(0, false); + extraValues = bigArrays.newLongArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(float value, long extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, FloatLongBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newLongBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendFloat(values.get(rootIndex)); + extraBuilder.appendLong(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendFloat(values.get(rootIndex + i)); + extraBuilder.appendLong(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(float lhs, float rhs, long lhsExtra, long rhsExtra) { + int res = Float.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Long.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.FLOAT_PAGE_SIZE, Float.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java index 5ba1a3f7138a3..7dcec4461f0a5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntBucketedSort.java @@ -23,7 +23,7 @@ import java.util.stream.IntStream; /** - * Aggregates the top N int values per bucket. + * Aggregates the top N {@code int} values per bucket. * See {@link BucketedSort} for more information. * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. */ @@ -162,12 +162,7 @@ public void merge(int groupId, IntBucketedSort other, int otherGroupId) { */ public Block toBlock(BlockFactory blockFactory, IntVector selected) { // Check if the selected groups are all empty, to avoid allocating extra memory - if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { - var bounds = this.getBucketValuesIndexes(bucket); - var size = bounds.v2() - bounds.v1(); - - return size > 0; - })) { + if (allSelectedGroupsAreEmpty(selected)) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -185,7 +180,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } if (size == 1) { - builder.appendInt(values.get(bounds.v1())); + builder.appendInt(values.get(rootIndex)); continue; } @@ -197,7 +192,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { builder.beginPositionEntry(); for (int i = 0; i < size; i++) { - builder.appendInt(values.get(bounds.v1() + i)); + builder.appendInt(values.get(rootIndex + i)); } builder.endPositionEntry(); } @@ -205,6 +200,17 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + /** * Is this bucket a min heap {@code true} or in gathering mode {@code false}? */ @@ -234,7 +240,8 @@ private void setNextGatherOffset(long rootIndex, int offset) { * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ private boolean betterThan(int lhs, int rhs) { - return getOrder().reverseMul() * Integer.compare(lhs, rhs) < 0; + int res = Integer.compare(lhs, rhs); + return getOrder().reverseMul() * res < 0; } /** diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntDoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntDoubleBucketedSort.java new file mode 100644 index 0000000000000..9692a3490add6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntDoubleBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code int} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class IntDoubleBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private IntArray values; + private DoubleArray extraValues; + + public IntDoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newIntArray(0, false); + extraValues = bigArrays.newDoubleArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(int value, double extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, IntDoubleBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendInt(values.get(rootIndex)); + extraBuilder.appendDouble(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendInt(values.get(rootIndex + i)); + extraBuilder.appendDouble(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(int lhs, int rhs, double lhsExtra, double rhsExtra) { + int res = Integer.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Double.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.INT_PAGE_SIZE, Integer.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntFloatBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntFloatBucketedSort.java new file mode 100644 index 0000000000000..756f9894d887a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntFloatBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code int} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class IntFloatBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private IntArray values; + private FloatArray extraValues; + + public IntFloatBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newIntArray(0, false); + extraValues = bigArrays.newFloatArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(int value, float extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, IntFloatBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendInt(values.get(rootIndex)); + extraBuilder.appendFloat(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendInt(values.get(rootIndex + i)); + extraBuilder.appendFloat(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(int lhs, int rhs, float lhsExtra, float rhsExtra) { + int res = Integer.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Float.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.INT_PAGE_SIZE, Integer.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntIntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntIntBucketedSort.java new file mode 100644 index 0000000000000..13cb37e886ce6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntIntBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code int} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class IntIntBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private IntArray values; + private IntArray extraValues; + + public IntIntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newIntArray(0, false); + extraValues = bigArrays.newIntArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(int value, int extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, IntIntBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newIntBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendInt(values.get(rootIndex)); + extraBuilder.appendInt(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendInt(values.get(rootIndex + i)); + extraBuilder.appendInt(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(int lhs, int rhs, int lhsExtra, int rhsExtra) { + int res = Integer.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Integer.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.INT_PAGE_SIZE, Integer.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntLongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntLongBucketedSort.java new file mode 100644 index 0000000000000..dae223f67dc8d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/IntLongBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code int} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class IntLongBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private IntArray values; + private LongArray extraValues; + + public IntLongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newIntArray(0, false); + extraValues = bigArrays.newLongArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(int value, long extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, IntLongBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newLongBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendInt(values.get(rootIndex)); + extraBuilder.appendLong(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendInt(values.get(rootIndex + i)); + extraBuilder.appendLong(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(int lhs, int rhs, long lhsExtra, long rhsExtra) { + int res = Integer.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Long.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.INT_PAGE_SIZE, Integer.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java index ac472cc411668..e10ce05017d23 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongBucketedSort.java @@ -23,7 +23,7 @@ import java.util.stream.IntStream; /** - * Aggregates the top N long values per bucket. + * Aggregates the top N {@code long} values per bucket. * See {@link BucketedSort} for more information. * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. */ @@ -162,12 +162,7 @@ public void merge(int groupId, LongBucketedSort other, int otherGroupId) { */ public Block toBlock(BlockFactory blockFactory, IntVector selected) { // Check if the selected groups are all empty, to avoid allocating extra memory - if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { - var bounds = this.getBucketValuesIndexes(bucket); - var size = bounds.v2() - bounds.v1(); - - return size > 0; - })) { + if (allSelectedGroupsAreEmpty(selected)) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -185,7 +180,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } if (size == 1) { - builder.appendLong(values.get(bounds.v1())); + builder.appendLong(values.get(rootIndex)); continue; } @@ -197,7 +192,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { builder.beginPositionEntry(); for (int i = 0; i < size; i++) { - builder.appendLong(values.get(bounds.v1() + i)); + builder.appendLong(values.get(rootIndex + i)); } builder.endPositionEntry(); } @@ -205,6 +200,17 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + /** * Is this bucket a min heap {@code true} or in gathering mode {@code false}? */ @@ -234,7 +240,8 @@ private void setNextGatherOffset(long rootIndex, int offset) { * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ private boolean betterThan(long lhs, long rhs) { - return getOrder().reverseMul() * Long.compare(lhs, rhs) < 0; + int res = Long.compare(lhs, rhs); + return getOrder().reverseMul() * res < 0; } /** diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongDoubleBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongDoubleBucketedSort.java new file mode 100644 index 0000000000000..dd7bf77eaf6f6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongDoubleBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code long} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class LongDoubleBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private LongArray values; + private DoubleArray extraValues; + + public LongDoubleBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newLongArray(0, false); + extraValues = bigArrays.newDoubleArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(long value, double extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, LongDoubleBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendLong(values.get(rootIndex)); + extraBuilder.appendDouble(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendLong(values.get(rootIndex + i)); + extraBuilder.appendDouble(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(long lhs, long rhs, double lhsExtra, double rhsExtra) { + int res = Long.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Double.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.LONG_PAGE_SIZE, Long.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongFloatBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongFloatBucketedSort.java new file mode 100644 index 0000000000000..db4d25aecfa74 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongFloatBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code long} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class LongFloatBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private LongArray values; + private FloatArray extraValues; + + public LongFloatBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newLongArray(0, false); + extraValues = bigArrays.newFloatArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(long value, float extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, LongFloatBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newFloatBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendLong(values.get(rootIndex)); + extraBuilder.appendFloat(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendLong(values.get(rootIndex + i)); + extraBuilder.appendFloat(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(long lhs, long rhs, float lhsExtra, float rhsExtra) { + int res = Long.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Float.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.LONG_PAGE_SIZE, Long.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongIntBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongIntBucketedSort.java new file mode 100644 index 0000000000000..8e22b27791a2c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongIntBucketedSort.java @@ -0,0 +1,406 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code long} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class LongIntBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private LongArray values; + private IntArray extraValues; + + public LongIntBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newLongArray(0, false); + extraValues = bigArrays.newIntArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(long value, int extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, LongIntBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newIntBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendLong(values.get(rootIndex)); + extraBuilder.appendInt(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendLong(values.get(rootIndex + i)); + extraBuilder.appendInt(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(long lhs, long rhs, int lhsExtra, int rhsExtra) { + int res = Long.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Integer.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.LONG_PAGE_SIZE, Long.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongLongBucketedSort.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongLongBucketedSort.java new file mode 100644 index 0000000000000..49a824e0c88b0 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/sort/LongLongBucketedSort.java @@ -0,0 +1,405 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.common.util.PageCacheRecycler; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.stream.IntStream; + +/** + * Aggregates the top N {@code long} values per bucket. + * See {@link BucketedSort} for more information. + * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. + */ +public class LongLongBucketedSort implements Releasable { + + private final BigArrays bigArrays; + private final SortOrder order; + private final int bucketSize; + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have less than bucketSize elements. + * In gather mode, the elements are stored in the array from the highest index to the lowest index. + * The lowest index contains the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private LongArray values; + private LongArray extraValues; + + public LongLongBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + heapMode = new BitArray(0, bigArrays); + + boolean success = false; + try { + values = bigArrays.newLongArray(0, false); + extraValues = bigArrays.newLongArray(0, false); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(long value, long extraValue, int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex), extraValue, extraValues.get(rootIndex))) { + values.set(rootIndex, value); + extraValues.set(rootIndex, extraValue); + downHeap(rootIndex, 0, bucketSize); + } + return; + } + // Gathering mode + long requiredSize = rootIndex + bucketSize; + if (values.size() < requiredSize) { + grow(bucket); + } + int next = getNextGatherOffset(rootIndex); + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + long index = next + rootIndex; + values.set(index, value); + extraValues.set(index, extraValue); + if (next == 0) { + heapMode.set(bucket); + heapify(rootIndex, bucketSize); + } else { + setNextGatherOffset(rootIndex, next - 1); + } + } + + /** + * The order of the sort. + */ + public SortOrder getOrder() { + return order; + } + + /** + * The number of values to store per bucket. + */ + public int getBucketSize() { + return bucketSize; + } + + /** + * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. + * Returns [0, 0] if the bucket has never been collected. + */ + private Tuple getBucketValuesIndexes(int bucket) { + long rootIndex = (long) bucket * bucketSize; + if (rootIndex >= values.size()) { + // We've never seen this bucket. + return Tuple.tuple(0L, 0L); + } + long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); + long end = rootIndex + bucketSize; + return Tuple.tuple(start, end); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int groupId, LongLongBucketedSort other, int otherGroupId) { + var otherBounds = other.getBucketValuesIndexes(otherGroupId); + + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { + collect(other.values.get(i), other.extraValues.get(i), groupId); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } + + try ( + var builder = blockFactory.newLongBlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.newLongBlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.appendLong(values.get(rootIndex)); + extraBuilder.appendLong(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.appendLong(values.get(rootIndex + i)); + extraBuilder.appendLong(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } + + /** + * Is this bucket a min heap {@code true} or in gathering mode {@code false}? + */ + private boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private int getNextGatherOffset(long rootIndex) { + return (int) values.get(rootIndex); + } + + /** + * Set the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + */ + private void setNextGatherOffset(long rootIndex, int offset) { + values.set(rootIndex, offset); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(long lhs, long rhs, long lhsExtra, long rhsExtra) { + int res = Long.compare(lhs, rhs); + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = Long.compare(lhsExtra, rhsExtra); + return getOrder().reverseMul() * res < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + var tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. We always grow the storage by whole bucket's + * worth of slots at a time. We never allocate space for partial buckets. + */ + private void grow(int bucket) { + long oldMax = values.size(); + assert oldMax % bucketSize == 0; + + long newSize = BigArrays.overSize(((long) bucket + 1) * bucketSize, PageCacheRecycler.LONG_PAGE_SIZE, Long.BYTES); + // Round up to the next full bucket. + newSize = (newSize + bucketSize - 1) / bucketSize; + values = bigArrays.resize(values, newSize * bucketSize); + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + int nextOffset = getBucketSize() - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += getBucketSize()) { + setNextGatherOffset(bucketRoot, nextOffset); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex, int heapSize) { + int maxParent = heapSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent, heapSize); + } + } + + /** + * Sorts all the values in the heap using heap sort algorithm. + * This runs in {@code O(n log n)} time. + * @param rootIndex index of the start of the bucket + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void heapSort(long rootIndex, int heapSize) { + while (heapSize > 0) { + swap(rootIndex, rootIndex + heapSize - 1); + heapSize--; + downHeap(rootIndex, 0, heapSize); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + * @param heapSize Number of values that belong to the heap. + * Can be less than bucketSize. + * In such a case, the remaining values in range + * (rootIndex + heapSize, rootIndex + bucketSize) + * are *not* considered part of the heap. + */ + private void downHeap(long rootIndex, int parent, int heapSize) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < heapSize) { + if (betterThan(values.get(worstIndex), values.get(leftIndex), extraValues.get(worstIndex), extraValues.get(leftIndex))) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + @Override + public final void close() { + Releasables.close(values, extraValues, heapMode); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..111f7ce2d3a6b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunction.java @@ -0,0 +1,218 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopDoubleDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopDoubleDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final DriverContext driverContext; + + private final TopDoubleDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopDoubleDoubleAggregatorFunction(DriverContext driverContext, List channels, + TopDoubleDoubleAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopDoubleDoubleAggregatorFunction(driverContext, channels, TopDoubleDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(DoubleVector vVector, DoubleVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + double vValue = vVector.getDouble(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopDoubleDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(DoubleVector vVector, DoubleVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double vValue = vVector.getDouble(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopDoubleDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(DoubleBlock vBlock, DoubleBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopDoubleDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(DoubleBlock vBlock, DoubleBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopDoubleDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert output.getPositionCount() == 1; + TopDoubleDoubleAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopDoubleDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..0c5a1aea81c7b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopDoubleDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopDoubleDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopDoubleDoubleAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopDoubleDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopDoubleDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopDoubleDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopDoubleDoubleAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopDoubleDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopDoubleDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_double of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..210ad4e23a6e1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleDoubleGroupingAggregatorFunction.java @@ -0,0 +1,395 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopDoubleDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopDoubleDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final TopDoubleDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopDoubleDoubleGroupingAggregatorFunction(List channels, + TopDoubleDoubleAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopDoubleDoubleGroupingAggregatorFunction(channels, TopDoubleDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double vValue = vVector.getDouble(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopDoubleDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopDoubleDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock vBlock, + DoubleBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopDoubleDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunction.java new file mode 100644 index 0000000000000..a1c451e229c8c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunction.java @@ -0,0 +1,220 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopDoubleFloatAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopDoubleFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final DriverContext driverContext; + + private final TopDoubleFloatAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopDoubleFloatAggregatorFunction(DriverContext driverContext, List channels, + TopDoubleFloatAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleFloatAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopDoubleFloatAggregatorFunction(driverContext, channels, TopDoubleFloatAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(DoubleVector vVector, FloatVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + double vValue = vVector.getDouble(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopDoubleFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(DoubleVector vVector, FloatVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double vValue = vVector.getDouble(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopDoubleFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(DoubleBlock vBlock, FloatBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopDoubleFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(DoubleBlock vBlock, FloatBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopDoubleFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert output.getPositionCount() == 1; + TopDoubleFloatAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopDoubleFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..db99d9135f9b8 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopDoubleFloatAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopDoubleFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopDoubleFloatAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopDoubleFloatAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopDoubleFloatGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopDoubleFloatAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopDoubleFloatAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopDoubleFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopDoubleFloatGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_double of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..e3bb5b23405d6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleFloatGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopDoubleFloatAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopDoubleFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final TopDoubleFloatAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopDoubleFloatGroupingAggregatorFunction(List channels, + TopDoubleFloatAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopDoubleFloatGroupingAggregatorFunction(channels, TopDoubleFloatAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double vValue = vVector.getDouble(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopDoubleFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopDoubleFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock vBlock, + FloatBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopDoubleFloatAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunction.java new file mode 100644 index 0000000000000..bf2a4eb912b82 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopDoubleIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopDoubleIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final DriverContext driverContext; + + private final TopDoubleIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopDoubleIntAggregatorFunction(DriverContext driverContext, List channels, + TopDoubleIntAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopDoubleIntAggregatorFunction(driverContext, channels, TopDoubleIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(DoubleVector vVector, IntVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + double vValue = vVector.getDouble(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopDoubleIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(DoubleVector vVector, IntVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double vValue = vVector.getDouble(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopDoubleIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(DoubleBlock vBlock, IntBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopDoubleIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(DoubleBlock vBlock, IntBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopDoubleIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert output.getPositionCount() == 1; + TopDoubleIntAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopDoubleIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..4bb7632d5db7d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopDoubleIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopDoubleIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopDoubleIntAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopDoubleIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopDoubleIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopDoubleIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopDoubleIntAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopDoubleIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopDoubleIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_double of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..817ac9746c37b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleIntGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopDoubleIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopDoubleIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final TopDoubleIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopDoubleIntGroupingAggregatorFunction(List channels, + TopDoubleIntAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopDoubleIntGroupingAggregatorFunction(channels, TopDoubleIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double vValue = vVector.getDouble(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopDoubleIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopDoubleIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock vBlock, + IntBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopDoubleIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunction.java new file mode 100644 index 0000000000000..9951bd7dd5c2b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunction.java @@ -0,0 +1,220 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopDoubleLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopDoubleLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final TopDoubleLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopDoubleLongAggregatorFunction(DriverContext driverContext, List channels, + TopDoubleLongAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopDoubleLongAggregatorFunction(driverContext, channels, TopDoubleLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(DoubleVector vVector, LongVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + double vValue = vVector.getDouble(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopDoubleLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(DoubleVector vVector, LongVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + double vValue = vVector.getDouble(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopDoubleLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(DoubleBlock vBlock, LongBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopDoubleLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(DoubleBlock vBlock, LongBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopDoubleLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert output.getPositionCount() == 1; + TopDoubleLongAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopDoubleLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..948e23035f622 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopDoubleLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopDoubleLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopDoubleLongAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopDoubleLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopDoubleLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopDoubleLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopDoubleLongAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopDoubleLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopDoubleLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_double of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..04cd530e8dd41 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopDoubleLongGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopDoubleLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopDoubleLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.DOUBLE), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final TopDoubleLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopDoubleLongGroupingAggregatorFunction(List channels, + TopDoubleLongAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopDoubleLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopDoubleLongGroupingAggregatorFunction(channels, TopDoubleLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + DoubleBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + DoubleVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, DoubleVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, DoubleVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + double vValue = vVector.getDouble(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopDoubleLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + double vValue = vBlock.getDouble(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, DoubleVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + double vValue = vVector.getDouble(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopDoubleLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + DoubleBlock top = (DoubleBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopDoubleLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, DoubleBlock vBlock, + LongBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopDoubleLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..3fc172d2a2dfe --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunction.java @@ -0,0 +1,220 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopFloatDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopFloatDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final DriverContext driverContext; + + private final TopFloatDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopFloatDoubleAggregatorFunction(DriverContext driverContext, List channels, + TopFloatDoubleAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopFloatDoubleAggregatorFunction(driverContext, channels, TopFloatDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(FloatVector vVector, DoubleVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + float vValue = vVector.getFloat(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopFloatDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(FloatVector vVector, DoubleVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + float vValue = vVector.getFloat(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopFloatDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(FloatBlock vBlock, DoubleBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopFloatDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(FloatBlock vBlock, DoubleBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopFloatDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert output.getPositionCount() == 1; + TopFloatDoubleAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopFloatDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..913f2087b2562 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopFloatDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopFloatDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopFloatDoubleAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopFloatDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopFloatDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopFloatDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopFloatDoubleAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopFloatDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopFloatDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_float of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..a4fe8f7a107e9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatDoubleGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopFloatDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopFloatDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final TopFloatDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopFloatDoubleGroupingAggregatorFunction(List channels, + TopFloatDoubleAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopFloatDoubleGroupingAggregatorFunction(channels, TopFloatDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + float vValue = vVector.getFloat(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopFloatDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopFloatDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, FloatBlock vBlock, + DoubleBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopFloatDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunction.java new file mode 100644 index 0000000000000..7a59c9060c727 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunction.java @@ -0,0 +1,218 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopFloatFloatAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopFloatFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final DriverContext driverContext; + + private final TopFloatFloatAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopFloatFloatAggregatorFunction(DriverContext driverContext, List channels, + TopFloatFloatAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatFloatAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopFloatFloatAggregatorFunction(driverContext, channels, TopFloatFloatAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(FloatVector vVector, FloatVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + float vValue = vVector.getFloat(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopFloatFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(FloatVector vVector, FloatVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + float vValue = vVector.getFloat(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopFloatFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(FloatBlock vBlock, FloatBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopFloatFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(FloatBlock vBlock, FloatBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopFloatFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert output.getPositionCount() == 1; + TopFloatFloatAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopFloatFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..673b05432a165 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopFloatFloatAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopFloatFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopFloatFloatAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopFloatFloatAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopFloatFloatGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopFloatFloatAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopFloatFloatAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopFloatFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopFloatFloatGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_float of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..6d955b01f352c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatFloatGroupingAggregatorFunction.java @@ -0,0 +1,395 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopFloatFloatAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopFloatFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final TopFloatFloatAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopFloatFloatGroupingAggregatorFunction(List channels, + TopFloatFloatAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopFloatFloatGroupingAggregatorFunction(channels, TopFloatFloatAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + float vValue = vVector.getFloat(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopFloatFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopFloatFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, FloatBlock vBlock, + FloatBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopFloatFloatAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunction.java new file mode 100644 index 0000000000000..3b190c8a069ce --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopFloatIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopFloatIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final DriverContext driverContext; + + private final TopFloatIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopFloatIntAggregatorFunction(DriverContext driverContext, List channels, + TopFloatIntAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopFloatIntAggregatorFunction(driverContext, channels, TopFloatIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(FloatVector vVector, IntVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + float vValue = vVector.getFloat(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopFloatIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(FloatVector vVector, IntVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + float vValue = vVector.getFloat(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopFloatIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(FloatBlock vBlock, IntBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopFloatIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(FloatBlock vBlock, IntBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopFloatIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert output.getPositionCount() == 1; + TopFloatIntAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopFloatIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..2851342c630c7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopFloatIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopFloatIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopFloatIntAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopFloatIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopFloatIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopFloatIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopFloatIntAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopFloatIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopFloatIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_float of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..af35589123794 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatIntGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopFloatIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopFloatIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final TopFloatIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopFloatIntGroupingAggregatorFunction(List channels, + TopFloatIntAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopFloatIntGroupingAggregatorFunction(channels, TopFloatIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + float vValue = vVector.getFloat(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopFloatIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopFloatIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, FloatBlock vBlock, + IntBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopFloatIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunction.java new file mode 100644 index 0000000000000..39006ef2ad955 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopFloatLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopFloatLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final TopFloatLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopFloatLongAggregatorFunction(DriverContext driverContext, List channels, + TopFloatLongAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopFloatLongAggregatorFunction(driverContext, channels, TopFloatLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(FloatVector vVector, LongVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + float vValue = vVector.getFloat(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopFloatLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(FloatVector vVector, LongVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + float vValue = vVector.getFloat(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopFloatLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(FloatBlock vBlock, LongBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopFloatLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(FloatBlock vBlock, LongBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopFloatLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert output.getPositionCount() == 1; + TopFloatLongAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopFloatLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..d8667fddcb2e9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopFloatLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopFloatLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopFloatLongAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopFloatLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopFloatLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopFloatLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopFloatLongAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopFloatLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopFloatLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_float of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..6c04a6df03f2e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopFloatLongGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopFloatLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopFloatLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.FLOAT), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final TopFloatLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopFloatLongGroupingAggregatorFunction(List channels, + TopFloatLongAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopFloatLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopFloatLongGroupingAggregatorFunction(channels, TopFloatLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + FloatBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + FloatVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, FloatVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, FloatVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + float vValue = vVector.getFloat(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopFloatLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + float vValue = vBlock.getFloat(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, FloatVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + float vValue = vVector.getFloat(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopFloatLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + FloatBlock top = (FloatBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopFloatLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, FloatBlock vBlock, + LongBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopFloatLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..4f661b865cb9c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopIntDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopIntDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final DriverContext driverContext; + + private final TopIntDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopIntDoubleAggregatorFunction(DriverContext driverContext, List channels, + TopIntDoubleAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopIntDoubleAggregatorFunction(driverContext, channels, TopIntDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(IntVector vVector, DoubleVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + int vValue = vVector.getInt(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopIntDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(IntVector vVector, DoubleVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int vValue = vVector.getInt(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopIntDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(IntBlock vBlock, DoubleBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopIntDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(IntBlock vBlock, DoubleBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopIntDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert output.getPositionCount() == 1; + TopIntDoubleAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopIntDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..038447c14d235 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopIntDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopIntDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopIntDoubleAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopIntDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopIntDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopIntDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopIntDoubleAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopIntDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopIntDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_int of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..d75f26dfb2338 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntDoubleGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopIntDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopIntDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final TopIntDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopIntDoubleGroupingAggregatorFunction(List channels, + TopIntDoubleAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopIntDoubleGroupingAggregatorFunction(channels, TopIntDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int vValue = vVector.getInt(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopIntDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopIntDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock vBlock, + DoubleBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopIntDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunction.java new file mode 100644 index 0000000000000..8f39a47ed99a7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopIntFloatAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopIntFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final DriverContext driverContext; + + private final TopIntFloatAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopIntFloatAggregatorFunction(DriverContext driverContext, List channels, + TopIntFloatAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntFloatAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopIntFloatAggregatorFunction(driverContext, channels, TopIntFloatAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(IntVector vVector, FloatVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + int vValue = vVector.getInt(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopIntFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(IntVector vVector, FloatVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int vValue = vVector.getInt(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopIntFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(IntBlock vBlock, FloatBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopIntFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(IntBlock vBlock, FloatBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopIntFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert output.getPositionCount() == 1; + TopIntFloatAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopIntFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..a71fc5830b8fc --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopIntFloatAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopIntFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopIntFloatAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopIntFloatAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopIntFloatGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopIntFloatAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopIntFloatAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopIntFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopIntFloatGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_int of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..b9b7abd9ff435 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntFloatGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopIntFloatAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopIntFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final TopIntFloatAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopIntFloatGroupingAggregatorFunction(List channels, + TopIntFloatAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopIntFloatGroupingAggregatorFunction(channels, TopIntFloatAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int vValue = vVector.getInt(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopIntFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopIntFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock vBlock, + FloatBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopIntFloatAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunction.java new file mode 100644 index 0000000000000..8c9f4ed629f0f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunction.java @@ -0,0 +1,217 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopIntIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopIntIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final DriverContext driverContext; + + private final TopIntIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopIntIntAggregatorFunction(DriverContext driverContext, List channels, + TopIntIntAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopIntIntAggregatorFunction(driverContext, channels, TopIntIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(IntVector vVector, IntVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + int vValue = vVector.getInt(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopIntIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(IntVector vVector, IntVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int vValue = vVector.getInt(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopIntIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(IntBlock vBlock, IntBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopIntIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(IntBlock vBlock, IntBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopIntIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert output.getPositionCount() == 1; + TopIntIntAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopIntIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..98483a1480f3a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopIntIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopIntIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopIntIntAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopIntIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopIntIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopIntIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopIntIntAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopIntIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopIntIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_int of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..c4966d76012e9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntIntGroupingAggregatorFunction.java @@ -0,0 +1,394 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopIntIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopIntIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final TopIntIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopIntIntGroupingAggregatorFunction(List channels, + TopIntIntAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopIntIntGroupingAggregatorFunction(channels, TopIntIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int vValue = vVector.getInt(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopIntIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopIntIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock vBlock, + IntBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopIntIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunction.java new file mode 100644 index 0000000000000..9b44f260e3517 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopIntLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopIntLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final TopIntLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopIntLongAggregatorFunction(DriverContext driverContext, List channels, + TopIntLongAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopIntLongAggregatorFunction(driverContext, channels, TopIntLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + IntBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(IntVector vVector, LongVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + int vValue = vVector.getInt(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopIntLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(IntVector vVector, LongVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + int vValue = vVector.getInt(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopIntLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(IntBlock vBlock, LongBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopIntLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(IntBlock vBlock, LongBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopIntLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert output.getPositionCount() == 1; + TopIntLongAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopIntLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..720cf8fd05269 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopIntLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopIntLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopIntLongAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopIntLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopIntLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopIntLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopIntLongAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopIntLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopIntLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_int of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..113d80aa98bdb --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopIntLongGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopIntLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopIntLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.INT), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final TopIntLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopIntLongGroupingAggregatorFunction(List channels, + TopIntLongAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopIntLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopIntLongGroupingAggregatorFunction(channels, TopIntLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + IntBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + IntVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, IntVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, IntVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vValue = vVector.getInt(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopIntLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + int vValue = vBlock.getInt(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, IntVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + int vValue = vVector.getInt(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopIntLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + IntBlock top = (IntBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopIntLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, IntBlock vBlock, + LongBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopIntLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunction.java new file mode 100644 index 0000000000000..3ce085f84251d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunction.java @@ -0,0 +1,220 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopLongDoubleAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopLongDoubleAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final DriverContext driverContext; + + private final TopLongDoubleAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopLongDoubleAggregatorFunction(DriverContext driverContext, List channels, + TopLongDoubleAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongDoubleAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopLongDoubleAggregatorFunction(driverContext, channels, TopLongDoubleAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(LongVector vVector, DoubleVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + long vValue = vVector.getLong(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopLongDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(LongVector vVector, DoubleVector outputValueVector, + BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long vValue = vVector.getLong(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopLongDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(LongBlock vBlock, DoubleBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopLongDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(LongBlock vBlock, DoubleBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopLongDoubleAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert output.getPositionCount() == 1; + TopLongDoubleAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopLongDoubleAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..d57f610de8a50 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopLongDoubleAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopLongDoubleAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopLongDoubleAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopLongDoubleAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopLongDoubleGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopLongDoubleAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopLongDoubleAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopLongDoubleGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopLongDoubleGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_long of doubles"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..ee6db0949b81e --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongDoubleGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopLongDoubleAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopLongDoubleGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.DOUBLE) ); + + private final TopLongDoubleAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopLongDoubleGroupingAggregatorFunction(List channels, + TopLongDoubleAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongDoubleGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopLongDoubleGroupingAggregatorFunction(channels, TopLongDoubleAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + DoubleBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + DoubleVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock vBlock, + DoubleBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + double outputValueValue = outputValueBlock.getDouble(outputValueOffset); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector vVector, + DoubleVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long vValue = vVector.getLong(valuesPosition); + double outputValueValue = outputValueVector.getDouble(valuesPosition); + TopLongDoubleAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + DoubleBlock output = (DoubleBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopLongDoubleAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock vBlock, + DoubleBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopLongDoubleAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunction.java new file mode 100644 index 0000000000000..0de6c65d1e5e5 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopLongFloatAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopLongFloatAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final DriverContext driverContext; + + private final TopLongFloatAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopLongFloatAggregatorFunction(DriverContext driverContext, List channels, + TopLongFloatAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongFloatAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopLongFloatAggregatorFunction(driverContext, channels, TopLongFloatAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(LongVector vVector, FloatVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + long vValue = vVector.getLong(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopLongFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(LongVector vVector, FloatVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long vValue = vVector.getLong(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopLongFloatAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(LongBlock vBlock, FloatBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopLongFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(LongBlock vBlock, FloatBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopLongFloatAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert output.getPositionCount() == 1; + TopLongFloatAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopLongFloatAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..57104208a686a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopLongFloatAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopLongFloatAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopLongFloatAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopLongFloatAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopLongFloatGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopLongFloatAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopLongFloatAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopLongFloatGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopLongFloatGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_long of floats"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..b91efe7571529 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongFloatGroupingAggregatorFunction.java @@ -0,0 +1,397 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.FloatVector; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopLongFloatAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopLongFloatGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.FLOAT) ); + + private final TopLongFloatAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopLongFloatGroupingAggregatorFunction(List channels, + TopLongFloatAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongFloatGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopLongFloatGroupingAggregatorFunction(channels, TopLongFloatAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + FloatBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + FloatVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock vBlock, + FloatBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + float outputValueValue = outputValueBlock.getFloat(outputValueOffset); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector vVector, + FloatVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long vValue = vVector.getLong(valuesPosition); + float outputValueValue = outputValueVector.getFloat(valuesPosition); + TopLongFloatAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + FloatBlock output = (FloatBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopLongFloatAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock vBlock, + FloatBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopLongFloatAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunction.java new file mode 100644 index 0000000000000..9186916c34e8d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunction.java @@ -0,0 +1,219 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopLongIntAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopLongIntAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final DriverContext driverContext; + + private final TopLongIntAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopLongIntAggregatorFunction(DriverContext driverContext, List channels, + TopLongIntAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongIntAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopLongIntAggregatorFunction(driverContext, channels, TopLongIntAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(LongVector vVector, IntVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + long vValue = vVector.getLong(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopLongIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(LongVector vVector, IntVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long vValue = vVector.getLong(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopLongIntAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(LongBlock vBlock, IntBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopLongIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(LongBlock vBlock, IntBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopLongIntAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert output.getPositionCount() == 1; + TopLongIntAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopLongIntAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..a4d67d1b63a73 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopLongIntAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopLongIntAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopLongIntAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopLongIntAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopLongIntGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopLongIntAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopLongIntAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopLongIntGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopLongIntGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_long of ints"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..3a78089397a4c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongIntGroupingAggregatorFunction.java @@ -0,0 +1,396 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopLongIntAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopLongIntGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.INT) ); + + private final TopLongIntAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopLongIntGroupingAggregatorFunction(List channels, + TopLongIntAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongIntGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopLongIntGroupingAggregatorFunction(channels, TopLongIntAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + IntBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + IntVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock vBlock, + IntBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + int outputValueValue = outputValueBlock.getInt(outputValueOffset); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector vVector, + IntVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long vValue = vVector.getLong(valuesPosition); + int outputValueValue = outputValueVector.getInt(valuesPosition); + TopLongIntAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + IntBlock output = (IntBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopLongIntAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock vBlock, + IntBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopLongIntAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunction.java new file mode 100644 index 0000000000000..fe9df560a0516 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunction.java @@ -0,0 +1,217 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopLongLongAggregator}. + * This class is generated. Edit {@code AggregatorImplementer} instead. + */ +public final class TopLongLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final DriverContext driverContext; + + private final TopLongLongAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopLongLongAggregatorFunction(DriverContext driverContext, List channels, + TopLongLongAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongLongAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopLongLongAggregatorFunction(driverContext, channels, TopLongLongAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + } else if (mask.allTrue()) { + addRawInputNotMasked(page); + } else { + addRawInputMasked(page, mask); + } + } + + private void addRawInputMasked(Page page, BooleanVector mask) { + LongBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock, mask); + return; + } + addRawVector(vVector, outputValueVector, mask); + } + + private void addRawInputNotMasked(Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + addRawBlock(vBlock, outputValueBlock); + return; + } + addRawVector(vVector, outputValueVector); + } + + private void addRawVector(LongVector vVector, LongVector outputValueVector) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + long vValue = vVector.getLong(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopLongLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawVector(LongVector vVector, LongVector outputValueVector, BooleanVector mask) { + for (int valuesPosition = 0; valuesPosition < vVector.getPositionCount(); valuesPosition++) { + if (mask.getBoolean(valuesPosition) == false) { + continue; + } + long vValue = vVector.getLong(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopLongLongAggregator.combine(state, vValue, outputValueValue); + } + } + + private void addRawBlock(LongBlock vBlock, LongBlock outputValueBlock) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopLongLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + private void addRawBlock(LongBlock vBlock, LongBlock outputValueBlock, BooleanVector mask) { + for (int p = 0; p < vBlock.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + int vValueCount = vBlock.getValueCount(p); + if (vValueCount == 0) { + continue; + } + int outputValueValueCount = outputValueBlock.getValueCount(p); + if (outputValueValueCount == 0) { + continue; + } + int vStart = vBlock.getFirstValueIndex(p); + int vEnd = vStart + vValueCount; + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(p); + int outputValueEnd = outputValueStart + outputValueValueCount; + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopLongLongAggregator.combine(state, vValue, outputValueValue); + } + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + assert top.getPositionCount() == 1; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert output.getPositionCount() == 1; + TopLongLongAggregator.combineIntermediate(state, top, output); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopLongLongAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..ad17808763bed --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongAggregatorFunctionSupplier.java @@ -0,0 +1,53 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopLongLongAggregator}. + * This class is generated. Edit {@code AggregatorFunctionSupplierImplementer} instead. + */ +public final class TopLongLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final int limit; + + private final boolean ascending; + + public TopLongLongAggregatorFunctionSupplier(int limit, boolean ascending) { + this.limit = limit; + this.ascending = ascending; + } + + @Override + public List nonGroupingIntermediateStateDesc() { + return TopLongLongAggregatorFunction.intermediateStateDesc(); + } + + @Override + public List groupingIntermediateStateDesc() { + return TopLongLongGroupingAggregatorFunction.intermediateStateDesc(); + } + + @Override + public TopLongLongAggregatorFunction aggregator(DriverContext driverContext, + List channels) { + return TopLongLongAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopLongLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext, + List channels) { + return TopLongLongGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top_long of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..4d8f016593497 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopLongLongGroupingAggregatorFunction.java @@ -0,0 +1,395 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntArrayBlock; +import org.elasticsearch.compute.data.IntBigArrayBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopLongLongAggregator}. + * This class is generated. Edit {@code GroupingAggregatorImplementer} instead. + */ +public final class TopLongLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.LONG), + new IntermediateStateDesc("output", ElementType.LONG) ); + + private final TopLongLongAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopLongLongGroupingAggregatorFunction(List channels, + TopLongLongAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopLongLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopLongLongGroupingAggregatorFunction(channels, TopLongLongAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock vBlock = page.getBlock(channels.get(0)); + LongBlock outputValueBlock = page.getBlock(channels.get(1)); + LongVector vVector = vBlock.asVector(); + if (vVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + LongVector outputValueVector = outputValueBlock.asVector(); + if (outputValueVector == null) { + maybeEnableGroupIdTracking(seenGroupIds, vBlock, outputValueBlock); + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vBlock, outputValueBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntBigArrayBlock groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, vVector, outputValueVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntArrayBlock groups, LongVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + } + + private void addRawInput(int positionOffset, IntBigArrayBlock groups, LongVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int valuesPosition = groupPosition + positionOffset; + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + long vValue = vVector.getLong(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + int valuesPosition = groupPosition + positionOffset; + TopLongLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock vBlock, + LongBlock outputValueBlock) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + if (vBlock.isNull(valuesPosition)) { + continue; + } + if (outputValueBlock.isNull(valuesPosition)) { + continue; + } + int groupId = groups.getInt(groupPosition); + int vStart = vBlock.getFirstValueIndex(valuesPosition); + int vEnd = vStart + vBlock.getValueCount(valuesPosition); + for (int vOffset = vStart; vOffset < vEnd; vOffset++) { + long vValue = vBlock.getLong(vOffset); + int outputValueStart = outputValueBlock.getFirstValueIndex(valuesPosition); + int outputValueEnd = outputValueStart + outputValueBlock.getValueCount(valuesPosition); + for (int outputValueOffset = outputValueStart; outputValueOffset < outputValueEnd; outputValueOffset++) { + long outputValueValue = outputValueBlock.getLong(outputValueOffset); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector vVector, + LongVector outputValueVector) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int valuesPosition = groupPosition + positionOffset; + int groupId = groups.getInt(groupPosition); + long vValue = vVector.getLong(valuesPosition); + long outputValueValue = outputValueVector.getLong(valuesPosition); + TopLongLongAggregator.combine(state, groupId, vValue, outputValueValue); + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + LongBlock top = (LongBlock) topUncast; + Block outputUncast = page.getBlock(channels.get(1)); + if (outputUncast.areAllValuesNull()) { + return; + } + LongBlock output = (LongBlock) outputUncast; + assert top.getPositionCount() == output.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + int valuesPosition = groupPosition + positionOffset; + TopLongLongAggregator.combineIntermediate(state, groupId, top, output, valuesPosition); + } + } + + private void maybeEnableGroupIdTracking(SeenGroupIds seenGroupIds, LongBlock vBlock, + LongBlock outputValueBlock) { + if (vBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + if (outputValueBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + GroupingAggregatorEvaluationContext ctx) { + blocks[offset] = TopLongLongAggregator.evaluateFinal(state, selected, ctx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st index f2ad759a6d57f..da5499eeed082 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st @@ -17,9 +17,11 @@ import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.$Type$Block; +$if(hasOutputField)$ +import org.elasticsearch.compute.data.$OutputFieldType$Block; +$endif$ import org.elasticsearch.compute.data.IntVector; -import org.elasticsearch.compute.data.$Type$Block; -import org.elasticsearch.compute.data.sort.$Name$BucketedSort; +import org.elasticsearch.compute.data.sort.$Name$$OutputFieldName$BucketedSort; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -31,28 +33,32 @@ import org.elasticsearch.search.sort.SortOrder; * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. *

*/ +$if(hasOutputField)$ +@Aggregator({ @IntermediateState(name = "top", type = "$TYPE$_BLOCK"), @IntermediateState(name = "output", type = "$OutputFieldTYPE$_BLOCK") }) +$else$ @Aggregator({ @IntermediateState(name = "top", type = "$TYPE$_BLOCK") }) +$endif$ @GroupingAggregator -class Top$Name$Aggregator { +class Top$Name$$OutputFieldName$Aggregator { public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { return new SingleState(bigArrays, limit, ascending); } - public static void combine(SingleState state, $type$ v) { - state.add(v); + public static void combine(SingleState state, $type$ v$if(hasOutputField)$, $OutputFieldtype$ outputValue$endif$) { + state.add(v$if(hasOutputField)$, outputValue$endif$); } - public static void combineIntermediate(SingleState state, $Type$Block values) { + public static void combineIntermediate(SingleState state, $Type$Block values$if(hasOutputField)$, $OutputFieldType$Block outputValues$endif$) { int start = values.getFirstValueIndex(0); int end = start + values.getValueCount(0); $if(BytesRef || Ip)$ var scratch = new BytesRef(); for (int i = start; i < end; i++) { - combine(state, values.get$Type$(i, scratch)); + combine(state, values.get$Type$(i, scratch)$if(hasOutputField)$, outputValues.get$OutputFieldType$(i)$endif$); } $else$ for (int i = start; i < end; i++) { - combine(state, values.get$Type$(i)); + combine(state, values.get$Type$(i)$if(hasOutputField)$, outputValues.get$OutputFieldType$(i)$endif$); } $endif$ } @@ -65,21 +71,21 @@ $endif$ return new GroupingState(bigArrays, limit, ascending); } - public static void combine(GroupingState state, int groupId, $type$ v) { - state.add(groupId, v); + public static void combine(GroupingState state, int groupId, $type$ v$if(hasOutputField)$, $OutputFieldtype$ outputValue$endif$) { + state.add(groupId, v$if(hasOutputField)$, outputValue$endif$); } - public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { + public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, $if(hasOutputField)$$OutputFieldType$Block outputValues, $endif$int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); $if(BytesRef || Ip)$ var scratch = new BytesRef(); for (int i = start; i < end; i++) { - combine(state, groupId, values.get$Type$(i, scratch)); + combine(state, groupId, values.get$Type$(i, scratch)$if(hasOutputField)$, outputValues.get$OutputFieldType$(i)$endif$); } $else$ for (int i = start; i < end; i++) { - combine(state, groupId, values.get$Type$(i)); + combine(state, groupId, values.get$Type$(i)$if(hasOutputField)$, outputValues.get$OutputFieldType$(i)$endif$); } $endif$ } @@ -89,7 +95,7 @@ $endif$ } public static class GroupingState implements GroupingAggregatorState { - private final $Name$BucketedSort sort; + private final $Name$$OutputFieldName$BucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { $if(BytesRef)$ @@ -97,22 +103,36 @@ $if(BytesRef)$ CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); this.sort = new BytesRefBucketedSort(breaker, "top", bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); $else$ - this.sort = new $Name$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + this.sort = new $Name$$OutputFieldName$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); $endif$ } - public void add(int groupId, $type$ value) { - sort.collect(value, groupId); + public void add(int groupId, $type$ value$if(hasOutputField)$, $OutputFieldtype$ outputValue$endif$) { + sort.collect(value, $if(hasOutputField)$outputValue, $endif$groupId); } @Override public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { +$if(hasOutputField)$ + sort.toBlocks(driverContext.blockFactory(), blocks, offset, selected); +$else$ blocks[offset] = toBlock(driverContext.blockFactory(), selected); +$endif$ } +$if(hasOutputField)$ + Block toBlock(BlockFactory blockFactory, IntVector selected) { + Block[] blocks = new Block[2]; + sort.toBlocks(blockFactory, blocks, 0, selected); + Releasables.close(blocks[0]); + return blocks[1]; + } + +$else$ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } +$endif$ @Override public void enableGroupIdTracking(SeenGroupIds seen) { @@ -132,13 +152,15 @@ $endif$ this.internalState = new GroupingState(bigArrays, limit, ascending); } - public void add($type$ value) { - internalState.add(0, value); + public void add($type$ value$if(hasOutputField)$, $OutputFieldtype$ outputValue$endif$) { + internalState.add(0, value$if(hasOutputField)$, outputValue$endif$); } @Override public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { - blocks[offset] = toBlock(driverContext.blockFactory()); + try (var intValues = driverContext.blockFactory().newConstantIntVector(0, 1)) { + internalState.toIntermediate(blocks, offset, intValues, driverContext); + } } Block toBlock(BlockFactory blockFactory) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st index 39eead2ed3044..70d7c5fb3f5e7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/X-BucketedSort.java.st @@ -9,6 +9,9 @@ package org.elasticsearch.compute.data.sort; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BitArray; +$if(hasExtra && !(long && Extralong))$ +import org.elasticsearch.common.util.$ExtraType$Array; +$endif$ import org.elasticsearch.common.util.$Type$Array; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.data.Block; @@ -23,11 +26,11 @@ import org.elasticsearch.search.sort.SortOrder; import java.util.stream.IntStream; /** - * Aggregates the top N $type$ values per bucket. + * Aggregates the top N {@code $type$} values per bucket. * See {@link BucketedSort} for more information. * This class is generated. Edit @{code X-BucketedSort.java.st} instead of this file. */ -public class $Type$BucketedSort implements Releasable { +public class $Name$$ExtraName$BucketedSort implements Releasable { private final BigArrays bigArrays; private final SortOrder order; @@ -64,9 +67,15 @@ public class $Type$BucketedSort implements Releasable { * * */ - private $Type$Array values; +$if(hasExtra)$ + private $Array$ values; + private $ExtraArray$ extraValues; + +$else$ + private $Array$ values; +$endif$ - public $Type$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { + public $Name$$ExtraName$BucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { this.bigArrays = bigArrays; this.order = order; this.bucketSize = bucketSize; @@ -75,6 +84,9 @@ public class $Type$BucketedSort implements Releasable { boolean success = false; try { values = bigArrays.new$Type$Array(0, false); +$if(hasExtra)$ + extraValues = bigArrays.new$ExtraType$Array(0, false); +$endif$ success = true; } finally { if (success == false) { @@ -89,11 +101,14 @@ public class $Type$BucketedSort implements Releasable { * It may or may not be inserted in the heap, depending on if it is better than the current root. *

*/ - public void collect($type$ value, int bucket) { + public void collect($type$ value, $if(hasExtra)$$Extratype$ extraValue, $endif$int bucket) { long rootIndex = (long) bucket * bucketSize; if (inHeapMode(bucket)) { - if (betterThan(value, values.get(rootIndex))) { + if (betterThan(value, values.get(rootIndex)$if(hasExtra)$, extraValue, extraValues.get(rootIndex)$endif$)) { values.set(rootIndex, value); +$if(hasExtra)$ + extraValues.set(rootIndex, extraValue); +$endif$ downHeap(rootIndex, 0, bucketSize); } return; @@ -108,6 +123,9 @@ public class $Type$BucketedSort implements Releasable { : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; long index = next + rootIndex; values.set(index, value); +$if(hasExtra)$ + extraValues.set(index, extraValue); +$endif$ if (next == 0) { heapMode.set(bucket); heapify(rootIndex, bucketSize); @@ -148,26 +166,79 @@ public class $Type$BucketedSort implements Releasable { /** * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. */ - public void merge(int groupId, $Type$BucketedSort other, int otherGroupId) { + public void merge(int groupId, $Name$$ExtraName$BucketedSort other, int otherGroupId) { var otherBounds = other.getBucketValuesIndexes(otherGroupId); // TODO: This can be improved for heapified buckets by making use of the heap structures for (long i = otherBounds.v1(); i < otherBounds.v2(); i++) { - collect(other.values.get(i), groupId); + collect(other.values.get(i), $if(hasExtra)$other.extraValues.get(i), $endif$groupId); } } +$if(hasExtra)$ /** * Creates a block with the values from the {@code selected} groups. */ - public Block toBlock(BlockFactory blockFactory, IntVector selected) { + public void toBlocks(BlockFactory blockFactory, Block[] blocks, int offset, IntVector selected) { // Check if the selected groups are all empty, to avoid allocating extra memory - if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { - var bounds = this.getBucketValuesIndexes(bucket); - var size = bounds.v2() - bounds.v1(); + if (allSelectedGroupsAreEmpty(selected)) { + Block constantNullBlock = blockFactory.newConstantNullBlock(selected.getPositionCount()); + constantNullBlock.incRef(); + blocks[offset] = constantNullBlock; + blocks[offset + 1] = constantNullBlock; + return; + } - return size > 0; - })) { + try ( + var builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount()); + var extraBuilder = blockFactory.new$ExtraType$BlockBuilder(selected.getPositionCount()) + ) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + + var bounds = getBucketValuesIndexes(bucket); + var rootIndex = bounds.v1(); + var size = bounds.v2() - bounds.v1(); + + if (size == 0) { + builder.appendNull(); + extraBuilder.appendNull(); + continue; + } + + if (size == 1) { + builder.append$Type$(values.get(rootIndex)); + extraBuilder.append$ExtraType$(extraValues.get(rootIndex)); + continue; + } + + // If we are in the gathering mode, we need to heapify before sorting. + if (inHeapMode(bucket) == false) { + heapify(rootIndex, (int) size); + } + heapSort(rootIndex, (int) size); + + builder.beginPositionEntry(); + extraBuilder.beginPositionEntry(); + for (int i = 0; i < size; i++) { + builder.append$Type$(values.get(rootIndex + i)); + extraBuilder.append$ExtraType$(extraValues.get(rootIndex + i)); + } + builder.endPositionEntry(); + extraBuilder.endPositionEntry(); + } + blocks[offset] = builder.build(); + blocks[offset + 1] = extraBuilder.build(); + } + } + +$else$ + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (allSelectedGroupsAreEmpty(selected)) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -185,7 +256,7 @@ public class $Type$BucketedSort implements Releasable { } if (size == 1) { - builder.append$Type$(values.get(bounds.v1())); + builder.append$Type$(values.get(rootIndex)); continue; } @@ -197,13 +268,25 @@ public class $Type$BucketedSort implements Releasable { builder.beginPositionEntry(); for (int i = 0; i < size; i++) { - builder.append$Type$(values.get(bounds.v1() + i)); + builder.append$Type$(values.get(rootIndex + i)); } builder.endPositionEntry(); } return builder.build(); } } +$endif$ + + /** + * Checks if the selected groups are all empty. + */ + private boolean allSelectedGroupsAreEmpty(IntVector selected) { + return IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + var bounds = this.getBucketValuesIndexes(bucket); + var size = bounds.v2() - bounds.v1(); + return size > 0; + }); + } /** * Is this bucket a min heap {@code true} or in gathering mode {@code false}? @@ -237,8 +320,15 @@ $endif$ * the entry at {@code rhs}. "Better" in this means "lower" for * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ - private boolean betterThan($type$ lhs, $type$ rhs) { - return getOrder().reverseMul() * $Wrapper$.compare(lhs, rhs) < 0; + private boolean betterThan($type$ lhs, $type$ rhs$if(hasExtra)$, $Extratype$ lhsExtra, $Extratype$ rhsExtra$endif$) { + int res = $Wrapper$.compare(lhs, rhs); +$if(hasExtra)$ + if (res != 0) { + return getOrder().reverseMul() * res < 0; + } + res = $ExtraWrapper$.compare(lhsExtra, rhsExtra); +$endif$ + return getOrder().reverseMul() * res < 0; } /** @@ -248,6 +338,11 @@ $endif$ var tmp = values.get(lhs); values.set(lhs, values.get(rhs)); values.set(rhs, tmp); +$if(hasExtra)$ + var tmpExtra = extraValues.get(lhs); + extraValues.set(lhs, extraValues.get(rhs)); + extraValues.set(rhs, tmpExtra); +$endif$ } /** @@ -263,6 +358,10 @@ $endif$ // Round up to the next full bucket. newSize = (newSize + bucketSize - 1) / bucketSize; values = bigArrays.resize(values, newSize * bucketSize); +$if(hasExtra)$ + // Round up to the next full bucket. + extraValues = bigArrays.resize(extraValues, newSize * bucketSize); +$endif$ // Set the next gather offsets for all newly allocated buckets. fillGatherOffsets(oldMax); } @@ -344,13 +443,23 @@ $endif$ int leftChild = parent * 2 + 1; long leftIndex = rootIndex + leftChild; if (leftChild < heapSize) { - if (betterThan(values.get(worstIndex), values.get(leftIndex))) { + if (betterThan(values.get(worstIndex), values.get(leftIndex)$if(hasExtra)$, extraValues.get(worstIndex), extraValues.get(leftIndex)$endif$)) { worst = leftChild; worstIndex = leftIndex; } int rightChild = leftChild + 1; long rightIndex = rootIndex + rightChild; +$if(hasExtra)$ + if (rightChild < heapSize + && betterThan( + values.get(worstIndex), + values.get(rightIndex), + extraValues.get(worstIndex), + extraValues.get(rightIndex) + )) { +$else$ if (rightChild < heapSize && betterThan(values.get(worstIndex), values.get(rightIndex))) { +$endif$ worst = rightChild; worstIndex = rightIndex; } @@ -365,6 +474,6 @@ $endif$ @Override public final void close() { - Releasables.close(values, heapMode); + Releasables.close(values, $if(hasExtra)$extraValues, $endif$heapMode); } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec index 8c6d157790d6f..b99a7b232bcf9 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec @@ -312,3 +312,96 @@ top_salaries:integer [25324, 25945, 25976] ; + +youngestEmployees +required_capability: agg_top_with_output_field +FROM employees +| STATS youngest_employees = TOP(birth_date, 3, "desc"), + youngest_employees_salaries = TOP(birth_date, 3, "desc", salary) +; + +youngest_employees:date | youngest_employees_salaries:integer +[1965-01-03T00:00:00.000Z, 1964-10-18T00:00:00.000Z, 1964-06-11T00:00:00.000Z] | [37702, 25976, 45656] +; + +youngestEmployeesByGender +required_capability: agg_top_with_output_field +FROM employees +| STATS youngest_employees = TOP(birth_date, 3, "desc"), + youngest_employees_salaries = TOP(birth_date, 3, "desc", salary) + BY gender +| SORT gender +| KEEP gender, youngest_employees, youngest_employees_salaries +; + +gender:keyword | youngest_employees:datetime | youngest_employees_salaries:integer +F | [1964-10-18T00:00:00.000Z, 1964-06-02T00:00:00.000Z, 1963-03-21T00:00:00.000Z] | [25976, 56371, 43602] +M | [1965-01-03T00:00:00.000Z, 1964-06-11T00:00:00.000Z, 1964-04-18T00:00:00.000Z] | [37702, 45656, 46595] +null | [1963-06-07T00:00:00.000Z, 1963-06-01T00:00:00.000Z, 1961-05-02T00:00:00.000Z] | [48735, 45797, 61358] +; + +oldestEmployees +required_capability: agg_top_with_output_field +FROM employees +| STATS oldest_employees = TOP(birth_date, 3, "asc"), + oldest_employees_salaries = TOP(birth_date, 3, "asc", salary) +; + +oldest_employees:date | oldest_employees_salaries:integer +[1952-02-27T00:00:00.000Z, 1952-04-19T00:00:00.000Z, 1952-05-15T00:00:00.000Z] | [71165, 66174, 54518] +; + +oldestEmployeesByGender +required_capability: agg_top_with_output_field +FROM employees +| STATS oldest_employees = TOP(birth_date, 3, "asc"), + oldest_employees_salaries = TOP(birth_date, 3, "asc", salary) + BY gender +| SORT gender +| KEEP gender, oldest_employees, oldest_employees_salaries +; + +gender:keyword | oldest_employees:datetime | oldest_employees_salaries:integer +F | [1952-04-19T00:00:00.000Z, 1952-05-15T00:00:00.000Z, 1952-06-13T00:00:00.000Z] | [66174, 54518, 62405] +M | [1952-02-27T00:00:00.000Z, 1952-07-08T00:00:00.000Z, 1952-11-13T00:00:00.000Z] | [71165, 48233, 31897] +null | [1953-01-23T00:00:00.000Z, 1953-11-07T00:00:00.000Z, 1954-06-19T00:00:00.000Z] | [73717, 31120, 56760] +; + +oldestEmployeesDuplicated +required_capability: agg_top_with_output_field +FROM employees +| STATS oldest_employees = TOP(birth_date, 3, "asc"), + oldest_employees_dup = TOP(birth_date, 3, "asc", birth_date) +; + +oldest_employees:date | oldest_employees_dup:date +[1952-02-27T00:00:00.000Z, 1952-04-19T00:00:00.000Z, 1952-05-15T00:00:00.000Z] | [1952-02-27T00:00:00.000Z, 1952-04-19T00:00:00.000Z, 1952-05-15T00:00:00.000Z] +; + +empNoByHireDateByBirthYear +required_capability: agg_top_with_output_field + +FROM employees +| EVAL birth_year = DATE_EXTRACT("year", birth_date) +| STATS v = TOP(birth_year, 20, "asc", emp_no) BY birth_year +| KEEP birth_year, v +| SORT birth_year +; + +birth_year:long | v:integer +1952 | [10009, 10020, 10022, 10063, 10066, 10072, 10076, 10097] +1953 | [10001, 10006, 10011, 10019, 10023, 10026, 10035, 10051, 10059, 10067, 10100] +1954 | [10004, 10018, 10053, 10057, 10058, 10073, 10088, 10096] +1955 | [10005, 10070, 10074, 10091] +1956 | [10014, 10029, 10033, 10055, 10099] +1957 | [10007, 10054, 10080, 10094] +1958 | [10008, 10017, 10024, 10025, 10030, 10050, 10071] +1959 | [10003, 10015, 10031, 10036, 10039, 10064, 10078, 10083, 10087] +1960 | [10012, 10021, 10032, 10038, 10069, 10075, 10081, 10084] +1961 | [10016, 10052, 10056, 10060, 10062, 10079, 10090, 10098] +1962 | [10027, 10034, 10061, 10068, 10085, 10086] +1963 | [10010, 10013, 10028, 10037, 10065, 10082, 10089] +1964 | [10002, 10077, 10092, 10093] +1965 | 10095 +null | null +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/EsqlIllegalArgumentException.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/EsqlIllegalArgumentException.java index d9a0694e98d2c..b615e7dfcbf49 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/EsqlIllegalArgumentException.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/EsqlIllegalArgumentException.java @@ -42,4 +42,8 @@ public static EsqlIllegalArgumentException illegalDataType(DataType dataType) { public static EsqlIllegalArgumentException illegalDataType(String dataTypeName) { return new EsqlIllegalArgumentException("illegal data type [" + dataTypeName + "]"); } + + public static EsqlIllegalArgumentException illegalDataTypeCombination(DataType dataType1, DataType dataType2) { + return new EsqlIllegalArgumentException("illegal data type combination [" + dataType1 + ", " + dataType2 + "]"); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 3acbb4d36899e..6ce54d483b9be 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -312,6 +312,11 @@ public enum Cap { */ AGG_TOP_WITH_OPTIONAL_ORDER_FIELD, + /** + * Support for the extra "map" field in {@code TOP} aggregation. + */ + AGG_TOP_WITH_OUTPUT_FIELD, + /** * {@code CASE} properly handling multivalue conditions. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 2f4d72338b4fc..c03eb08f5b054 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -356,7 +356,7 @@ private static FunctionDefinition[][] functions() { def(Sample.class, bi(Sample::new), "sample"), def(StdDev.class, uni(StdDev::new), "std_dev"), def(Sum.class, uni(Sum::new), "sum"), - def(Top.class, tri(Top::new), "top"), + def(Top.class, quad(Top::new), "top"), def(Values.class, uni(Values::new), "values"), def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg"), def(Present.class, uni(Present::new), "present"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java index 6a12f63ada117..aab0d4287d980 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java @@ -14,9 +14,27 @@ import org.elasticsearch.compute.aggregation.TopBooleanAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopBytesRefAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopDoubleDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopDoubleFloatAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopDoubleIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopDoubleLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopFloatDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopFloatFloatAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopFloatIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopFloatLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopIntDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopIntFloatAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopIntIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopIntLongAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopIpAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopLongDoubleAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopLongFloatAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopLongIntAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopLongLongAggregatorFunctionSupplier; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; @@ -31,18 +49,20 @@ import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.FunctionType; -import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; -import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.expression.function.TwoOptionalArguments; import org.elasticsearch.xpack.esql.planner.ToAggregator; import java.io.IOException; import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; import static java.util.Arrays.asList; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; @@ -53,7 +73,7 @@ public class Top extends AggregateFunction implements - OptionalArgument, + TwoOptionalArguments, ToAggregator, SurrogateExpression, PostOptimizationVerificationAware { @@ -81,27 +101,28 @@ public Top( name = "order", type = { "keyword" }, description = "The order to calculate the top values. Either `asc` or `desc`, and defaults to `asc` if omitted." - ) Expression order + ) Expression order, + @Param( + optional = true, + name = "outputField", + type = { "double", "integer", "long", "date" }, + description = "The extra field that, if present, will be the output of the TOP call instead of `field`." + ) Expression outputField ) { - this(source, field, Literal.TRUE, limit, order == null ? Literal.keyword(source, ORDER_ASC) : order); + this(source, field, Literal.TRUE, limit, order == null ? Literal.keyword(source, ORDER_ASC) : order, outputField); } - public Top(Source source, Expression field, Expression filter, Expression limit, Expression order) { - super(source, field, filter, asList(limit, order)); + public Top(Source source, Expression field, Expression filter, Expression limit, Expression order, @Nullable Expression outputField) { + super(source, field, filter, outputField != null ? asList(limit, order, outputField) : asList(limit, order)); } private Top(StreamInput in) throws IOException { - super( - Source.readFrom((PlanStreamInput) in), - in.readNamedWriteable(Expression.class), - in.readNamedWriteable(Expression.class), - in.readNamedWriteableCollectionAsList(Expression.class) - ); + super(in); } @Override public Top withFilter(Expression filter) { - return new Top(source(), field(), filter, limitField(), orderField()); + return new Top(source(), field(), filter, limitField(), orderField(), outputField()); } @Override @@ -117,6 +138,11 @@ Expression orderField() { return parameters().get(1); } + @Nullable + Expression outputField() { + return parameters().size() > 2 ? parameters().get(2) : null; + } + private Integer limitValue() { return Foldables.limitValue(limitField(), sourceText()); } @@ -155,6 +181,30 @@ protected TypeResolution resolveType() { .and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer")) .and(isNotNull(orderField(), sourceText(), THIRD)) .and(isString(orderField(), sourceText(), THIRD)); + if (outputField() != null) { + typeResolution = typeResolution.and( + isType( + outputField(), + dt -> dt == DataType.DATETIME || (dt.isNumeric() && dt != DataType.UNSIGNED_LONG), + sourceText(), + FOURTH, + "date", + "numeric except unsigned_long or counter types" + ) + ) + .and( + isType( + field(), + dt -> dt == DataType.DATETIME || (dt.isNumeric() && dt != DataType.UNSIGNED_LONG), + "when fourth argument is set, ", + sourceText(), + FIRST, + false, + "date", + "numeric except unsigned_long or counter types" + ) + ); + } if (typeResolution.unresolved()) { return typeResolution; @@ -242,46 +292,92 @@ private void postOptimizationVerificationOrder(Failures failures) { @Override public DataType dataType() { - return field().dataType().noText(); + return outputField() == null ? field().dataType().noText() : outputField().dataType().noText(); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Top::new, field(), filter(), limitField(), orderField()); + return NodeInfo.create(this, Top::new, field(), filter(), limitField(), orderField(), outputField()); } @Override public Top replaceChildren(List newChildren) { - return new Top(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), newChildren.get(3)); + return new Top( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.get(2), + newChildren.get(3), + newChildren.size() > 4 ? newChildren.get(4) : null + ); } + private static final Map> SUPPLIERS = Map.ofEntries( + Map.entry(DataType.LONG, TopLongAggregatorFunctionSupplier::new), + Map.entry(DataType.DATETIME, TopLongAggregatorFunctionSupplier::new), + Map.entry(DataType.INTEGER, TopIntAggregatorFunctionSupplier::new), + Map.entry(DataType.DOUBLE, TopDoubleAggregatorFunctionSupplier::new), + Map.entry(DataType.BOOLEAN, TopBooleanAggregatorFunctionSupplier::new), + Map.entry(DataType.IP, TopIpAggregatorFunctionSupplier::new), + Map.entry(DataType.KEYWORD, TopBytesRefAggregatorFunctionSupplier::new), + Map.entry(DataType.TEXT, TopBytesRefAggregatorFunctionSupplier::new) + ); + + private static final Map, BiFunction> SUPPLIERS_WITH_EXTRA = Map + .ofEntries( + Map.entry(Tuple.tuple(DataType.LONG, DataType.DATETIME), TopLongLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.LONG, DataType.INTEGER), TopLongIntAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.LONG, DataType.LONG), TopLongLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.LONG, DataType.FLOAT), TopLongFloatAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.LONG, DataType.DOUBLE), TopLongDoubleAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DATETIME, DataType.DATETIME), TopLongLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DATETIME, DataType.INTEGER), TopLongIntAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DATETIME, DataType.LONG), TopLongLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DATETIME, DataType.FLOAT), TopLongFloatAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DATETIME, DataType.DOUBLE), TopLongDoubleAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.INTEGER, DataType.DATETIME), TopIntLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.INTEGER, DataType.INTEGER), TopIntIntAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.INTEGER, DataType.LONG), TopIntLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.INTEGER, DataType.FLOAT), TopIntFloatAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.INTEGER, DataType.DOUBLE), TopIntDoubleAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.FLOAT, DataType.DATETIME), TopFloatLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.FLOAT, DataType.INTEGER), TopFloatIntAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.FLOAT, DataType.LONG), TopFloatLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.FLOAT, DataType.FLOAT), TopFloatFloatAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.FLOAT, DataType.DOUBLE), TopFloatDoubleAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DOUBLE, DataType.DATETIME), TopDoubleLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DOUBLE, DataType.INTEGER), TopDoubleIntAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DOUBLE, DataType.LONG), TopDoubleLongAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DOUBLE, DataType.FLOAT), TopDoubleFloatAggregatorFunctionSupplier::new), + Map.entry(Tuple.tuple(DataType.DOUBLE, DataType.DOUBLE), TopDoubleDoubleAggregatorFunctionSupplier::new) + ); + @Override public AggregatorFunctionSupplier supplier() { - DataType type = field().dataType(); - if (type == DataType.LONG || type == DataType.DATETIME) { - return new TopLongAggregatorFunctionSupplier(limitValue(), orderValue()); - } - if (type == DataType.INTEGER) { - return new TopIntAggregatorFunctionSupplier(limitValue(), orderValue()); - } - if (type == DataType.DOUBLE) { - return new TopDoubleAggregatorFunctionSupplier(limitValue(), orderValue()); - } - if (type == DataType.BOOLEAN) { - return new TopBooleanAggregatorFunctionSupplier(limitValue(), orderValue()); - } - if (type == DataType.IP) { - return new TopIpAggregatorFunctionSupplier(limitValue(), orderValue()); - } - if (DataType.isString(type)) { - return new TopBytesRefAggregatorFunctionSupplier(limitValue(), orderValue()); + DataType fieldType = field().dataType(); + BiFunction supplierCtor; + if (outputField() == null) { + supplierCtor = SUPPLIERS.get(fieldType); + if (supplierCtor == null) { + throw EsqlIllegalArgumentException.illegalDataType(fieldType); + } + } else { + DataType outputFieldType = outputField().dataType(); + supplierCtor = SUPPLIERS_WITH_EXTRA.get(Tuple.tuple(fieldType, outputFieldType)); + if (supplierCtor == null) { + throw EsqlIllegalArgumentException.illegalDataTypeCombination(fieldType, outputFieldType); + } } - throw EsqlIllegalArgumentException.illegalDataType(type); + return supplierCtor.apply(limitValue(), orderValue()); } @Override public Expression surrogate() { var s = source(); + // If the `outputField` is specified but its value is the same as `field` then we do not need to handle `outputField` separately. + if (outputField() != null && field().semanticEquals(outputField())) { + return new Top(s, field(), limitField(), orderField(), null); + } if (orderField() instanceof Literal && limitField() instanceof Literal && limitValue() == 1) { if (orderValue()) { return new Min(s, field(), filter()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java index 82bf57d1a194e..c9badc3d86993 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java @@ -20,7 +20,8 @@ protected Top createTestInstance() { Expression field = randomChild(); Expression limit = randomChild(); Expression order = randomChild(); - return new Top(source, field, limit, order); + Expression outputField = randomBoolean() ? null : randomChild(); + return new Top(source, field, limit, order, outputField); } @Override @@ -29,11 +30,13 @@ protected Top mutateInstance(Top instance) throws IOException { Expression field = instance.field(); Expression limit = instance.limitField(); Expression order = instance.orderField(); - switch (between(0, 2)) { + Expression outputField = instance.outputField(); + switch (between(0, 3)) { case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); case 1 -> limit = randomValueOtherThan(limit, AbstractExpressionSerializationTests::randomChild); case 2 -> order = randomValueOtherThan(order, AbstractExpressionSerializationTests::randomChild); + case 3 -> outputField = randomValueOtherThan(outputField, () -> randomBoolean() ? null : randomChild()); } - return new Top(source, field, limit, order); + return new Top(source, field, limit, order, outputField); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index 0d7733a6bcb98..d1a10da22372e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; @@ -24,11 +25,14 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; public class TopTests extends AbstractAggregationTestCase { public TopTests(@Name("TestCase") Supplier testCaseSupplier) { @@ -52,14 +56,97 @@ public static Iterable parameters() { MultiRowTestCaseSupplier.stringCases(1, 1000, DataType.TEXT) ) .flatMap(List::stream) - .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order)) + .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order, null)) .collect(Collectors.toCollection(() -> suppliers)); } } + for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) { + for (String order : Arrays.asList("asc", "desc")) { + int rows = 100; + List fieldCaseSuppliers = Stream.of( + MultiRowTestCaseSupplier.intCases(rows, rows, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(rows, rows, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(rows, rows, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(rows, rows) + ).flatMap(List::stream).toList(); + for (var fieldCaseSupplier : fieldCaseSuppliers) { + List outputFieldCaseSuppliers = Stream.of( + MultiRowTestCaseSupplier.intCases(rows, rows, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(rows, rows, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(rows, rows, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(rows, rows) + ).flatMap(List::stream).toList(); + for (var outputFieldCaseSupplier : outputFieldCaseSuppliers) { + if (fieldCaseSupplier.name().equals(outputFieldCaseSupplier.name())) { + continue; + } + suppliers.add(TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order, outputFieldCaseSupplier)); + } + } + } + } + suppliers.addAll( List.of( - // Surrogates + // Surrogates for cases where field and outputField are effectively the same field + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD, DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), + new TestCaseSupplier.TypedData(3, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field") + ), + "TopInt", + DataType.INTEGER, + equalTo(List.of(200, 8, 5)) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD, DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(3, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field") + ), + "TopLong", + DataType.LONG, + equalTo(List.of(200L, 8L, 5L)) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD, DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(3, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field") + ), + "TopDouble", + DataType.DOUBLE, + equalTo(List.of(200., 8., 5.)) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD, DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field"), + new TestCaseSupplier.TypedData(3, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field") + ), + "TopLong", + DataType.DATETIME, + equalTo(List.of(200L, 8L, 5L)) + ) + ), + // Surrogates for cases where limit == 1 new TestCaseSupplier( List.of(DataType.BOOLEAN, DataType.INTEGER, DataType.KEYWORD), () -> new TestCaseSupplier.TestCase( @@ -276,6 +363,84 @@ public static Iterable parameters() { ), "Invalid order value in [source], expected [ASC, DESC] but got [null]" ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD, DataType.BOOLEAN), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(true, false), DataType.BOOLEAN, "outputField") + ), + "fourth argument of [source] must be [date or numeric except unsigned_long or counter types], " + + "found value [outputField] type [boolean]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of("a", "b"), DataType.KEYWORD, "outputField") + ), + "fourth argument of [source] must be [date or numeric except unsigned_long or counter types], " + + "found value [outputField] type [keyword]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD, DataType.IP), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of("192.168.0.1", "192.168.0.2"), DataType.IP, "outputField") + ), + "fourth argument of [source] must be [date or numeric except unsigned_long or counter types], " + + "found value [outputField] type [ip]" + ) + ), + new TestCaseSupplier( + List.of(DataType.BOOLEAN, DataType.INTEGER, DataType.KEYWORD, DataType.LONG), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(true, false), DataType.BOOLEAN, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "outputField") + ), + "when fourth argument is set, first argument of [source] must be " + + "[date or numeric except unsigned_long or counter types], found value [field] type [boolean]" + ) + ), + new TestCaseSupplier( + List.of(DataType.KEYWORD, DataType.INTEGER, DataType.KEYWORD, DataType.LONG), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of("a", "b"), DataType.KEYWORD, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "outputField") + ), + "when fourth argument is set, first argument of [source] must be " + + "[date or numeric except unsigned_long or counter types], found value [field] type [keyword]" + ) + ), + new TestCaseSupplier( + List.of(DataType.IP, DataType.INTEGER, DataType.KEYWORD, DataType.LONG), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of("192.168.0.1", "192.168.0.2"), DataType.IP, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData("asc", DataType.KEYWORD, "order").forceLiteral(), + TestCaseSupplier.TypedData.multiRow(List.of(1L, 2L), DataType.LONG, "outputField") + ), + "when fourth argument is set, first argument of [source] must be " + + "[date or numeric except unsigned_long or counter types], found value [field] type [ip]" + ) ) ) ); @@ -285,32 +450,75 @@ public static Iterable parameters() { @Override protected Expression build(Source source, List args) { - return new Top(source, args.get(0), args.get(1), args.size() == 3 ? args.get(2) : null); + Expression field = args.get(0); + Expression outputField = args.size() > 3 ? args.get(3) : null; + if (field instanceof FieldAttribute f && outputField instanceof FieldAttribute of && f.fieldName().equals(of.fieldName())) { + // In order to avoid passing the same field twice as two different FieldAttribute objects, we use `field` as the fourth argument + // if `field`'s and `outputField`'s `fieldName` match. In such case TOP will use surrogate. + return new Top(source, field, args.get(1), args.size() > 2 ? args.get(2) : null, field); + } else { + return new Top(source, field, args.get(1), args.size() > 2 ? args.get(2) : null, outputField); + } } @SuppressWarnings("unchecked") private static TestCaseSupplier makeSupplier( TestCaseSupplier.TypedDataSupplier fieldSupplier, TestCaseSupplier.TypedDataSupplier limitCaseSupplier, - String order + String order, + TestCaseSupplier.TypedDataSupplier outputFieldSupplier ) { boolean isAscending = order == null || order.equalsIgnoreCase("asc"); - boolean noOrderSupplied = order == null; + boolean orderSupplied = order != null; + boolean outputFieldSupplied = outputFieldSupplier != null; + + List dataTypes = new ArrayList<>(); + dataTypes.add(fieldSupplier.type()); + dataTypes.add(DataType.INTEGER); + if (orderSupplied) { + dataTypes.add(DataType.KEYWORD); + } + if (outputFieldSupplied) { + dataTypes.add(outputFieldSupplier.type()); + } - List dataTypes = noOrderSupplied - ? List.of(fieldSupplier.type(), DataType.INTEGER) - : List.of(fieldSupplier.type(), DataType.INTEGER, DataType.KEYWORD); + DataType expectedType = outputFieldSupplied ? outputFieldSupplier.type() : fieldSupplier.type(); return new TestCaseSupplier(fieldSupplier.name(), dataTypes, () -> { var fieldTypedData = fieldSupplier.get(); var limitTypedData = limitCaseSupplier.get().forceLiteral(); var limit = (int) limitTypedData.getValue(); - var expected = fieldTypedData.multiRowData() - .stream() - .map(v -> (Comparable>) v) - .sorted(isAscending ? Comparator.naturalOrder() : Comparator.reverseOrder()) - .limit(limit) - .toList(); + TestCaseSupplier.TypedData outputFieldTypedData; + List expected; + if (outputFieldSupplied) { + outputFieldTypedData = outputFieldSupplier.get(); + assertThat(outputFieldTypedData.multiRowData(), hasSize(equalTo(fieldTypedData.multiRowData().size()))); + Comparator>, Comparable>>> comparator = Map.Entry.< + Comparable>, + Comparable>>comparingByKey().thenComparing(Map.Entry::getValue); + if (isAscending == false) { + comparator = comparator.reversed(); + } + expected = IntStream.range(0, fieldTypedData.multiRowData().size()) + .mapToObj( + i -> Map.>, Comparable>>entry( + (Comparable>) fieldTypedData.multiRowData().get(i), + (Comparable>) outputFieldTypedData.multiRowData().get(i) + ) + ) + .sorted(comparator) + .map(Map.Entry::getValue) + .limit(limit) + .toList(); + } else { + outputFieldTypedData = null; + expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .sorted(isAscending ? Comparator.naturalOrder() : Comparator.reverseOrder()) + .limit(limit) + .toList(); + } String baseName; if (limit != 1) { @@ -324,18 +532,22 @@ private static TestCaseSupplier makeSupplier( } } - List typedData = noOrderSupplied - ? List.of(fieldTypedData, limitTypedData) - : List.of( - fieldTypedData, - limitTypedData, - new TestCaseSupplier.TypedData(new BytesRef(order), DataType.KEYWORD, order + " order").forceLiteral() - ); + List typedData = new ArrayList<>(); + typedData.add(fieldTypedData); + typedData.add(limitTypedData); + if (orderSupplied) { + typedData.add(new TestCaseSupplier.TypedData(new BytesRef(order), DataType.KEYWORD, order + " order").forceLiteral()); + } + if (outputFieldSupplied) { + typedData.add(outputFieldTypedData); + } return new TestCaseSupplier.TestCase( typedData, - standardAggregatorName(baseName, fieldTypedData.type()), - fieldSupplier.type(), + outputFieldSupplied && (fieldTypedData.name().equals(outputFieldTypedData.name()) == false) + ? standardAggregatorName(standardAggregatorName(baseName, fieldTypedData.type()), outputFieldTypedData.type()) + : standardAggregatorName(baseName, fieldTypedData.type()), + expectedType, equalTo(expected.size() == 1 ? expected.get(0) : expected) ); }); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java index c50c998674704..c7404360c2017 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java @@ -47,7 +47,8 @@ public static List randomAggregates() { randomSource(), FieldAttributeTests.createFieldAttribute(1, true), new Literal(randomSource(), between(1, 5), DataType.INTEGER), - Literal.keyword(randomSource(), randomFrom("ASC", "DESC")) + Literal.keyword(randomSource(), randomFrom("ASC", "DESC")), + randomBoolean() ? null : FieldAttributeTests.createFieldAttribute(1, true) ); case 4 -> new Values(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java index 2f848d07f657f..02ffa191b0832 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java @@ -181,7 +181,7 @@ public void testInfoParameters() throws Exception { */ expectedCount -= 1; - assertEquals(expectedCount, info(node).properties().size()); + assertEquals("Wrong number of info parameters for " + subclass.getSimpleName(), expectedCount, info(node).properties().size()); } /**