diff --git a/docs/changelog/130510.yaml b/docs/changelog/130510.yaml new file mode 100644 index 0000000000000..01426b6b8e4e3 --- /dev/null +++ b/docs/changelog/130510.yaml @@ -0,0 +1,5 @@ +pr: 130510 +summary: Add fast path for single value in VALUES aggregator +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 79077b6628105..120a77fda92c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -7,11 +7,15 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; @@ -19,13 +23,16 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for BytesRef. @@ -129,47 +136,146 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + } + + int getValue(int index) { + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongLongHash values; BytesRefHash bytes; + private IntArray firstValues; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); - LongLongHash _values = null; - BytesRefHash _bytes = null; + boolean success = false; try { - _values = new LongLongHash(1, driverContext.bigArrays()); - _bytes = new BytesRefHash(1, driverContext.bigArrays()); - - values = _values; - bytes = _bytes; - - _values = null; - _bytes = null; + this.bytes = new BytesRefHash(1, driverContext.bigArrays()); + this.firstValues = driverContext.bigArrays().newIntArray(1, true); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_values, _bytes); + if (success == false) { + this.close(); + } } } @@ -178,14 +284,28 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } - void addValueOrdinal(int groupId, long valueOrdinal) { - values.add(groupId, valueOrdinal); - maxGroupId = Math.max(maxGroupId, groupId); + void addValueOrdinal(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + int current = firstValues.get(groupId) - 1; + if (current < 0) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (current != valueOrdinal) { + nextValues.addValue(groupId, valueOrdinal); + } + } else { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } } void addValue(int groupId, BytesRef v) { - values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); - maxGroupId = Math.max(maxGroupId, groupId); + int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v))); + addValueOrdinal(groupId, valueOrdinal); + } + + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative } /** @@ -193,159 +313,81 @@ void addValue(int groupId, BytesRef v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { - return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } else { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } + nextValues.prepareForEmitting(selected); + if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) { + return buildOrdinalOutputBlock(blockFactory, selected); + } else { + return buildOutputBlock(blockFactory, selected); } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } - } - - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ BytesRef scratch = new BytesRef(); + final int[] nextValueCounts = nextValues.selectedCounts; try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendBytesRef(getValue(ids[start], scratch)); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendBytesRef(getValue(ids[i], scratch)); - } - builder.endPositionEntry(); + int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { + builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendBytesRef(bytes.get(nextValue, scratch)); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) { BytesRefVector dict = null; IntBlock ordinals = null; BytesRefBlock result = null; var dictArray = bytes.takeBytesRefsOwnership(); bytes = null; // transfer ownership to dictArray - try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size()); + final int[] nextValueCounts = nextValues.selectedCounts; + try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) { + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); - } - builder.endPositionEntry(); + if (firstValues.size() < group) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + builder.appendInt(nextValues.getValue(i)); } + builder.endPositionEntry(); } - start = end; + nextValuesStart = nextValuesEnd; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -359,18 +401,9 @@ Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int } } - BytesRef getValue(int valueId, BytesRef scratch) { - return bytes.get(values.getKey2(valueId), scratch); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(bytes, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 5f01ad586976f..e9c217ff172b7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -7,20 +7,32 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for double. @@ -106,34 +118,140 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongLongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongLongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, double v) { + hashes.add(groupId, Double.doubleToLongBits(v)); + } + + double getValue(int index) { + return Double.longBitsToDouble(hashes.getKey2(ids[index])); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongLongHash values; + DoubleArray firstValues; + private BitArray seen; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); - values = new LongLongHash(1, driverContext.bigArrays()); + boolean success = false; + try { + this.firstValues = driverContext.bigArrays().newDoubleArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; + } finally { + if (success == false) { + this.close(); + } + } } @Override @@ -142,151 +260,90 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, double v) { - values.add(groupId, Double.doubleToLongBits(v)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + seen.fill(0, maxGroupId + 1, true); + } + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } - /** - * Builds a {@link Block} with the unique values collected for the {@code #selected} - * groups. This is the implementation of the final and intermediate results of the agg. - */ - Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ + Block toBlock(BlockFactory blockFactory, IntVector selected) { + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendDouble(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendDouble(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + double firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendDouble(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendDouble(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendDouble(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - double getValue(int valueId) { - return Double.longBitsToDouble(values.getKey2(valueId)); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 9acaaccd80a85..cd8ed8a854763 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -7,19 +7,32 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; +import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for float. @@ -105,34 +118,147 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, float v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + } + + float getValue(int index) { + long both = hashes.get(ids[index]); + return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL)); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongHash values; + FloatArray firstValues; + private BitArray seen; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); - values = new LongHash(1, driverContext.bigArrays()); + boolean success = false; + try { + this.firstValues = driverContext.bigArrays().newFloatArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; + } finally { + if (success == false) { + this.close(); + } + } } @Override @@ -141,158 +267,90 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, float v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + seen.fill(0, maxGroupId + 1, true); + } + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } - /** - * Builds a {@link Block} with the unique values collected for the {@code #selected} - * groups. This is the implementation of the final and intermediate results of the agg. - */ - Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ + Block toBlock(BlockFactory blockFactory, IntVector selected) { + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendFloat(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendFloat(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + float firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendFloat(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendFloat(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendFloat(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - float getValue(int valueId) { - long both = values.get(valueId); - return Float.intBitsToFloat((int) both); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 3690df739552b..f6c05104f749b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -7,19 +7,32 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; +import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for int. @@ -105,34 +118,147 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + } + + int getValue(int index) { + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongHash values; + IntArray firstValues; + private BitArray seen; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); - values = new LongHash(1, driverContext.bigArrays()); + boolean success = false; + try { + this.firstValues = driverContext.bigArrays().newIntArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; + } finally { + if (success == false) { + this.close(); + } + } } @Override @@ -141,158 +267,90 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, int v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + seen.fill(0, maxGroupId + 1, true); + } + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } - /** - * Builds a {@link Block} with the unique values collected for the {@code #selected} - * groups. This is the implementation of the final and intermediate results of the agg. - */ - Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ + Block toBlock(BlockFactory blockFactory, IntVector selected) { + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendInt(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - int getValue(int valueId) { - long both = values.get(valueId); - return (int) both; - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 9514e9147e05d..93bfdce654d1e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -7,20 +7,32 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for long. @@ -106,34 +118,140 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongLongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongLongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, long v) { + hashes.add(groupId, v); + } + + long getValue(int index) { + return hashes.getKey2(ids[index]); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongLongHash values; + LongArray firstValues; + private BitArray seen; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); - values = new LongLongHash(1, driverContext.bigArrays()); + boolean success = false; + try { + this.firstValues = driverContext.bigArrays().newLongArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; + } finally { + if (success == false) { + this.close(); + } + } } @Override @@ -142,151 +260,90 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, long v) { - values.add(groupId, v); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + seen.fill(0, maxGroupId + 1, true); + } + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } - /** - * Builds a {@link Block} with the unique values collected for the {@code #selected} - * groups. This is the implementation of the final and intermediate results of the agg. - */ - Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + /** + * Builds a {@link Block} with the unique values collected for the {@code #selected} + * groups. This is the implementation of the final and intermediate results of the agg. + */ + Block toBlock(BlockFactory blockFactory, IntVector selected) { + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendLong(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendLong(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + long firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendLong(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendLong(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendLong(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - long getValue(int valueId) { - return values.getKey2(valueId); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index d92ac5fa0afce..017c5540c98f0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -7,44 +7,32 @@ package org.elasticsearch.compute.aggregation; -$if(BytesRef)$ +// begin generated imports import org.apache.lucene.util.BytesRef; -$endif$ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; -$if(BytesRef)$ +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; -$else$ import org.elasticsearch.common.util.LongHash; -$endif$ -$if(long||double||BytesRef)$ import org.elasticsearch.common.util.LongLongHash; -$endif$ -$if(BytesRef)$ +import org.elasticsearch.common.util.$if(BytesRef)$Int$else$$Type$$endif$Array; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; -$endif$ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -$if(int||double||float)$ import org.elasticsearch.compute.data.$Type$Block; -$elseif(BytesRef)$ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.IntBlock; -$endif$ import org.elasticsearch.compute.data.IntVector; -$if(long)$ import org.elasticsearch.compute.data.LongBlock; -$endif$ -$if(BytesRef)$ import org.elasticsearch.compute.data.OrdinalBytesRefBlock; -$endif$ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for $type$. @@ -204,62 +192,190 @@ $endif$ } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; +$if(long||double)$ + private final LongLongHash hashes; +$else$ + private final LongHash hashes; +$endif$ + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new Long$if(long||double)$Long$endif$Hash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, $if(BytesRef)$int$else$$type$$endif$ v) { +$if(long)$ + hashes.add(groupId, v); +$elseif(double)$ + hashes.add(groupId, Double.doubleToLongBits(v)); +$elseif(int||BytesRef)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); +$elseif(float)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$endif$ + } + + $if(BytesRef)$int$else$$type$$endif$ getValue(int index) { +$if(long)$ + return hashes.getKey2(ids[index]); +$elseif(double)$ + return Double.longBitsToDouble(hashes.getKey2(ids[index])); +$elseif(float)$ + long both = hashes.get(ids[index]); + return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL)); +$elseif(BytesRef||int)$ + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); +$endif$ + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { +$if(long||double)$ + int group = (int) hashes.getKey1(id); +$elseif(float||int||BytesRef)$ + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); +$endif$ + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { +$if(long||double)$ + int group = (int) hashes.getKey1(id); +$elseif(float||int||BytesRef)$ + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); +$endif$ + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; -$if(long||double)$ - private final LongLongHash values; - -$elseif(BytesRef)$ - private final LongLongHash values; +$if(BytesRef)$ BytesRefHash bytes; - -$elseif(int||float)$ - private final LongHash values; - + private IntArray firstValues; +$else$ + $Type$Array firstValues; + private BitArray seen; + private int maxGroupId = -1; $endif$ + private final NextValues nextValues; + private GroupingState(DriverContext driverContext) { this.blockFactory = driverContext.blockFactory(); -$if(long||double)$ - values = new LongLongHash(1, driverContext.bigArrays()); -$elseif(BytesRef)$ - LongLongHash _values = null; - BytesRefHash _bytes = null; + boolean success = false; try { - _values = new LongLongHash(1, driverContext.bigArrays()); - _bytes = new BytesRefHash(1, driverContext.bigArrays()); - - values = _values; - bytes = _bytes; - - _values = null; - _bytes = null; +$if(BytesRef)$ + this.bytes = new BytesRefHash(1, driverContext.bigArrays()); + this.firstValues = driverContext.bigArrays().newIntArray(1, true); +$else$ + this.firstValues = driverContext.bigArrays().new$Type$Array(1, false); +$endif$ + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_values, _bytes); + if (success == false) { + this.close(); + } } -$elseif(int||float)$ - values = new LongHash(1, driverContext.bigArrays()); -$endif$ } @Override @@ -268,210 +384,169 @@ $endif$ } $if(BytesRef)$ - void addValueOrdinal(int groupId, long valueOrdinal) { - values.add(groupId, valueOrdinal); - maxGroupId = Math.max(maxGroupId, groupId); + void addValueOrdinal(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + int current = firstValues.get(groupId) - 1; + if (current < 0) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (current != valueOrdinal) { + nextValues.addValue(groupId, valueOrdinal); + } + } else { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } } $endif$ void addValue(int groupId, $type$ v) { -$if(long)$ - values.add(groupId, v); -$elseif(double)$ - values.add(groupId, Double.doubleToLongBits(v)); -$elseif(BytesRef)$ - values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); -$elseif(int)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); -$elseif(float)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$if(BytesRef)$ + int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v))); + addValueOrdinal(groupId, valueOrdinal); +$else$ + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + seen.fill(0, maxGroupId + 1, true); + } + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } $endif$ - maxGroupId = Math.max(maxGroupId, groupId); } +$if(BytesRef)$ + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative + } + +$else$ + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + +$endif$ /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { + nextValues.prepareForEmitting(selected); $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { - return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } else { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } + if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) { + return buildOrdinalOutputBlock(blockFactory, selected); + } else { + return buildOutputBlock(blockFactory, selected); + } $else$ - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); + return buildOutputBlock(blockFactory, selected); $endif$ - } } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ - int group = (int) values.getKey1(id); -$elseif(float||int)$ - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); -$endif$ - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - $if(long||BytesRef||double)$ - int group = (int) values.getKey1(id); - $elseif(float||int)$ - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - $endif$ - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } - } - - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ $if(BytesRef)$ BytesRef scratch = new BytesRef(); $endif$ + final int[] nextValueCounts = nextValues.selectedCounts; try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$)); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$)); - } - builder.endPositionEntry(); +$if(BytesRef)$ + int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } +$else$ + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + $type$ firstValue = firstValues.get(group); +$endif$ + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$); + } else { + builder.beginPositionEntry(); + builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.append$Type$($if(BytesRef)$bytes.get(nextValue, scratch)$else$nextValue$endif$); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } $if(BytesRef)$ - Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) { BytesRefVector dict = null; IntBlock ordinals = null; BytesRefBlock result = null; var dictArray = bytes.takeBytesRefsOwnership(); bytes = null; // transfer ownership to dictArray - try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size()); + final int[] nextValueCounts = nextValues.selectedCounts; + try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) { + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); - } - builder.endPositionEntry(); + if (firstValues.size() < group) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + builder.appendInt(nextValues.getValue(i)); } + builder.endPositionEntry(); } - start = end; + nextValuesStart = nextValuesEnd; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -486,34 +561,9 @@ $if(BytesRef)$ } $endif$ - $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) { -$if(BytesRef)$ - return bytes.get(values.getKey2(valueId), scratch); -$elseif(long)$ - return values.getKey2(valueId); -$elseif(double)$ - return Double.longBitsToDouble(values.getKey2(valueId)); -$elseif(float)$ - long both = values.get(valueId); - return Float.intBitsToFloat((int) both); -$elseif(int)$ - long both = values.get(valueId); - return (int) both; -$endif$ - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { -$if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); -$else$ - Releasables.closeExpectNoException(values); -$endif$ + Releasables.closeExpectNoException($if(BytesRef)$bytes$else$seen$endif$, firstValues, nextValues); } } }