diff --git a/docs/changelog/113183.yaml b/docs/changelog/113183.yaml new file mode 100644 index 0000000000000..f30ce9831adb3 --- /dev/null +++ b/docs/changelog/113183.yaml @@ -0,0 +1,6 @@ +pr: 113183 +summary: "ESQL: TOP support for strings" +area: ES|QL +type: feature +issues: + - 109849 diff --git a/docs/reference/esql/functions/kibana/definition/top.json b/docs/reference/esql/functions/kibana/definition/top.json index 2e8e51e726588..62184326994fe 100644 --- a/docs/reference/esql/functions/kibana/definition/top.json +++ b/docs/reference/esql/functions/kibana/definition/top.json @@ -124,6 +124,30 @@ "variadic" : false, "returnType" : "ip" }, + { + "params" : [ + { + "name" : "field", + "type" : "keyword", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : false, + "description" : "The order to calculate the top values. Either `asc` or `desc`." + } + ], + "variadic" : false, + "returnType" : "keyword" + }, { "params" : [ { @@ -147,6 +171,30 @@ ], "variadic" : false, "returnType" : "long" + }, + { + "params" : [ + { + "name" : "field", + "type" : "text", + "optional" : false, + "description" : "The field to collect the top values for." + }, + { + "name" : "limit", + "type" : "integer", + "optional" : false, + "description" : "The maximum number of values to collect." + }, + { + "name" : "order", + "type" : "keyword", + "optional" : false, + "description" : "The order to calculate the top values. Either `asc` or `desc`." + } + ], + "variadic" : false, + "returnType" : "text" } ], "examples" : [ diff --git a/docs/reference/esql/functions/types/top.asciidoc b/docs/reference/esql/functions/types/top.asciidoc index 0eb329c10b9ed..25d7962a27252 100644 --- a/docs/reference/esql/functions/types/top.asciidoc +++ b/docs/reference/esql/functions/types/top.asciidoc @@ -10,5 +10,7 @@ date | integer | keyword | date double | integer | keyword | double integer | integer | keyword | integer ip | integer | keyword | ip +keyword | integer | keyword | keyword long | integer | keyword | long +text | integer | keyword | text |=== diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 81d1a6f5360ca..49e819b7cdc88 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -635,6 +635,11 @@ tasks.named('stringTemplates').configure { it.inputFile = topAggregatorInputFile it.outputFile = "org/elasticsearch/compute/aggregation/TopBooleanAggregator.java" } + template { + it.properties = bytesRefProperties + it.inputFile = topAggregatorInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java" + } template { it.properties = ipProperties it.inputFile = topAggregatorInputFile diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java new file mode 100644 index 0000000000000..c9b0e679b3e64 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Aggregates the top N field values for BytesRef. + *

+ * This class is generated. Edit `X-TopAggregator.java.st` to edit this file. + *

+ */ +@Aggregator({ @IntermediateState(name = "top", type = "BYTES_REF_BLOCK") }) +@GroupingAggregator +class TopBytesRefAggregator { + public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) { + return new SingleState(bigArrays, limit, ascending); + } + + public static void combine(SingleState state, BytesRef v) { + state.add(v); + } + + public static void combineIntermediate(SingleState state, BytesRefBlock values) { + int start = values.getFirstValueIndex(0); + int end = start + values.getValueCount(0); + var scratch = new BytesRef(); + for (int i = start; i < end; i++) { + combine(state, values.getBytesRef(i, scratch)); + } + } + + public static Block evaluateFinal(SingleState state, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory()); + } + + public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) { + return new GroupingState(bigArrays, limit, ascending); + } + + public static void combine(GroupingState state, int groupId, BytesRef v) { + state.add(groupId, v); + } + + public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { + int start = values.getFirstValueIndex(valuesPosition); + int end = start + values.getValueCount(valuesPosition); + var scratch = new BytesRef(); + for (int i = start; i < end; i++) { + combine(state, groupId, values.getBytesRef(i, scratch)); + } + } + + public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) { + current.merge(groupId, state, statePosition); + } + + public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) { + return state.toBlock(driverContext.blockFactory(), selected); + } + + public static class GroupingState implements Releasable { + private final BytesRefBucketedSort sort; + + private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { + // TODO pass the breaker in from the DriverContext + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "top", bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); + } + + public void add(int groupId, BytesRef value) { + sort.collect(value, groupId); + } + + public void merge(int groupId, GroupingState other, int otherGroupId) { + sort.merge(groupId, other.sort, otherGroupId); + } + + void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory(), selected); + } + + Block toBlock(BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from nulls on the values block + } + + @Override + public void close() { + Releasables.closeExpectNoException(sort); + } + } + + public static class SingleState implements Releasable { + private final GroupingState internalState; + + private SingleState(BigArrays bigArrays, int limit, boolean ascending) { + this.internalState = new GroupingState(bigArrays, limit, ascending); + } + + public void add(BytesRef value) { + internalState.add(0, value); + } + + public void merge(GroupingState other) { + internalState.merge(0, other, 0); + } + + void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = toBlock(driverContext.blockFactory()); + } + + Block toBlock(BlockFactory blockFactory) { + try (var intValues = blockFactory.newConstantIntVector(0, 1)) { + return internalState.toBlock(blockFactory, intValues); + } + } + + @Override + public void close() { + Releasables.closeExpectNoException(internalState); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunction.java new file mode 100644 index 0000000000000..17b3d84ab0028 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunction.java @@ -0,0 +1,174 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link TopBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopBytesRefAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.BYTES_REF) ); + + private final DriverContext driverContext; + + private final TopBytesRefAggregator.SingleState state; + + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopBytesRefAggregatorFunction(DriverContext driverContext, List channels, + TopBytesRefAggregator.SingleState state, int limit, boolean ascending) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + this.limit = limit; + this.ascending = ascending; + } + + public static TopBytesRefAggregatorFunction create(DriverContext driverContext, + List channels, int limit, boolean ascending) { + return new TopBytesRefAggregatorFunction(driverContext, channels, TopBytesRefAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.isConstant()) { + if (mask.getBoolean(0) == false) { + // Entire page masked away + return; + } + // No masking + BytesRefBlock block = page.getBlock(channels.get(0)); + BytesRefVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + BytesRefBlock block = page.getBlock(channels.get(0)); + BytesRefVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(BytesRefVector vector) { + BytesRef scratch = new BytesRef(); + for (int i = 0; i < vector.getPositionCount(); i++) { + TopBytesRefAggregator.combine(state, vector.getBytesRef(i, scratch)); + } + } + + private void addRawVector(BytesRefVector vector, BooleanVector mask) { + BytesRef scratch = new BytesRef(); + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + TopBytesRefAggregator.combine(state, vector.getBytesRef(i, scratch)); + } + } + + private void addRawBlock(BytesRefBlock block) { + BytesRef scratch = new BytesRef(); + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + TopBytesRefAggregator.combine(state, block.getBytesRef(i, scratch)); + } + } + } + + private void addRawBlock(BytesRefBlock block, BooleanVector mask) { + BytesRef scratch = new BytesRef(); + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + TopBytesRefAggregator.combine(state, block.getBytesRef(i, scratch)); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + assert top.getPositionCount() == 1; + BytesRef scratch = new BytesRef(); + TopBytesRefAggregator.combineIntermediate(state, top); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + blocks[offset] = TopBytesRefAggregator.evaluateFinal(state, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..8c77d2116bf69 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionSupplier.java @@ -0,0 +1,45 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link TopBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopBytesRefAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + private final int limit; + + private final boolean ascending; + + public TopBytesRefAggregatorFunctionSupplier(List channels, int limit, + boolean ascending) { + this.channels = channels; + this.limit = limit; + this.ascending = ascending; + } + + @Override + public TopBytesRefAggregatorFunction aggregator(DriverContext driverContext) { + return TopBytesRefAggregatorFunction.create(driverContext, channels, limit, ascending); + } + + @Override + public TopBytesRefGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return TopBytesRefGroupingAggregatorFunction.create(channels, driverContext, limit, ascending); + } + + @Override + public String describe() { + return "top of bytes"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..aa2d6094c8c3f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunction.java @@ -0,0 +1,221 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link TopBytesRefAggregator}. + * This class is generated. Do not edit it. + */ +public final class TopBytesRefGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("top", ElementType.BYTES_REF) ); + + private final TopBytesRefAggregator.GroupingState state; + + private final List channels; + + private final DriverContext driverContext; + + private final int limit; + + private final boolean ascending; + + public TopBytesRefGroupingAggregatorFunction(List channels, + TopBytesRefAggregator.GroupingState state, DriverContext driverContext, int limit, + boolean ascending) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + this.limit = limit; + this.ascending = ascending; + } + + public static TopBytesRefGroupingAggregatorFunction create(List channels, + DriverContext driverContext, int limit, boolean ascending) { + return new TopBytesRefGroupingAggregatorFunction(channels, TopBytesRefAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BytesRefBlock valuesBlock = page.getBlock(channels.get(0)); + BytesRefVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BytesRefBlock values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(v, scratch)); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BytesRefVector values) { + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + TopBytesRefAggregator.combine(state, groupId, values.getBytesRef(groupPosition + positionOffset, scratch)); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block topUncast = page.getBlock(channels.get(0)); + if (topUncast.areAllValuesNull()) { + return; + } + BytesRefBlock top = (BytesRefBlock) topUncast; + BytesRef scratch = new BytesRef(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + TopBytesRefAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + TopBytesRefAggregator.GroupingState inState = ((TopBytesRefGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + TopBytesRefAggregator.combineStates(state, groupId, inState, position); + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = TopBytesRefAggregator.evaluateFinal(state, selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st index b97d26ee6147d..18d573eea4a4c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st @@ -7,9 +7,12 @@ package org.elasticsearch.compute.aggregation; -$if(Ip)$ +$if(BytesRef || Ip)$ import org.apache.lucene.util.BytesRef; $endif$ +$if(BytesRef)$ +import org.elasticsearch.common.breaker.CircuitBreaker; +$endif$ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; @@ -49,7 +52,7 @@ class Top$Name$Aggregator { public static void combineIntermediate(SingleState state, $Type$Block values) { int start = values.getFirstValueIndex(0); int end = start + values.getValueCount(0); -$if(Ip)$ +$if(BytesRef || Ip)$ var scratch = new BytesRef(); for (int i = start; i < end; i++) { combine(state, values.get$Type$(i, scratch)); @@ -76,7 +79,7 @@ $endif$ public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) { int start = values.getFirstValueIndex(valuesPosition); int end = start + values.getValueCount(valuesPosition); -$if(Ip)$ +$if(BytesRef || Ip)$ var scratch = new BytesRef(); for (int i = start; i < end; i++) { combine(state, groupId, values.get$Type$(i, scratch)); @@ -100,7 +103,13 @@ $endif$ private final $Name$BucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { +$if(BytesRef)$ + // TODO pass the breaker in from the DriverContext + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); + this.sort = new BytesRefBucketedSort(breaker, "top", bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); +$else$ this.sort = new $Name$BucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit); +$endif$ } public void add(int groupId, $type$ value) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BucketedSortCommon.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BucketedSortCommon.java new file mode 100644 index 0000000000000..58306f2140a82 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BucketedSortCommon.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.search.sort.SortOrder; + +/** + * Components common to BucketedSort implementations. + */ +class BucketedSortCommon implements Releasable { + final BigArrays bigArrays; + final SortOrder order; + final int bucketSize; + + /** + * {@code true} if the bucket is in heap mode, {@code false} if + * it is still gathering. + */ + private final BitArray heapMode; + + BucketedSortCommon(BigArrays bigArrays, SortOrder order, int bucketSize) { + this.bigArrays = bigArrays; + this.order = order; + this.bucketSize = bucketSize; + this.heapMode = new BitArray(0, bigArrays); + } + + /** + * The first index in a bucket. Note that this might not be used. + * See {@link } + */ + long rootIndex(int bucket) { + return (long) bucket * bucketSize; + } + + /** + * The last index in a bucket. + */ + long endIndex(long rootIndex) { + return rootIndex + bucketSize; + } + + boolean inHeapMode(int bucket) { + return heapMode.get(bucket); + } + + void enableHeapMode(int bucket) { + heapMode.set(bucket); + } + + void assertValidNextOffset(int next) { + assert 0 <= next && next < bucketSize + : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + } + + @Override + public void close() { + heapMode.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSort.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSort.java new file mode 100644 index 0000000000000..9198de53b1e04 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSort.java @@ -0,0 +1,386 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ByteUtils; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; +import org.elasticsearch.core.Assertions; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.sort.BucketedSort; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.Arrays; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +/** + * Aggregates the top N variable length {@link BytesRef} values per bucket. + * See {@link BucketedSort} for more information. + */ +public class BytesRefBucketedSort implements Releasable { + private final BucketedSortCommon common; + private final CircuitBreaker breaker; + private final String label; + + /** + * An array containing all the values on all buckets. The structure is as follows: + *

+ * For each bucket, there are {@link BucketedSortCommon#bucketSize} elements, based + * on the bucket id (0, 1, 2...). Then, for each bucket, it can be in 2 states: + *

+ *
    + *
  • + * Gather mode: All buckets start in gather mode, and remain here while they have + * less than bucketSize elements. In gather mode, the elements are stored in the + * array from the highest index to the lowest index. The lowest index contains + * the offset to the next slot to be filled. + *

    + * This allows us to insert elements in O(1) time. + *

    + *

    + * When the bucketSize-th element is collected, the bucket transitions to heap + * mode, by heapifying its contents. + *

    + *
  • + *
  • + * Heap mode: The bucket slots are organized as a min heap structure. + *

    + * The root of the heap is the minimum value in the bucket, + * which allows us to quickly discard new values that are not in the top N. + *

    + *
  • + *
+ */ + private ObjectArray values; + + public BytesRefBucketedSort(CircuitBreaker breaker, String label, BigArrays bigArrays, SortOrder order, int bucketSize) { + this.breaker = breaker; + this.label = label; + common = new BucketedSortCommon(bigArrays, order, bucketSize); + boolean success = false; + try { + values = bigArrays.newObjectArray(0); + success = true; + } finally { + if (success == false) { + close(); + } + } + } + + private void checkInvariant(int bucket) { + if (Assertions.ENABLED == false) { + return; + } + long rootIndex = common.rootIndex(bucket); + long requiredSize = common.endIndex(rootIndex); + if (values.size() < requiredSize) { + throw new AssertionError("values too short " + values.size() + " < " + requiredSize); + } + if (values.get(rootIndex) == null) { + throw new AssertionError("new gather offset can't be null"); + } + if (common.inHeapMode(bucket) == false) { + common.assertValidNextOffset(getNextGatherOffset(rootIndex)); + } else { + for (long l = rootIndex; l < common.endIndex(rootIndex); l++) { + if (values.get(rootIndex) == null) { + throw new AssertionError("values missing in heap mode"); + } + } + } + } + + /** + * Collects a {@code value} into a {@code bucket}. + *

+ * It may or may not be inserted in the heap, depending on if it is better than the current root. + *

+ */ + public void collect(BytesRef value, int bucket) { + long rootIndex = common.rootIndex(bucket); + if (common.inHeapMode(bucket)) { + if (betterThan(value, values.get(rootIndex).bytesRefView())) { + clearedBytesAt(rootIndex).append(value); + downHeap(rootIndex, 0); + } + checkInvariant(bucket); + return; + } + // Gathering mode + long requiredSize = common.endIndex(rootIndex); + if (values.size() < requiredSize) { + grow(requiredSize); + } + int next = getNextGatherOffset(rootIndex); + common.assertValidNextOffset(next); + long index = next + rootIndex; + clearedBytesAt(index).append(value); + if (next == 0) { + common.enableHeapMode(bucket); + heapify(rootIndex); + } else { + ByteUtils.writeIntLE(next - 1, values.get(rootIndex).bytes(), 0); + } + checkInvariant(bucket); + } + + /** + * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}. + */ + public void merge(int bucket, BytesRefBucketedSort other, int otherBucket) { + long otherRootIndex = other.common.rootIndex(otherBucket); + if (otherRootIndex >= other.values.size()) { + // The value was never collected. + return; + } + other.checkInvariant(bucket); + long otherStart = other.startIndex(otherBucket, otherRootIndex); + long otherEnd = other.common.endIndex(otherRootIndex); + // TODO: This can be improved for heapified buckets by making use of the heap structures + for (long i = otherStart; i < otherEnd; i++) { + collect(other.values.get(i).bytesRefView(), bucket); + } + } + + /** + * Creates a block with the values from the {@code selected} groups. + */ + public Block toBlock(BlockFactory blockFactory, IntVector selected) { + // Check if the selected groups are all empty, to avoid allocating extra memory + if (IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> { + long rootIndex = common.rootIndex(bucket); + if (rootIndex >= values.size()) { + // Never collected + return false; + } + long start = startIndex(bucket, rootIndex); + long end = common.endIndex(rootIndex); + long size = end - start; + return size > 0; + })) { + return blockFactory.newConstantNullBlock(selected.getPositionCount()); + } + + // Used to sort the values in the bucket. + BytesRef[] bucketValues = new BytesRef[common.bucketSize]; + + try (var builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int bucket = selected.getInt(s); + long rootIndex = common.rootIndex(bucket); + if (rootIndex >= values.size()) { + // Never collected + builder.appendNull(); + continue; + } + + long start = startIndex(bucket, rootIndex); + long end = common.endIndex(rootIndex); + long size = end - start; + + if (size == 0) { + builder.appendNull(); + continue; + } + + if (size == 1) { + try (BreakingBytesRefBuilder bytes = values.get(start)) { + builder.appendBytesRef(bytes.bytesRefView()); + } + values.set(start, null); + continue; + } + + for (int i = 0; i < size; i++) { + try (BreakingBytesRefBuilder bytes = values.get(start + i)) { + bucketValues[i] = bytes.bytesRefView(); + } + values.set(start + i, null); + } + + // TODO: Make use of heap structures to faster iterate in order instead of copying and sorting + Arrays.sort(bucketValues, 0, (int) size); + + builder.beginPositionEntry(); + if (common.order == SortOrder.ASC) { + for (int i = 0; i < size; i++) { + builder.appendBytesRef(bucketValues[i]); + } + } else { + for (int i = (int) size - 1; i >= 0; i--) { + builder.appendBytesRef(bucketValues[i]); + } + } + builder.endPositionEntry(); + } + return builder.build(); + } + } + + private long startIndex(int bucket, long rootIndex) { + if (common.inHeapMode(bucket)) { + return rootIndex; + } + return rootIndex + getNextGatherOffset(rootIndex) + 1; + } + + /** + * Get the next index that should be "gathered" for a bucket rooted + * at {@code rootIndex}. + *

+ * Using the first 4 bytes of the element to store the next gather offset. + *

+ */ + private int getNextGatherOffset(long rootIndex) { + BreakingBytesRefBuilder bytes = values.get(rootIndex); + assert bytes.length() == Integer.BYTES; + return ByteUtils.readIntLE(bytes.bytes(), 0); + } + + /** + * {@code true} if the entry at index {@code lhs} is "better" than + * the entry at {@code rhs}. "Better" in this means "lower" for + * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. + */ + private boolean betterThan(BytesRef lhs, BytesRef rhs) { + return common.order.reverseMul() * lhs.compareTo(rhs) < 0; + } + + /** + * Swap the data at two indices. + */ + private void swap(long lhs, long rhs) { + BreakingBytesRefBuilder tmp = values.get(lhs); + values.set(lhs, values.get(rhs)); + values.set(rhs, tmp); + } + + /** + * Allocate storage for more buckets and store the "next gather offset" + * for those new buckets. + */ + private void grow(long requiredSize) { + long oldMax = values.size(); + values = common.bigArrays.grow(values, requiredSize); + // Set the next gather offsets for all newly allocated buckets. + fillGatherOffsets(oldMax - (oldMax % common.bucketSize)); + } + + /** + * Maintain the "next gather offsets" for newly allocated buckets. + */ + private void fillGatherOffsets(long startingAt) { + assert startingAt % common.bucketSize == 0; + int nextOffset = common.bucketSize - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size(); bucketRoot += common.bucketSize) { + BreakingBytesRefBuilder bytes = values.get(bucketRoot); + if (bytes != null) { + continue; + } + bytes = new BreakingBytesRefBuilder(breaker, label); + values.set(bucketRoot, bytes); + bytes.grow(Integer.BYTES); + bytes.setLength(Integer.BYTES); + ByteUtils.writeIntLE(nextOffset, bytes.bytes(), 0); + } + } + + /** + * Heapify a bucket whose entries are in random order. + *

+ * This works by validating the heap property on each node, iterating + * "upwards", pushing any out of order parents "down". Check out the + * wikipedia + * entry on binary heaps for more about this. + *

+ *

+ * While this *looks* like it could easily be {@code O(n * log n)}, it is + * a fairly well studied algorithm attributed to Floyd. There's + * been a bunch of work that puts this at {@code O(n)}, close to 1.88n worst + * case. + *

+ * + * @param rootIndex the index the start of the bucket + */ + private void heapify(long rootIndex) { + int maxParent = common.bucketSize / 2 - 1; + for (int parent = maxParent; parent >= 0; parent--) { + downHeap(rootIndex, parent); + } + } + + /** + * Correct the heap invariant of a parent and its children. This + * runs in {@code O(log n)} time. + * @param rootIndex index of the start of the bucket + * @param parent Index within the bucket of the parent to check. + * For example, 0 is the "root". + */ + private void downHeap(long rootIndex, int parent) { + while (true) { + long parentIndex = rootIndex + parent; + int worst = parent; + long worstIndex = parentIndex; + int leftChild = parent * 2 + 1; + long leftIndex = rootIndex + leftChild; + if (leftChild < common.bucketSize) { + if (betterThan(values.get(worstIndex).bytesRefView(), values.get(leftIndex).bytesRefView())) { + worst = leftChild; + worstIndex = leftIndex; + } + int rightChild = leftChild + 1; + long rightIndex = rootIndex + rightChild; + if (rightChild < common.bucketSize + && betterThan(values.get(worstIndex).bytesRefView(), values.get(rightIndex).bytesRefView())) { + + worst = rightChild; + worstIndex = rightIndex; + } + } + if (worst == parent) { + break; + } + swap(worstIndex, parentIndex); + parent = worst; + } + } + + private BreakingBytesRefBuilder clearedBytesAt(long index) { + BreakingBytesRefBuilder bytes = values.get(index); + if (bytes == null) { + bytes = new BreakingBytesRefBuilder(breaker, label); + values.set(index, bytes); + } else { + bytes.clear(); + } + return bytes; + } + + @Override + public final void close() { + Releasable allValues = values == null ? () -> {} : Releasables.wrap(LongStream.range(0, values.size()).mapToObj(i -> { + BreakingBytesRefBuilder bytes = values.get(i); + return bytes == null ? (Releasable) () -> {} : bytes; + }).toList().iterator()); + Releasables.close(allValues, values, common); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/IpBucketedSort.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/IpBucketedSort.java index 0fd38c18d7504..4eb31ea30db22 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/IpBucketedSort.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/IpBucketedSort.java @@ -9,7 +9,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.ByteArray; import org.elasticsearch.common.util.ByteUtils; import org.elasticsearch.compute.data.Block; @@ -29,7 +28,7 @@ * See {@link BucketedSort} for more information. */ public class IpBucketedSort implements Releasable { - private static final int IP_LENGTH = 16; + private static final int IP_LENGTH = 16; // Bytes. It's ipv6. // BytesRefs used in internal methods private final BytesRef scratch1 = new BytesRef(); @@ -39,18 +38,11 @@ public class IpBucketedSort implements Releasable { */ private final byte[] scratchBytes = new byte[IP_LENGTH]; - private final BigArrays bigArrays; - private final SortOrder order; - private final int bucketSize; - /** - * {@code true} if the bucket is in heap mode, {@code false} if - * it is still gathering. - */ - private final BitArray heapMode; + private final BucketedSortCommon common; /** * An array containing all the values on all buckets. The structure is as follows: *

- * For each bucket, there are bucketSize elements, based on the bucket id (0, 1, 2...). + * For each bucket, there are {@link BucketedSortCommon#bucketSize} elements, based on the bucket id (0, 1, 2...). * Then, for each bucket, it can be in 2 states: *

*
    @@ -77,10 +69,7 @@ public class IpBucketedSort implements Releasable { private ByteArray values; public IpBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { - this.bigArrays = bigArrays; - this.order = order; - this.bucketSize = bucketSize; - heapMode = new BitArray(0, bigArrays); + this.common = new BucketedSortCommon(bigArrays, order, bucketSize); boolean success = false; try { @@ -101,8 +90,8 @@ public IpBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) { */ public void collect(BytesRef value, int bucket) { assert value.length == IP_LENGTH; - long rootIndex = (long) bucket * bucketSize; - if (inHeapMode(bucket)) { + long rootIndex = common.rootIndex(bucket); + if (common.inHeapMode(bucket)) { if (betterThan(value, get(rootIndex, scratch1))) { set(rootIndex, value); downHeap(rootIndex, 0); @@ -110,49 +99,34 @@ public void collect(BytesRef value, int bucket) { return; } // Gathering mode - long requiredSize = (rootIndex + bucketSize) * IP_LENGTH; + long requiredSize = common.endIndex(rootIndex) * IP_LENGTH; if (values.size() < requiredSize) { grow(requiredSize); } int next = getNextGatherOffset(rootIndex); - assert 0 <= next && next < bucketSize - : "Expected next to be in the range of valid buckets [0 <= " + next + " < " + bucketSize + "]"; + common.assertValidNextOffset(next); long index = next + rootIndex; set(index, value); if (next == 0) { - heapMode.set(bucket); + common.enableHeapMode(bucket); heapify(rootIndex); } else { setNextGatherOffset(rootIndex, next - 1); } } - /** - * The order of the sort. - */ - public SortOrder getOrder() { - return order; - } - - /** - * The number of values to store per bucket. - */ - public int getBucketSize() { - return bucketSize; - } - /** * Get the first and last indexes (inclusive, exclusive) of the values for a bucket. * Returns [0, 0] if the bucket has never been collected. */ private Tuple getBucketValuesIndexes(int bucket) { - long rootIndex = (long) bucket * bucketSize; + long rootIndex = common.rootIndex(bucket); if (rootIndex >= values.size() / IP_LENGTH) { // We've never seen this bucket. return Tuple.tuple(0L, 0L); } - long start = inHeapMode(bucket) ? rootIndex : (rootIndex + getNextGatherOffset(rootIndex) + 1); - long end = rootIndex + bucketSize; + long start = startIndex(bucket, rootIndex); + long end = common.endIndex(rootIndex); return Tuple.tuple(start, end); } @@ -184,7 +158,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } // Used to sort the values in the bucket. - var bucketValues = new BytesRef[bucketSize]; + var bucketValues = new BytesRef[common.bucketSize]; try (var builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { for (int s = 0; s < selected.getPositionCount(); s++) { @@ -211,7 +185,7 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { Arrays.sort(bucketValues, 0, (int) size); builder.beginPositionEntry(); - if (order == SortOrder.ASC) { + if (common.order == SortOrder.ASC) { for (int i = 0; i < size; i++) { builder.appendBytesRef(bucketValues[i]); } @@ -226,11 +200,11 @@ public Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - /** - * Is this bucket a min heap {@code true} or in gathering mode {@code false}? - */ - private boolean inHeapMode(int bucket) { - return heapMode.get(bucket); + private long startIndex(int bucket, long rootIndex) { + if (common.inHeapMode(bucket)) { + return rootIndex; + } + return rootIndex + getNextGatherOffset(rootIndex) + 1; } /** @@ -267,7 +241,7 @@ private void setNextGatherOffset(long rootIndex, int offset) { * {@link SortOrder#ASC} and "higher" for {@link SortOrder#DESC}. */ private boolean betterThan(BytesRef lhs, BytesRef rhs) { - return getOrder().reverseMul() * lhs.compareTo(rhs) < 0; + return common.order.reverseMul() * lhs.compareTo(rhs) < 0; } /** @@ -296,17 +270,17 @@ private void swap(long lhs, long rhs) { */ private void grow(long minSize) { long oldMax = values.size() / IP_LENGTH; - values = bigArrays.grow(values, minSize); + values = common.bigArrays.grow(values, minSize); // Set the next gather offsets for all newly allocated buckets. - setNextGatherOffsets(oldMax - (oldMax % bucketSize)); + setNextGatherOffsets(oldMax - (oldMax % common.bucketSize)); } /** * Maintain the "next gather offsets" for newly allocated buckets. */ private void setNextGatherOffsets(long startingAt) { - int nextOffset = bucketSize - 1; - for (long bucketRoot = startingAt; bucketRoot < values.size() / IP_LENGTH; bucketRoot += bucketSize) { + int nextOffset = common.bucketSize - 1; + for (long bucketRoot = startingAt; bucketRoot < values.size() / IP_LENGTH; bucketRoot += common.bucketSize) { setNextGatherOffset(bucketRoot, nextOffset); } } @@ -334,7 +308,7 @@ private void setNextGatherOffsets(long startingAt) { * @param rootIndex the index the start of the bucket */ private void heapify(long rootIndex) { - int maxParent = bucketSize / 2 - 1; + int maxParent = common.bucketSize / 2 - 1; for (int parent = maxParent; parent >= 0; parent--) { downHeap(rootIndex, parent); } @@ -354,14 +328,14 @@ private void downHeap(long rootIndex, int parent) { long worstIndex = parentIndex; int leftChild = parent * 2 + 1; long leftIndex = rootIndex + leftChild; - if (leftChild < bucketSize) { + if (leftChild < common.bucketSize) { if (betterThan(get(worstIndex, scratch1), get(leftIndex, scratch2))) { worst = leftChild; worstIndex = leftIndex; } int rightChild = leftChild + 1; long rightIndex = rootIndex + rightChild; - if (rightChild < bucketSize && betterThan(get(worstIndex, scratch1), get(rightIndex, scratch2))) { + if (rightChild < common.bucketSize && betterThan(get(worstIndex, scratch1), get(rightIndex, scratch2))) { worst = rightChild; worstIndex = rightIndex; } @@ -400,6 +374,6 @@ private void set(long index, BytesRef value) { @Override public final void close() { - Releasables.close(values, heapMode); + Releasables.close(values, common); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefAggregatorFunctionTests.java new file mode 100644 index 0000000000000..2815dd70e8124 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefAggregatorFunctionTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; + +abstract class AbstractTopBytesRefAggregatorFunctionTests extends AggregatorFunctionTestCase { + static final int LIMIT = 100; + + @Override + protected final SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceBytesRefBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToObj(l -> randomValue())); + } + + protected abstract BytesRef randomValue(); + + @Override + public final void assertSimpleOutput(List input, Block result) { + Object[] values = input.stream().flatMap(AggregatorFunctionTestCase::allBytesRefs).sorted().limit(LIMIT).toArray(Object[]::new); + assertThat((List) BlockUtils.toJavaObject(result, 0), contains(values)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..45c8a23dfc1c0 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AbstractTopBytesRefGroupingAggregatorFunctionTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.LongBytesRefTupleBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; + +public abstract class AbstractTopBytesRefGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + static final int LIMIT = 100; + + @Override + protected final SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new LongBytesRefTupleBlockSourceOperator( + blockFactory, + IntStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomValue())) + ); + } + + protected abstract BytesRef randomValue(); + + @Override + protected final void assertSimpleGroup(List input, Block result, int position, Long group) { + Object[] values = input.stream().flatMap(b -> allBytesRefs(b, group)).sorted().limit(LIMIT).toArray(Object[]::new); + if (values.length == 0) { + assertThat(result.isNull(position), equalTo(true)); + } else if (values.length == 1) { + assertThat(BlockUtils.toJavaObject(result, position), equalTo(values[0])); + } else { + assertThat((List) BlockUtils.toJavaObject(result, position), contains(values)); + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionTests.java new file mode 100644 index 0000000000000..732229c98f9c7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefAggregatorFunctionTests.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; + +import java.util.List; + +public class TopBytesRefAggregatorFunctionTests extends AbstractTopBytesRefAggregatorFunctionTests { + @Override + protected BytesRef randomValue() { + return new BytesRef(randomAlphaOfLength(10)); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new TopBytesRefAggregatorFunctionSupplier(inputChannels, LIMIT, true); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "top of bytes"; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..4932e1abef46d --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBytesRefGroupingAggregatorFunctionTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.esql.core.type.DataType; + +import java.util.List; + +public class TopBytesRefGroupingAggregatorFunctionTests extends AbstractTopBytesRefGroupingAggregatorFunctionTests { + @Override + protected BytesRef randomValue() { + return new BytesRef(randomAlphaOfLength(6)); + } + + @Override + protected final AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new TopBytesRefAggregatorFunctionSupplier(inputChannels, LIMIT, true); + } + + @Override + protected DataType acceptedDataType() { + return DataType.KEYWORD; + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "top of bytes"; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpAggregatorFunctionTests.java index 1594f66ed9fe2..840e4cf9af961 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpAggregatorFunctionTests.java @@ -9,26 +9,13 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BlockUtils; -import org.elasticsearch.compute.operator.SequenceBytesRefBlockSourceOperator; -import org.elasticsearch.compute.operator.SourceOperator; import java.util.List; -import java.util.stream.IntStream; - -import static org.hamcrest.Matchers.contains; - -public class TopIpAggregatorFunctionTests extends AggregatorFunctionTestCase { - private static final int LIMIT = 100; +public class TopIpAggregatorFunctionTests extends AbstractTopBytesRefAggregatorFunctionTests { @Override - protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new SequenceBytesRefBlockSourceOperator( - blockFactory, - IntStream.range(0, size).mapToObj(l -> new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean())))) - ); + protected BytesRef randomValue() { + return new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); } @Override @@ -40,10 +27,4 @@ protected AggregatorFunctionSupplier aggregatorFunction(List inputChann protected String expectedDescriptionOfAggregator() { return "top of ips"; } - - @Override - public void assertSimpleOutput(List input, Block result) { - Object[] values = input.stream().flatMap(b -> allBytesRefs(b)).sorted().limit(LIMIT).toArray(Object[]::new); - assertThat((List) BlockUtils.toJavaObject(result, 0), contains(values)); - } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunctionTests.java index da55ff2d7aab3..02bf6b667192b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopIpGroupingAggregatorFunctionTests.java @@ -9,36 +9,14 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.BlockUtils; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.LongBytesRefTupleBlockSourceOperator; -import org.elasticsearch.compute.operator.SourceOperator; -import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.esql.core.type.DataType; import java.util.List; -import java.util.stream.IntStream; - -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.equalTo; - -public class TopIpGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { - private static final int LIMIT = 100; +public class TopIpGroupingAggregatorFunctionTests extends AbstractTopBytesRefGroupingAggregatorFunctionTests { @Override - protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { - return new LongBytesRefTupleBlockSourceOperator( - blockFactory, - IntStream.range(0, size) - .mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))))) - ); - } - - @Override - protected DataType acceptedDataType() { - return DataType.IP; + protected BytesRef randomValue() { + return new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); } @Override @@ -47,19 +25,12 @@ protected AggregatorFunctionSupplier aggregatorFunction(List inputChann } @Override - protected String expectedDescriptionOfAggregator() { - return "top of ips"; + protected DataType acceptedDataType() { + return DataType.IP; } @Override - protected void assertSimpleGroup(List input, Block result, int position, Long group) { - Object[] values = input.stream().flatMap(b -> allBytesRefs(b, group)).sorted().limit(LIMIT).toArray(Object[]::new); - if (values.length == 0) { - assertThat(result.isNull(position), equalTo(true)); - } else if (values.length == 1) { - assertThat(BlockUtils.toJavaObject(result, position), equalTo(values[0])); - } else { - assertThat((List) BlockUtils.toJavaObject(result, position), contains(values)); - } + protected String expectedDescriptionOfAggregator() { + return "top of ips"; } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSortTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSortTests.java new file mode 100644 index 0000000000000..7a4e6658cd646 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BytesRefBucketedSortTests.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.data.sort; + +import org.apache.lucene.document.InetAddressPoint; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.search.sort.SortOrder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class BytesRefBucketedSortTests extends BucketedSortTestCase { + @Override + protected BytesRefBucketedSort build(SortOrder sortOrder, int bucketSize) { + BigArrays bigArrays = bigArrays(); + return new BytesRefBucketedSort( + bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), + "test", + bigArrays, + sortOrder, + bucketSize + ); + } + + @Override + protected BytesRef randomValue() { + return new BytesRef(InetAddressPoint.encode(randomIp(randomBoolean()))); + } + + @Override + protected List threeSortedValues() { + List values = new ArrayList<>(); + values.add(new BytesRef(randomAlphaOfLength(10))); + values.add(new BytesRef(randomAlphaOfLength(11))); + values.add(new BytesRef(randomAlphaOfLength(1))); + Collections.sort(values); + return values; + } + + @Override + protected void collect(BytesRefBucketedSort sort, BytesRef value, int bucket) { + sort.collect(value, bucket); + } + + @Override + protected void merge(BytesRefBucketedSort sort, int groupId, BytesRefBucketedSort other, int otherGroupId) { + sort.merge(groupId, other, otherGroupId); + } + + @Override + protected Block toBlock(BytesRefBucketedSort sort, BlockFactory blockFactory, IntVector selected) { + return sort.toBlock(blockFactory, selected); + } + + @Override + protected void assertBlockTypeAndValues(Block block, List values) { + assertThat(block.elementType(), equalTo(ElementType.BYTES_REF)); + var typedBlock = (BytesRefBlock) block; + var scratch = new BytesRef(); + for (int i = 0; i < values.size(); i++) { + assertThat("expected value on block position " + i, typedBlock.getBytesRef(i, scratch), equalTo(values.get(i))); + } + } +} diff --git a/x-pack/plugin/esql/qa/server/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/mixed/MixedClusterEsqlSpecIT.java b/x-pack/plugin/esql/qa/server/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/mixed/MixedClusterEsqlSpecIT.java index d0d6d5fa49c42..08b4794b740d6 100644 --- a/x-pack/plugin/esql/qa/server/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/mixed/MixedClusterEsqlSpecIT.java +++ b/x-pack/plugin/esql/qa/server/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/mixed/MixedClusterEsqlSpecIT.java @@ -72,6 +72,10 @@ public MixedClusterEsqlSpecIT( protected void shouldSkipTest(String testName) throws IOException { super.shouldSkipTest(testName); assumeTrue("Test " + testName + " is skipped on " + bwcVersion, isEnabled(testName, instructions, bwcVersion)); + assumeFalse( + "Skip META tests on mixed version clusters because we change it too quickly", + testCase.requiredCapabilities.contains("meta") + ); if (mode == ASYNC) { assumeTrue("Async is not supported on " + bwcVersion, supportsAsync()); } diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java index 3e799730f7269..8d54dc63598f0 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java @@ -112,6 +112,10 @@ protected void shouldSkipTest(String testName) throws IOException { ); assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains("inlinestats")); assumeFalse("INLINESTATS not yet supported in CCS", testCase.requiredCapabilities.contains("inlinestats_v2")); + assumeFalse( + "Skip META tests on mixed version clusters because we change it too quickly", + testCase.requiredCapabilities.contains("meta") + ); } private TestFeatureService remoteFeaturesService() throws IOException { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 6909f0aeb42f5..2b3fa9dec797d 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -1,5 +1,7 @@ -metaFunctionsSynopsis#[skip:-8.15.99] +metaFunctionsSynopsis required_capability: date_nanos_type +required_capability: meta + meta functions | keep synopsis; synopsis:keyword @@ -118,14 +120,16 @@ double tau() "keyword|text to_upper(str:keyword|text)" "version to_ver(field:keyword|text|version)" "version to_version(field:keyword|text|version)" -"boolean|double|integer|long|date|ip top(field:boolean|double|integer|long|date|ip, limit:integer, order:keyword)" +"boolean|double|integer|long|date|ip|keyword|text top(field:boolean|double|integer|long|date|ip|keyword|text, limit:integer, order:keyword)" "keyword|text trim(string:keyword|text)" "boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)" "double weighted_avg(number:double|integer|long, weight:double|integer|long)" ; -metaFunctionsArgs#[skip:-8.15.99] +metaFunctionsArgs +required_capability: meta required_capability: date_nanos_type + META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, argNames, argTypes, argDescriptions; @@ -246,13 +250,15 @@ to_unsigned_lo|field |"boolean|date|keyword|text|d to_upper |str |"keyword|text" |String expression. If `null`, the function returns `null`. to_ver |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. to_version |field |"keyword|text|version" |Input value. The input can be a single- or multi-valued column or an expression. -top |[field, limit, order] |["boolean|double|integer|long|date|ip", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.] +top |[field, limit, order] |["boolean|double|integer|long|date|ip|keyword|text", integer, keyword] |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.] trim |string |"keyword|text" |String expression. If `null`, the function returns `null`. values |field |"boolean|date|double|integer|ip|keyword|long|text|version" |[""] weighted_avg |[number, weight] |["double|integer|long", "double|integer|long"] |[A numeric value., A numeric weight.] ; -metaFunctionsDescription#[skip:-8.15.99] +metaFunctionsDescription +required_capability: meta + META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, description @@ -380,8 +386,10 @@ values |Returns all values in a group as a multivalued field. The order o weighted_avg |The weighted average of a numeric expression. ; -metaFunctionsRemaining#[skip:-8.15.99] +metaFunctionsRemaining +required_capability: meta required_capability: date_nanos_type + META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, * @@ -504,13 +512,15 @@ to_unsigned_lo|unsigned_long to_upper |"keyword|text" |false |false |false to_ver |version |false |false |false to_version |version |false |false |false -top |"boolean|double|integer|long|date|ip" |[false, false, false] |false |true +top |"boolean|double|integer|long|date|ip|keyword|text" |[false, false, false] |false |true trim |"keyword|text" |false |false |false values |"boolean|date|double|integer|ip|keyword|long|text|version" |false |false |true weighted_avg |"double" |[false, false] |false |true ; -metaFunctionsFiltered#[skip:-8.15.99] +metaFunctionsFiltered +required_capability: meta + META FUNCTIONS | WHERE STARTS_WITH(name, "sin") ; @@ -520,7 +530,9 @@ sin |"double sin(angle:double|integer|long|unsigned_long)" |angle sinh |"double sinh(number:double|integer|long|unsigned_long)" |number |"double|integer|long|unsigned_long" | "Numeric expression. If `null`, the function returns `null`." | double | "Returns the {wikipedia}/Hyperbolic_functions[hyperbolic sine] of a number." | false | false | false ; -countFunctions#[skip:-8.15.99] +countFunctions +required_capability: meta + meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec index 86f91adf506d1..80d11425c5bb6 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec @@ -224,3 +224,77 @@ a:ip | b:ip | c:ip | host:keyword [fe82::cae2:65ff:fece:fec0, fe81::cae2:65ff:fece:feb9] | [fe82::cae2:65ff:fece:fec0, fe81::cae2:65ff:fece:feb9] | [fe82::cae2:65ff:fece:fec0, fe81::cae2:65ff:fece:feb9] | epsilon [fe80::cae2:65ff:fece:feb9, fe80::cae2:65ff:fece:feb9] | [fe80::cae2:65ff:fece:feb9, fe80::cae2:65ff:fece:feb9] | [fe81::cae2:65ff:fece:feb9, 127.0.0.3] | gamma ; + +topKeywords +required_capability: agg_top +required_capability: agg_top_string_support + +FROM employees +| EVAL calc = SUBSTRING(last_name, 2) +| STATS + first_name = TOP(first_name, 3, "asc"), + last_name = TOP(calc, 3, "asc"), + evil = TOP(CASE(languages <= 2, first_name, last_name), 3, "desc"); + + first_name:keyword | last_name:keyword | evil:keyword +[Alejandro, Amabile, Anneke] | [acello, addadi, aek] | [Zschoche, Zielinski, Zhongwei] +; + +topKeywordsGrouping +required_capability: agg_top +required_capability: agg_top_string_support + +FROM employees +| EVAL calc = SUBSTRING(last_name, 2) +| STATS + first_name = TOP(first_name, 3, "asc"), + last_name = TOP(calc, 3, "asc"), + evil = TOP(CASE(languages <= 2, first_name, last_name), 3, "desc") + BY job_positions +| SORT job_positions +| LIMIT 3; + + first_name:keyword | last_name:keyword | evil:keyword | job_positions:keyword + [Arumugam, Bojan, Domenick] | [acello, aine, akrucki] | [Zhongwei, Yinghua, Valdiodio] | Accountant +[Alejandro, Charlene, Danel] | [andell, cAlpine, eistad] | [Stamatiou, Sluis, Sidou] | Architect + [Basil, Breannda, Hidefumi] | [aine, alabarba, ierman] | [Tramer, Syrzycki, Stamatiou] | Business Analyst +; + +topText +required_capability: agg_top +required_capability: agg_top_string_support +# we don't need MATCH, but the loader for books.csv is busted in CsvTests +required_capability: match_operator + +FROM books +| EVAL calc = TRIM(SUBSTRING(title, 2, 5)) +| STATS + title = TOP(title, 3, "desc"), + calc = TOP(calc, 3, "asc"), + evil = TOP(CASE(year < 1980, title, author), 3, "desc"); + +title:text | calc:keyword | evil:text +[Worlds of Exile and Illusion: Three Complete Novels of the Hainish Series in One Volume--Rocannon's World, Planet of Exile, City of Illusions, Woman-The Full Story: A Dynamic Celebration of Freedoms, Winter notes on summer impressions] | ["'Bria", "Gent", "HE UN"] | [William Faulkner, William Faulkner, William Faulkner] +; + +topTextGrouping +required_capability: agg_top +required_capability: agg_top_string_support +# we don't need MATCH, but the loader for books.csv is busted in CsvTests +required_capability: match_operator + +FROM books +| EVAL calc = TRIM(SUBSTRING(title, 2, 5)) +| STATS + title = TOP(title, 3, "desc"), + calc = TOP(calc, 3, "asc"), + evil = TOP(CASE(year < 1980, title, author), 3, "desc") + BY author +| SORT author +| LIMIT 3; + + title:text | calc:keyword | evil:text | author:text + A Tolkien Compass: Including J. R. R. Tolkien's Guide to the Names in The Lord of the Rings | Tolk | A Tolkien Compass: Including J. R. R. Tolkien's Guide to the Names in The Lord of the Rings | Agnes Perkins + The Lord of the Rings Poster Collection: Six Paintings by Alan Lee (No. 1) | he Lo | [J. R. R. Tolkien, Alan Lee] | Alan Lee +A Gentle Creature and Other Stories: White Nights, A Gentle Creature, and The Dream of a Ridiculous Man (The World's Classics) | Gent | [W. J. Leatherbarrow, Fyodor Dostoevsky, Alan Myers] | Alan Myers +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 597c349273eb2..31a3096c13cd2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -97,6 +97,11 @@ public enum Cap { */ AGG_TOP_IP_SUPPORT, + /** + * Support for {@code keyword} and {@code text} fields in {@code TOP} aggregation. + */ + AGG_TOP_STRING_SUPPORT, + /** * {@code CASE} properly handling multivalue conditions. */ @@ -251,6 +256,13 @@ public enum Cap { */ MATCH_OPERATOR(true), + /** + * Support for the {@code META} keyword. Tests with this tag are + * intentionally excluded from mixed version clusters because we + * continually add functions, so they constantly fail if we don't. + */ + META, + /** * Add CombineBinaryComparisons rule. */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java index 4927acc3e1cd9..cb1b0f0cad895 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopBooleanAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.TopBytesRefAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopIpAggregatorFunctionSupplier; @@ -48,7 +49,7 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp private static final String ORDER_DESC = "DESC"; @FunctionInfo( - returnType = { "boolean", "double", "integer", "long", "date", "ip" }, + returnType = { "boolean", "double", "integer", "long", "date", "ip", "keyword", "text" }, description = "Collects the top values for a field. Includes repeated values.", isAggregation = true, examples = @Example(file = "stats_top", tag = "top") @@ -57,7 +58,7 @@ public Top( Source source, @Param( name = "field", - type = { "boolean", "double", "integer", "long", "date", "ip" }, + type = { "boolean", "double", "integer", "long", "date", "ip", "keyword", "text" }, description = "The field to collect the top values for." ) Expression field, @Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit, @@ -125,12 +126,14 @@ protected TypeResolution resolveType() { dt -> dt == DataType.BOOLEAN || dt == DataType.DATETIME || dt == DataType.IP + || DataType.isString(dt) || (dt.isNumeric() && dt != DataType.UNSIGNED_LONG), sourceText(), FIRST, "boolean", "date", "ip", + "string", "numeric except unsigned_long or counter types" ).and(isNotNullAndFoldable(limitField(), sourceText(), SECOND)) .and(isType(limitField(), dt -> dt == DataType.INTEGER, sourceText(), SECOND, "integer")) @@ -190,6 +193,9 @@ public AggregatorFunctionSupplier supplier(List inputChannels) { if (type == DataType.IP) { return new TopIpAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue()); } + if (DataType.isString(type)) { + return new TopBytesRefAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue()); + } throw EsqlIllegalArgumentException.illegalDataType(type); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 60bf4be1d2b03..13ce9ba77cc71 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -170,7 +170,7 @@ private static Stream, Tuple>> typeAndNames(Class // TODO can't we figure this out from the function itself? types = List.of("Int", "Long", "Double", "Boolean", "BytesRef"); } else if (Top.class.isAssignableFrom(clazz)) { - types = List.of("Boolean", "Int", "Long", "Double", "Ip"); + types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); } else if (Rate.class.isAssignableFrom(clazz)) { types = List.of("Int", "Long", "Double"); } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index f64d6a200a031..f7bf338caa099 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -46,7 +46,9 @@ public static Iterable parameters() { MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), MultiRowTestCaseSupplier.dateCases(1, 1000), MultiRowTestCaseSupplier.booleanCases(1, 1000), - MultiRowTestCaseSupplier.ipCases(1, 1000) + MultiRowTestCaseSupplier.ipCases(1, 1000), + MultiRowTestCaseSupplier.stringCases(1, 1000, DataType.KEYWORD), + MultiRowTestCaseSupplier.stringCases(1, 1000, DataType.TEXT) ) .flatMap(List::stream) .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order))