From 5860c6ea21f03dc68cc47d4a0f9e1499e6e39a1f Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 2 Jul 2025 19:21:46 -0700 Subject: [PATCH 1/7] Add fast path for single value in VALUES --- .../aggregation/ValuesBytesRefAggregator.java | 94 +++++++++++----- .../aggregation/ValuesDoubleAggregator.java | 12 +- .../aggregation/ValuesFloatAggregator.java | 12 +- .../aggregation/ValuesIntAggregator.java | 12 +- .../aggregation/ValuesLongAggregator.java | 12 +- .../ValuesBytesRefAggregators.java | 12 +- .../aggregation/X-ValuesAggregator.java.st | 106 ++++++++++++++---- 7 files changed, 188 insertions(+), 72 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 51195578ac363..63166fb5bdf88 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongLongHash; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; @@ -76,7 +77,7 @@ public static GroupingAggregatorFunction.AddInput wrapAddInput( } public static void combine(GroupingState state, int groupId, BytesRef v) { - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(v)))); } public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { @@ -90,6 +91,14 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { BytesRef scratch = new BytesRef(); + if (statePosition >= state.firstValues.size()) { + return; + } + int valueOrd = state.firstValues.get(statePosition) - 1; + if (valueOrd < 0) { + return; + } + combine(current, currentGroupId, state.bytes.get(valueOrd, scratch)); for (int id = 0; id < state.values.size(); id++) { if (state.values.getKey1(id) == statePosition) { long value = state.values.getKey2(id); @@ -146,23 +155,29 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { - final LongLongHash values; + private final BigArrays bigArrays; + private final LongLongHash values; + private IntArray firstValues; // the first value ordinal+1 collected in each group, 0 means no value BytesRefHash bytes; private GroupingState(BigArrays bigArrays) { + this.bigArrays = bigArrays; LongLongHash _values = null; BytesRefHash _bytes = null; + IntArray _firstValues = null; try { _values = new LongLongHash(1, bigArrays); _bytes = new BytesRefHash(1, bigArrays); - + _firstValues = bigArrays.newIntArray(1); values = _values; bytes = _bytes; + firstValues = _firstValues; _values = null; _bytes = null; + _firstValues = null; } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException(_values, _bytes, _firstValues); } } @@ -176,7 +191,7 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { + if (bytes.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -220,11 +235,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* @@ -256,7 +273,7 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(firstValues.size() + values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); } else { return buildOutputBlock(blockFactory, selected, selectedCounts, ids); @@ -266,6 +283,20 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + void addValue(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + final int curr = firstValues.get(groupId) - 1; + if (curr == -1) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (curr != valueOrdinal) { + values.add(groupId, valueOrdinal); + } + } else { + firstValues = bigArrays.grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } + } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -275,20 +306,26 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start], scratch); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); for (int i = start; i < end; i++) { append(builder, ids[i], scratch); } builder.endPositionEntry(); } + start = end; } - start = end; + } return builder.build(); } @@ -304,20 +341,25 @@ Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendInt(firstValue); + } else { builder.beginPositionEntry(); + builder.appendInt(firstValue); for (int i = start; i < end; i++) { builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); } builder.endPositionEntry(); } + start = end; } - start = end; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -343,7 +385,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, firstValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index f5b0d519dd890..62f729b16e656 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -180,11 +180,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 4cfbf329a895d..7bb7444e000e7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -186,11 +186,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 38e5ad99cf581..1428b05a53ac8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -186,11 +186,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 4bfc230d7e1f7..eb9f7beb19908 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -180,11 +180,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 78a083b8daac7..c8094b348ca8f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -55,7 +55,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -77,7 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -93,7 +93,7 @@ public void add(int positionOffset, IntVector groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -135,7 +135,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -150,7 +150,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -159,7 +159,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { public void add(int positionOffset, IntVector groupIds) { for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { int groupId = groupIds.getInt(groupPosition); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 67f32fc4a4d4e..8b1a7ecd2f2d2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -14,6 +14,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; $if(BytesRef)$ import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.IntArray; $else$ import org.elasticsearch.common.util.LongHash; $endif$ @@ -118,7 +119,7 @@ $if(long)$ $elseif(double)$ state.values.add(groupId, Double.doubleToLongBits(v)); $elseif(BytesRef)$ - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(v)))); $elseif(int)$ /* * Encode the groupId and value into a single long - @@ -152,6 +153,14 @@ $endif$ public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { $if(BytesRef)$ BytesRef scratch = new BytesRef(); + if (statePosition >= state.firstValues.size()) { + return; + } + int valueOrd = state.firstValues.get(statePosition) - 1; + if (valueOrd < 0) { + return; + } + combine(current, currentGroupId, state.bytes.get(valueOrd, scratch)); $endif$ for (int id = 0; id < state.values.size(); id++) { $if(long||BytesRef)$ @@ -259,7 +268,9 @@ $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ - final LongLongHash values; + private final BigArrays bigArrays; + private final LongLongHash values; + private IntArray firstValues; // the first value ordinal+1 collected in each group, 0 means no value BytesRefHash bytes; $elseif(int||float)$ @@ -270,19 +281,23 @@ $endif$ $if(long||double)$ values = new LongLongHash(1, bigArrays); $elseif(BytesRef)$ + this.bigArrays = bigArrays; LongLongHash _values = null; BytesRefHash _bytes = null; + IntArray _firstValues = null; try { _values = new LongLongHash(1, bigArrays); _bytes = new BytesRefHash(1, bigArrays); - + _firstValues = bigArrays.newIntArray(1); values = _values; bytes = _bytes; + firstValues = _firstValues; _values = null; _bytes = null; + _firstValues = null; } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException(_values, _bytes, _firstValues); } $elseif(int||float)$ values = new LongHash(1, bigArrays); @@ -299,7 +314,7 @@ $endif$ * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { + if ($if(BytesRef)$bytes$else$values$endif$.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -348,11 +363,13 @@ $endif$ * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* @@ -390,7 +407,7 @@ $endif$ } } $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(firstValues.size() + values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); } else { return buildOutputBlock(blockFactory, selected, selectedCounts, ids); @@ -403,13 +420,56 @@ $endif$ } } +$if(BytesRef)$ + void addValue(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + final int curr = firstValues.get(groupId) - 1; + if (curr == -1) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (curr != valueOrdinal) { + values.add(groupId, valueOrdinal); + } + } else { + firstValues = bigArrays.grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } + } +$endif$ + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. */ $if(BytesRef)$ BytesRef scratch = new BytesRef(); -$endif$ + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { + builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); + for (int i = start; i < end; i++) { + append(builder, ids[i], scratch); + } + builder.endPositionEntry(); + } + start = end; + } + + } + return builder.build(); + } +$else$ try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { @@ -431,6 +491,7 @@ $endif$ } return builder.build(); } +$endif$ } $if(BytesRef)$ @@ -444,20 +505,25 @@ $if(BytesRef)$ int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendInt(firstValue); + } else { builder.beginPositionEntry(); + builder.appendInt(firstValue); for (int i = start; i < end; i++) { builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); } builder.endPositionEntry(); } + start = end; } - start = end; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -501,7 +567,7 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, firstValues); $else$ values.close(); $endif$ From 90c3463dfa6e2d3c43e55e82ca38970d860741e0 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 2 Jul 2025 21:51:34 -0700 Subject: [PATCH 2/7] Update docs/changelog/130510.yaml --- docs/changelog/130510.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/130510.yaml diff --git a/docs/changelog/130510.yaml b/docs/changelog/130510.yaml new file mode 100644 index 0000000000000..01426b6b8e4e3 --- /dev/null +++ b/docs/changelog/130510.yaml @@ -0,0 +1,5 @@ +pr: 130510 +summary: Add fast path for single value in VALUES aggregator +area: ES|QL +type: enhancement +issues: [] From 5a7a5c7301522eb3390068da512b9ef65fe10a13 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 2 Jul 2025 21:53:20 -0700 Subject: [PATCH 3/7] naming --- .../compute/aggregation/ValuesBytesRefAggregator.java | 6 +++--- .../compute/aggregation/X-ValuesAggregator.java.st | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 63166fb5bdf88..8a52433fc0629 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -94,11 +94,11 @@ public static void combineStates(GroupingState current, int currentGroupId, Grou if (statePosition >= state.firstValues.size()) { return; } - int valueOrd = state.firstValues.get(statePosition) - 1; - if (valueOrd < 0) { + int firstValue = state.firstValues.get(statePosition) - 1; + if (firstValue < 0) { return; } - combine(current, currentGroupId, state.bytes.get(valueOrd, scratch)); + combine(current, currentGroupId, state.bytes.get(firstValue, scratch)); for (int id = 0; id < state.values.size(); id++) { if (state.values.getKey1(id) == statePosition) { long value = state.values.getKey2(id); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 8b1a7ecd2f2d2..2c899105516b9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -156,11 +156,11 @@ $if(BytesRef)$ if (statePosition >= state.firstValues.size()) { return; } - int valueOrd = state.firstValues.get(statePosition) - 1; - if (valueOrd < 0) { + int firstValue = state.firstValues.get(statePosition) - 1; + if (firstValue < 0) { return; } - combine(current, currentGroupId, state.bytes.get(valueOrd, scratch)); + combine(current, currentGroupId, state.bytes.get(firstValue, scratch)); $endif$ for (int id = 0; id < state.values.size(); id++) { $if(long||BytesRef)$ From d3f9027634bc1b959b091f9750a71533b53171e5 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 21 Jul 2025 13:35:01 -0700 Subject: [PATCH 4/7] single-value --- .../aggregation/ValuesBytesRefAggregator.java | 364 +++++++------ .../aggregation/ValuesDoubleAggregator.java | 309 ++++++----- .../aggregation/ValuesFloatAggregator.java | 324 ++++++----- .../aggregation/ValuesIntAggregator.java | 324 ++++++----- .../aggregation/ValuesLongAggregator.java | 309 ++++++----- .../aggregation/X-ValuesAggregator.java.st | 510 ++++++++++-------- 6 files changed, 1168 insertions(+), 972 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 79077b6628105..2f6692f0359dc 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -7,11 +7,14 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; @@ -19,13 +22,16 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for BytesRef. @@ -129,47 +135,152 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + } + + int getValue(int index) { + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; private final BlockFactory blockFactory; - private final LongLongHash values; BytesRefHash bytes; + private IntArray firstValues; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); - LongLongHash _values = null; BytesRefHash _bytes = null; + IntArray _firstValues = null; + NextValues _nextValues = null; try { - _values = new LongLongHash(1, driverContext.bigArrays()); _bytes = new BytesRefHash(1, driverContext.bigArrays()); + _firstValues = driverContext.bigArrays().newIntArray(1, true); + _nextValues = new NextValues(driverContext.blockFactory()); - values = _values; - bytes = _bytes; - - _values = null; + this.bytes = _bytes; _bytes = null; + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException(_bytes, _firstValues, _nextValues); } } @@ -178,14 +289,28 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } - void addValueOrdinal(int groupId, long valueOrdinal) { - values.add(groupId, valueOrdinal); - maxGroupId = Math.max(maxGroupId, groupId); + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative + } + + void addValueOrdinal(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + int current = firstValues.get(groupId) - 1; + if (current < 0) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (current != valueOrdinal) { + nextValues.addValue(groupId, valueOrdinal); + } + } else { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } } void addValue(int groupId, BytesRef v) { - values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); - maxGroupId = Math.max(maxGroupId, groupId); + int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v))); + addValueOrdinal(groupId, valueOrdinal); } /** @@ -193,159 +318,81 @@ void addValue(int groupId, BytesRef v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { - return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } else { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } - } - } - - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } + nextValues.prepareForEmitting(selected); + if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) { + return buildOrdinalOutputBlock(blockFactory, selected); + } else { + return buildOutputBlock(blockFactory, selected); } } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ BytesRef scratch = new BytesRef(); + final int[] nextValueCounts = nextValues.selectedCounts; try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendBytesRef(getValue(ids[start], scratch)); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendBytesRef(getValue(ids[i], scratch)); - } - builder.endPositionEntry(); + int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { + builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendBytesRef(bytes.get(nextValue, scratch)); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) { BytesRefVector dict = null; IntBlock ordinals = null; BytesRefBlock result = null; var dictArray = bytes.takeBytesRefsOwnership(); bytes = null; // transfer ownership to dictArray - try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size()); + final int[] nextValueCounts = nextValues.selectedCounts; + try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) { + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); - } - builder.endPositionEntry(); + if (firstValues.size() < group) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + builder.appendInt(nextValues.getValue(i)); } + builder.endPositionEntry(); } - start = end; + nextValuesStart = nextValuesEnd; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -359,18 +406,9 @@ Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int } } - BytesRef getValue(int valueId, BytesRef scratch) { - return bytes.get(values.getKey2(valueId), scratch); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(bytes, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index aa5a16e499fb6..d720c56822b99 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -7,20 +7,31 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for double. @@ -106,34 +117,143 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongLongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongLongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, double v) { + hashes.add(groupId, Double.doubleToLongBits(v)); + } + + double getValue(int index) { + return Double.longBitsToDouble(hashes.getKey2(ids[index])); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; + public static class GroupingState extends AbstractArrayState { private final BlockFactory blockFactory; - private final LongLongHash values; + DoubleArray firstValues; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); - values = new LongLongHash(1, driverContext.bigArrays()); + super(driverContext.bigArrays()); + DoubleArray _firstValues = null; + NextValues _nextValues = null; + try { + _firstValues = driverContext.bigArrays().newDoubleArray(1, true); + _nextValues = new NextValues(driverContext.blockFactory()); + + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + } finally { + Releasables.closeExpectNoException(_firstValues, _nextValues); + } } @Override @@ -142,8 +262,17 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, double v) { - values.add(groupId, Double.doubleToLongBits(v)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } /** @@ -151,144 +280,46 @@ void addValue(int groupId, double v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } - } - - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - if (values.size() > 0) { - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendDouble(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendDouble(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + double firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendDouble(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendDouble(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendDouble(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - double getValue(int valueId) { - return Double.longBitsToDouble(values.getKey2(valueId)); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(super::close, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index fa536b6b7cf66..a25a2dee2776a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -7,19 +7,31 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; +import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.FloatArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for float. @@ -105,34 +117,150 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, float v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); + } + + float getValue(int index) { + long both = hashes.get(ids[index]); + return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL)); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; + public static class GroupingState extends AbstractArrayState { private final BlockFactory blockFactory; - private final LongHash values; + FloatArray firstValues; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); - values = new LongHash(1, driverContext.bigArrays()); + super(driverContext.bigArrays()); + FloatArray _firstValues = null; + NextValues _nextValues = null; + try { + _firstValues = driverContext.bigArrays().newFloatArray(1, true); + _nextValues = new NextValues(driverContext.blockFactory()); + + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + } finally { + Releasables.closeExpectNoException(_firstValues, _nextValues); + } } @Override @@ -141,12 +269,17 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, float v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } /** @@ -154,147 +287,46 @@ void addValue(int groupId, float v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } - } - - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - if (values.size() > 0) { - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendFloat(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendFloat(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + float firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendFloat(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendFloat(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendFloat(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - float getValue(int valueId) { - long both = values.get(valueId); - return Float.intBitsToFloat((int) both); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(super::close, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index be72c024431f4..abfbf01e75b45 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -7,19 +7,31 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; +import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.IntArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for int. @@ -105,34 +117,150 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, int v) { + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); + } + + int getValue(int index) { + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; + public static class GroupingState extends AbstractArrayState { private final BlockFactory blockFactory; - private final LongHash values; + IntArray firstValues; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); - values = new LongHash(1, driverContext.bigArrays()); + super(driverContext.bigArrays()); + IntArray _firstValues = null; + NextValues _nextValues = null; + try { + _firstValues = driverContext.bigArrays().newIntArray(1, true); + _nextValues = new NextValues(driverContext.blockFactory()); + + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + } finally { + Releasables.closeExpectNoException(_firstValues, _nextValues); + } } @Override @@ -141,12 +269,17 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, int v) { - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } /** @@ -154,147 +287,46 @@ void addValue(int groupId, int v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - if (values.size() > 0) { - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } - } - - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendInt(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - int getValue(int valueId) { - long both = values.get(valueId); - return (int) both; - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(super::close, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 3731be99afd3f..6e51881e2df28 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -7,20 +7,31 @@ package org.elasticsearch.compute.aggregation; +// begin generated imports +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; +import org.elasticsearch.common.util.LongArray; +import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.OrdinalBytesRefBlock; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for long. @@ -106,34 +117,143 @@ public void close() { } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; + private final LongLongHash hashes; + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new LongLongHash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, long v) { + hashes.add(groupId, v); + } + + long getValue(int index) { + return hashes.getKey2(ids[index]); + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { + int group = (int) hashes.getKey1(id); + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; + public static class GroupingState extends AbstractArrayState { private final BlockFactory blockFactory; - private final LongLongHash values; + LongArray firstValues; + private int maxGroupId = -1; + private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); - values = new LongLongHash(1, driverContext.bigArrays()); + super(driverContext.bigArrays()); + LongArray _firstValues = null; + NextValues _nextValues = null; + try { + _firstValues = driverContext.bigArrays().newLongArray(1, true); + _nextValues = new NextValues(driverContext.blockFactory()); + + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + } finally { + Releasables.closeExpectNoException(_firstValues, _nextValues); + } } @Override @@ -142,8 +262,17 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive } void addValue(int groupId, long v) { - values.add(groupId, v); - maxGroupId = Math.max(maxGroupId, groupId); + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } } /** @@ -151,144 +280,46 @@ void addValue(int groupId, long v) { * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } - } - - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - if (values.size() > 0) { - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - int group = (int) values.getKey1(id); - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } + nextValues.prepareForEmitting(selected); + return buildOutputBlock(blockFactory, selected); } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ + final int[] nextValueCounts = nextValues.selectedCounts; try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendLong(getValue(ids[start])); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendLong(getValue(ids[i])); - } - builder.endPositionEntry(); + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + long firstValue = firstValues.get(group); + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendLong(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendLong(firstValue); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.appendLong(nextValue); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } - long getValue(int valueId) { - return values.getKey2(valueId); - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { - Releasables.closeExpectNoException(values); + Releasables.closeExpectNoException(super::close, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index d92ac5fa0afce..72e9b6fab893f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -7,44 +7,31 @@ package org.elasticsearch.compute.aggregation; -$if(BytesRef)$ +// begin generated imports import org.apache.lucene.util.BytesRef; -$endif$ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; -$if(BytesRef)$ import org.elasticsearch.common.util.BytesRefHash; -$else$ import org.elasticsearch.common.util.LongHash; -$endif$ -$if(long||double||BytesRef)$ import org.elasticsearch.common.util.LongLongHash; -$endif$ -$if(BytesRef)$ +import org.elasticsearch.common.util.$if(BytesRef)$Int$else$$Type$$endif$Array; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; -$endif$ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -$if(int||double||float)$ import org.elasticsearch.compute.data.$Type$Block; -$elseif(BytesRef)$ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.BytesRefVector; import org.elasticsearch.compute.data.IntBlock; -$endif$ import org.elasticsearch.compute.data.IntVector; -$if(long)$ import org.elasticsearch.compute.data.LongBlock; -$endif$ -$if(BytesRef)$ import org.elasticsearch.compute.data.OrdinalBytesRefBlock; -$endif$ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +// end generated imports /** * Aggregates field values for $type$. @@ -204,62 +191,202 @@ $endif$ } /** - * Values are collected in a hash. Iterating over them in order (row by row) to build the output, - * or merging with other state, can be expensive. To optimize this, we build a sorted structure once, - * and then use it to iterate over the values in order. - * - * @param ids positions of the {@link GroupingState#values} to read. + * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. + * When emitting the output, we need to iterate the hash one group at a time to build the output block, + * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * to an array, allowing us to build the output in O(N) instead. */ - private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable { + private static class NextValues implements Releasable { + private final BlockFactory blockFactory; +$if(long||double)$ + private final LongLongHash hashes; +$else$ + private final LongHash hashes; +$endif$ + private int[] selectedCounts = null; + private int[] ids = null; + private long extraMemoryUsed = 0; + + private NextValues(BlockFactory blockFactory) { + this.blockFactory = blockFactory; + this.hashes = new Long$if(long||double)$Long$endif$Hash(1, blockFactory.bigArrays()); + } + + void addValue(int groupId, $if(BytesRef)$int$else$$type$$endif$ v) { +$if(long)$ + hashes.add(groupId, v); +$elseif(double)$ + hashes.add(groupId, Double.doubleToLongBits(v)); +$elseif(int||BytesRef)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); +$elseif(float)$ + /* + * Encode the groupId and value into a single long - + * the top 32 bits for the group, the bottom 32 for the value. + */ + hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$endif$ + } + + $if(BytesRef)$int$else$$type$$endif$ getValue(int index) { +$if(long)$ + return hashes.getKey2(ids[index]); +$elseif(double)$ + return Double.longBitsToDouble(hashes.getKey2(ids[index])); +$elseif(float)$ + long both = hashes.get(ids[index]); + return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL)); +$elseif(BytesRef||int)$ + long both = hashes.get(ids[index]); + return (int) (both & 0xFFFFFFFFL); +$endif$ + } + + private void reserveBytesForIntArray(long numElements) { + long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES); + blockFactory.adjustBreaker(adjust); + extraMemoryUsed += adjust; + } + + private void prepareForEmitting(IntVector selected) { + if (hashes.size() == 0) { + return; + } + /* + * Get a count of all groups less than the maximum selected group. Count + * *downwards* so that we can flip the sign on all of the actually selected + * groups. Negative values in this array are always unselected groups. + */ + int selectedCountsLen = selected.max() + 1; + reserveBytesForIntArray(selectedCountsLen); + this.selectedCounts = new int[selectedCountsLen]; + for (int id = 0; id < hashes.size(); id++) { +$if(long||double)$ + int group = (int) hashes.getKey1(id); +$elseif(float||int||BytesRef)$ + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); +$endif$ + if (group < selectedCounts.length) { + selectedCounts[group]--; + } + } + + /* + * Total the selected groups and turn the counts into the start index into a sort-of + * off-by-one running count. It's really the number of values that have been inserted + * into the results before starting on this group. Unselected groups will still + * have negative counts. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 + */ + int total = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } + + /* + * Build a list of ids to insert in order *and* convert the running + * count in selectedCounts[group] into the end index (exclusive) in + * ids for each group. + * Here we use the negative counts to signal that a group hasn't been + * selected and the id containing values for that group is ignored. + * + * For example, if + * | Group | Value Count | Selected | + * |-------|-------------|----------| + * | 0 | 3 | <- | + * | 1 | 1 | <- | + * | 2 | 2 | | + * | 3 | 1 | <- | + * | 4 | 4 | <- | + * + * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. + * The counts will end with 3, 4, -2, 5, 9. + */ + reserveBytesForIntArray(total); + + this.ids = new int[total]; + for (int id = 0; id < hashes.size(); id++) { +$if(long||double)$ + int group = (int) hashes.getKey1(id); +$elseif(float||int||BytesRef)$ + long both = hashes.get(id); + int group = (int) (both >>> Float.SIZE); +$endif$ + ids[selectedCounts[group]++] = id; + } + } + @Override public void close() { - releasable.close(); + Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed)); } } /** * State for a grouped {@code VALUES} aggregation. This implementation - * emphasizes collect-time performance over the performance of rendering - * results. That's good, but it's a pretty intensive emphasis, requiring - * an {@code O(n^2)} operation for collection to support a {@code O(1)} - * collector operation. But at least it's fairly simple. + * emphasizes collect-time performance over result rendering performance. + * The first value in each group is collected in the {@code firstValues} + * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState implements GroupingAggregatorState { - private int maxGroupId = -1; + public static class GroupingState $if(BytesRef)$implements GroupingAggregatorState$else$extends AbstractArrayState$endif$ { private final BlockFactory blockFactory; -$if(long||double)$ - private final LongLongHash values; - -$elseif(BytesRef)$ - private final LongLongHash values; +$if(BytesRef)$ BytesRefHash bytes; - -$elseif(int||float)$ - private final LongHash values; - + private IntArray firstValues; +$else$ + $Type$Array firstValues; + private int maxGroupId = -1; $endif$ + private final NextValues nextValues; + private GroupingState(DriverContext driverContext) { - this.blockFactory = driverContext.blockFactory(); -$if(long||double)$ - values = new LongLongHash(1, driverContext.bigArrays()); -$elseif(BytesRef)$ - LongLongHash _values = null; +$if(BytesRef)$ BytesRefHash _bytes = null; + IntArray _firstValues = null; +$else$ + super(driverContext.bigArrays()); + $Type$Array _firstValues = null; +$endif$ + NextValues _nextValues = null; try { - _values = new LongLongHash(1, driverContext.bigArrays()); +$if(BytesRef)$ _bytes = new BytesRefHash(1, driverContext.bigArrays()); + _firstValues = driverContext.bigArrays().newIntArray(1, true); +$else$ + _firstValues = driverContext.bigArrays().new$Type$Array(1, true); +$endif$ + _nextValues = new NextValues(driverContext.blockFactory()); - values = _values; - bytes = _bytes; - - _values = null; +$if(BytesRef)$ + this.bytes = _bytes; _bytes = null; +$endif$ + this.firstValues = _firstValues; + _firstValues = null; + this.nextValues = _nextValues; + _nextValues = null; + this.blockFactory = driverContext.blockFactory(); } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException($if(BytesRef)$_bytes, $endif$_firstValues, _nextValues); } -$elseif(int||float)$ - values = new LongHash(1, driverContext.bigArrays()); -$endif$ } @Override @@ -268,33 +395,43 @@ $endif$ } $if(BytesRef)$ - void addValueOrdinal(int groupId, long valueOrdinal) { - values.add(groupId, valueOrdinal); - maxGroupId = Math.max(maxGroupId, groupId); + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative + } + + void addValueOrdinal(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + int current = firstValues.get(groupId) - 1; + if (current < 0) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (current != valueOrdinal) { + nextValues.addValue(groupId, valueOrdinal); + } + } else { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } } $endif$ void addValue(int groupId, $type$ v) { -$if(long)$ - values.add(groupId, v); -$elseif(double)$ - values.add(groupId, Double.doubleToLongBits(v)); -$elseif(BytesRef)$ - values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v))); -$elseif(int)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL)); -$elseif(float)$ - /* - * Encode the groupId and value into a single long - - * the top 32 bits for the group, the bottom 32 for the value. - */ - values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL)); +$if(BytesRef)$ + int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v))); + addValueOrdinal(groupId, valueOrdinal); +$else$ + if (groupId > maxGroupId) { + firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); + firstValues.set(groupId, v); + trackGroupId(groupId); + maxGroupId = groupId; + } else if (hasValue(groupId) == false) { + firstValues.set(groupId, v); + trackGroupId(groupId); + } else if (firstValues.get(groupId) != v) { + nextValues.addValue(groupId, v); + } $endif$ - maxGroupId = Math.max(maxGroupId, groupId); } /** @@ -302,176 +439,96 @@ $endif$ * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { - return blockFactory.newConstantNullBlock(selected.getPositionCount()); - } - - try (var sorted = buildSorted(selected)) { + nextValues.prepareForEmitting(selected); $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) { - return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } else { - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); - } -$else$ - return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids); -$endif$ + if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) { + return buildOrdinalOutputBlock(blockFactory, selected); + } else { + return buildOutputBlock(blockFactory, selected); } - } - - private Sorted buildSorted(IntVector selected) { - long selectedCountsSize = 0; - long idsSize = 0; - Sorted sorted = null; - try { - /* - * Get a count of all groups less than the maximum selected group. Count - * *downwards* so that we can flip the sign on all of the actually selected - * groups. Negative values in this array are always unselected groups. - */ - int selectedCountsLen = selected.max() + 1; - long adjust = RamUsageEstimator.alignObjectSize( - RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES - ); - blockFactory.adjustBreaker(adjust); - selectedCountsSize = adjust; - int[] selectedCounts = new int[selectedCountsLen]; - for (int id = 0; id < values.size(); id++) { -$if(long||BytesRef||double)$ - int group = (int) values.getKey1(id); -$elseif(float||int)$ - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); +$else$ + return buildOutputBlock(blockFactory, selected); $endif$ - if (group < selectedCounts.length) { - selectedCounts[group]--; - } - } - - /* - * Total the selected groups and turn the counts into the start index into a sort-of - * off-by-one running count. It's really the number of values that have been inserted - * into the results before starting on this group. Unselected groups will still - * have negative counts. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 - */ - int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; - } - - /* - * Build a list of ids to insert in order *and* convert the running - * count in selectedCounts[group] into the end index (exclusive) in - * ids for each group. - * Here we use the negative counts to signal that a group hasn't been - * selected and the id containing values for that group is ignored. - * - * For example, if - * | Group | Value Count | Selected | - * |-------|-------------|----------| - * | 0 | 3 | <- | - * | 1 | 1 | <- | - * | 2 | 2 | | - * | 3 | 1 | <- | - * | 4 | 4 | <- | - * - * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5. - * The counts will end with 3, 4, -2, 5, 9. - */ - adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES); - blockFactory.adjustBreaker(adjust); - idsSize = adjust; - int[] ids = new int[total]; - for (int id = 0; id < values.size(); id++) { - $if(long||BytesRef||double)$ - int group = (int) values.getKey1(id); - $elseif(float||int)$ - long both = values.get(id); - int group = (int) (both >>> Float.SIZE); - $endif$ - if (group < selectedCounts.length && selectedCounts[group] >= 0) { - ids[selectedCounts[group]++] = id; - } - } - final long totalMemoryUsed = selectedCountsSize + idsSize; - sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids); - return sorted; - } finally { - if (sorted == null) { - blockFactory.adjustBreaker(-selectedCountsSize - idsSize); - } - } } - Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { /* * Insert the ids in order. */ $if(BytesRef)$ BytesRef scratch = new BytesRef(); $endif$ + final int[] nextValueCounts = nextValues.selectedCounts; try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { - int start = 0; + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$)); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$)); - } - builder.endPositionEntry(); +$if(BytesRef)$ + int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } +$else$ + if (group > maxGroupId || hasValue(group) == false) { + builder.appendNull(); + continue; + } + $type$ firstValue = firstValues.get(group); +$endif$ + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$); + } else { + builder.beginPositionEntry(); + builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$); + // append values from the nextValues + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + var nextValue = nextValues.getValue(i); + builder.append$Type$($if(BytesRef)$bytes.get(nextValue, scratch)$else$nextValue$endif$); } + builder.endPositionEntry(); + nextValuesStart = nextValuesEnd; } - start = end; } return builder.build(); } } $if(BytesRef)$ - Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { + Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) { BytesRefVector dict = null; IntBlock ordinals = null; BytesRefBlock result = null; var dictArray = bytes.takeBytesRefsOwnership(); bytes = null; // transfer ownership to dictArray - try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) { - int start = 0; + int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size()); + final int[] nextValueCounts = nextValues.selectedCounts; + try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) { + int nextValuesStart = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { - builder.beginPositionEntry(); - for (int i = start; i < end; i++) { - builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); - } - builder.endPositionEntry(); + if (firstValues.size() < group) { + builder.appendNull(); + continue; + } + int firstValue = firstValues.get(group) - 1; + if (firstValue < 0) { + builder.appendNull(); + continue; + } + final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart; + if (nextValuesEnd == nextValuesStart) { + builder.appendInt(firstValue); + } else { + builder.beginPositionEntry(); + builder.appendInt(firstValue); + for (int i = nextValuesStart; i < nextValuesEnd; i++) { + builder.appendInt(nextValues.getValue(i)); } + builder.endPositionEntry(); } - start = end; + nextValuesStart = nextValuesEnd; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -486,34 +543,9 @@ $if(BytesRef)$ } $endif$ - $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) { -$if(BytesRef)$ - return bytes.get(values.getKey2(valueId), scratch); -$elseif(long)$ - return values.getKey2(valueId); -$elseif(double)$ - return Double.longBitsToDouble(values.getKey2(valueId)); -$elseif(float)$ - long both = values.get(valueId); - return Float.intBitsToFloat((int) both); -$elseif(int)$ - long both = values.get(valueId); - return (int) both; -$endif$ - } - - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from nulls on the values block - } - @Override public void close() { -$if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); -$else$ - Releasables.closeExpectNoException(values); -$endif$ + Releasables.closeExpectNoException($if(BytesRef)$bytes$else$super::close$endif$, firstValues, nextValues); } } } From bab40942dd84e0cd0ab24111901d02c1b2308585 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 23 Jul 2025 14:39:11 -0700 Subject: [PATCH 5/7] fix seen --- .../aggregation/ValuesBytesRefAggregator.java | 11 ++--- .../aggregation/ValuesDoubleAggregator.java | 30 +++++++++++-- .../aggregation/ValuesFloatAggregator.java | 30 +++++++++++-- .../aggregation/ValuesIntAggregator.java | 30 +++++++++++-- .../aggregation/ValuesLongAggregator.java | 30 +++++++++++-- .../aggregation/X-ValuesAggregator.java.st | 43 +++++++++++++++---- 6 files changed, 144 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 2f6692f0359dc..c187a4218a674 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -289,11 +290,6 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive blocks[offset] = toBlock(driverContext.blockFactory(), selected); } - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from firstValues since ordinals are non-negative - } - void addValueOrdinal(int groupId, int valueOrdinal) { if (groupId < firstValues.size()) { int current = firstValues.get(groupId) - 1; @@ -313,6 +309,11 @@ void addValue(int groupId, BytesRef v) { addValueOrdinal(groupId, valueOrdinal); } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index d720c56822b99..40542e51b3639 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -232,18 +233,18 @@ public void close() { * The first value in each group is collected in the {@code firstValues} * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState extends AbstractArrayState { + public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; DoubleArray firstValues; + private BitArray seen; private int maxGroupId = -1; private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - super(driverContext.bigArrays()); DoubleArray _firstValues = null; NextValues _nextValues = null; try { - _firstValues = driverContext.bigArrays().newDoubleArray(1, true); + _firstValues = driverContext.bigArrays().newDoubleArray(1, false); _nextValues = new NextValues(driverContext.blockFactory()); this.firstValues = _firstValues; @@ -265,6 +266,12 @@ void addValue(int groupId, double v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + if (maxGroupId >= 0) { + seen.fill(0, maxGroupId, true); + } + } trackGroupId(groupId); maxGroupId = groupId; } else if (hasValue(groupId) == false) { @@ -275,6 +282,21 @@ void addValue(int groupId, double v) { } } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -319,7 +341,7 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { @Override public void close() { - Releasables.closeExpectNoException(super::close, firstValues, nextValues); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index a25a2dee2776a..a7c23bcf01a78 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -239,18 +240,18 @@ public void close() { * The first value in each group is collected in the {@code firstValues} * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState extends AbstractArrayState { + public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; FloatArray firstValues; + private BitArray seen; private int maxGroupId = -1; private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - super(driverContext.bigArrays()); FloatArray _firstValues = null; NextValues _nextValues = null; try { - _firstValues = driverContext.bigArrays().newFloatArray(1, true); + _firstValues = driverContext.bigArrays().newFloatArray(1, false); _nextValues = new NextValues(driverContext.blockFactory()); this.firstValues = _firstValues; @@ -272,6 +273,12 @@ void addValue(int groupId, float v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + if (maxGroupId >= 0) { + seen.fill(0, maxGroupId, true); + } + } trackGroupId(groupId); maxGroupId = groupId; } else if (hasValue(groupId) == false) { @@ -282,6 +289,21 @@ void addValue(int groupId, float v) { } } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -326,7 +348,7 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { @Override public void close() { - Releasables.closeExpectNoException(super::close, firstValues, nextValues); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index abfbf01e75b45..2d1a559662dea 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -239,18 +240,18 @@ public void close() { * The first value in each group is collected in the {@code firstValues} * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState extends AbstractArrayState { + public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; IntArray firstValues; + private BitArray seen; private int maxGroupId = -1; private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - super(driverContext.bigArrays()); IntArray _firstValues = null; NextValues _nextValues = null; try { - _firstValues = driverContext.bigArrays().newIntArray(1, true); + _firstValues = driverContext.bigArrays().newIntArray(1, false); _nextValues = new NextValues(driverContext.blockFactory()); this.firstValues = _firstValues; @@ -272,6 +273,12 @@ void addValue(int groupId, int v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + if (maxGroupId >= 0) { + seen.fill(0, maxGroupId, true); + } + } trackGroupId(groupId); maxGroupId = groupId; } else if (hasValue(groupId) == false) { @@ -282,6 +289,21 @@ void addValue(int groupId, int v) { } } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -326,7 +348,7 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { @Override public void close() { - Releasables.closeExpectNoException(super::close, firstValues, nextValues); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 6e51881e2df28..bbb21280c02d5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -232,18 +233,18 @@ public void close() { * The first value in each group is collected in the {@code firstValues} * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState extends AbstractArrayState { + public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; LongArray firstValues; + private BitArray seen; private int maxGroupId = -1; private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - super(driverContext.bigArrays()); LongArray _firstValues = null; NextValues _nextValues = null; try { - _firstValues = driverContext.bigArrays().newLongArray(1, true); + _firstValues = driverContext.bigArrays().newLongArray(1, false); _nextValues = new NextValues(driverContext.blockFactory()); this.firstValues = _firstValues; @@ -265,6 +266,12 @@ void addValue(int groupId, long v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + if (maxGroupId >= 0) { + seen.fill(0, maxGroupId, true); + } + } trackGroupId(groupId); maxGroupId = groupId; } else if (hasValue(groupId) == false) { @@ -275,6 +282,21 @@ void addValue(int groupId, long v) { } } + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -319,7 +341,7 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) { @Override public void close() { - Releasables.closeExpectNoException(super::close, firstValues, nextValues); + Releasables.closeExpectNoException(seen, firstValues, nextValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 72e9b6fab893f..b958b7bf7f187 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -11,6 +11,7 @@ package org.elasticsearch.compute.aggregation; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.LongLongHash; @@ -346,13 +347,14 @@ $endif$ * The first value in each group is collected in the {@code firstValues} * array, and subsequent values for each group are collected in {@code nextValues}. */ - public static class GroupingState $if(BytesRef)$implements GroupingAggregatorState$else$extends AbstractArrayState$endif$ { + public static class GroupingState implements GroupingAggregatorState { private final BlockFactory blockFactory; $if(BytesRef)$ BytesRefHash bytes; private IntArray firstValues; $else$ $Type$Array firstValues; + private BitArray seen; private int maxGroupId = -1; $endif$ private final NextValues nextValues; @@ -362,7 +364,6 @@ $if(BytesRef)$ BytesRefHash _bytes = null; IntArray _firstValues = null; $else$ - super(driverContext.bigArrays()); $Type$Array _firstValues = null; $endif$ NextValues _nextValues = null; @@ -371,7 +372,7 @@ $if(BytesRef)$ _bytes = new BytesRefHash(1, driverContext.bigArrays()); _firstValues = driverContext.bigArrays().newIntArray(1, true); $else$ - _firstValues = driverContext.bigArrays().new$Type$Array(1, true); + _firstValues = driverContext.bigArrays().new$Type$Array(1, false); $endif$ _nextValues = new NextValues(driverContext.blockFactory()); @@ -395,11 +396,6 @@ $endif$ } $if(BytesRef)$ - @Override - public void enableGroupIdTracking(SeenGroupIds seen) { - // we figure out seen values from firstValues since ordinals are non-negative - } - void addValueOrdinal(int groupId, int valueOrdinal) { if (groupId < firstValues.size()) { int current = firstValues.get(groupId) - 1; @@ -423,6 +419,12 @@ $else$ if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + if (seen == null && groupId > maxGroupId + 1) { + seen = new BitArray(groupId + 1, blockFactory.bigArrays()); + if (maxGroupId >= 0) { + seen.fill(0, maxGroupId, true); + } + } trackGroupId(groupId); maxGroupId = groupId; } else if (hasValue(groupId) == false) { @@ -434,6 +436,29 @@ $else$ $endif$ } +$if(BytesRef)$ + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we figure out seen values from firstValues since ordinals are non-negative + } + +$else$ + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { + // we track the seen values manually + } + + private void trackGroupId(int groupId) { + if (seen != null) { + seen.set(groupId); + } + } + + private boolean hasValue(int groupId) { + return seen == null || seen.get(groupId); + } + +$endif$ /** * Builds a {@link Block} with the unique values collected for the {@code #selected} * groups. This is the implementation of the final and intermediate results of the agg. @@ -545,7 +570,7 @@ $endif$ @Override public void close() { - Releasables.closeExpectNoException($if(BytesRef)$bytes$else$super::close$endif$, firstValues, nextValues); + Releasables.closeExpectNoException($if(BytesRef)$bytes$else$seen$endif$, firstValues, nextValues); } } } From 239c90d3c3c322b02b946422d17e586126b8f7b4 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 24 Jul 2025 10:11:26 -0700 Subject: [PATCH 6/7] javadocs --- .../compute/aggregation/ValuesDoubleAggregator.java | 7 ++++--- .../compute/aggregation/ValuesFloatAggregator.java | 7 ++++--- .../compute/aggregation/ValuesIntAggregator.java | 13 +++++++++---- .../compute/aggregation/ValuesLongAggregator.java | 7 ++++--- .../compute/aggregation/X-ValuesAggregator.java.st | 13 +++++++++---- 5 files changed, 30 insertions(+), 17 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 40542e51b3639..1ed1fc48281fa 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -266,11 +266,12 @@ void addValue(int groupId, double v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. if (seen == null && groupId > maxGroupId + 1) { seen = new BitArray(groupId + 1, blockFactory.bigArrays()); - if (maxGroupId >= 0) { - seen.fill(0, maxGroupId, true); - } + seen.fill(0, maxGroupId + 1, true); } trackGroupId(groupId); maxGroupId = groupId; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index a7c23bcf01a78..9ee895ebb7150 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -273,11 +273,12 @@ void addValue(int groupId, float v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. if (seen == null && groupId > maxGroupId + 1) { seen = new BitArray(groupId + 1, blockFactory.bigArrays()); - if (maxGroupId >= 0) { - seen.fill(0, maxGroupId, true); - } + seen.fill(0, maxGroupId + 1, true); } trackGroupId(groupId); maxGroupId = groupId; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 2d1a559662dea..be64fd0d0bbad 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -120,7 +120,7 @@ public void close() { /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -273,11 +273,12 @@ void addValue(int groupId, int v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. if (seen == null && groupId > maxGroupId + 1) { seen = new BitArray(groupId + 1, blockFactory.bigArrays()); - if (maxGroupId >= 0) { - seen.fill(0, maxGroupId, true); - } + seen.fill(0, maxGroupId + 1, true); } trackGroupId(groupId); maxGroupId = groupId; @@ -300,6 +301,10 @@ private void trackGroupId(int groupId) { } } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ private boolean hasValue(int groupId) { return seen == null || seen.get(groupId); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index bbb21280c02d5..bea612338d532 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -266,11 +266,12 @@ void addValue(int groupId, long v) { if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. if (seen == null && groupId > maxGroupId + 1) { seen = new BitArray(groupId + 1, blockFactory.bigArrays()); - if (maxGroupId >= 0) { - seen.fill(0, maxGroupId, true); - } + seen.fill(0, maxGroupId + 1, true); } trackGroupId(groupId); maxGroupId = groupId; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index b958b7bf7f187..4935182065df6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -194,7 +194,7 @@ $endif$ /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -419,11 +419,12 @@ $else$ if (groupId > maxGroupId) { firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1); firstValues.set(groupId, v); + // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating + // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset, + // fill the groups that have values, and begin tracking incoming groups. if (seen == null && groupId > maxGroupId + 1) { seen = new BitArray(groupId + 1, blockFactory.bigArrays()); - if (maxGroupId >= 0) { - seen.fill(0, maxGroupId, true); - } + seen.fill(0, maxGroupId + 1, true); } trackGroupId(groupId); maxGroupId = groupId; @@ -454,6 +455,10 @@ $else$ } } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ private boolean hasValue(int groupId) { return seen == null || seen.get(groupId); } From e390d42e2fd7f935e8f8f46c61e7b28ad0149c7b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 24 Jul 2025 10:26:06 -0700 Subject: [PATCH 7/7] ctor --- .../aggregation/ValuesBytesRefAggregator.java | 26 ++++++--------- .../aggregation/ValuesDoubleAggregator.java | 25 ++++++++------- .../aggregation/ValuesFloatAggregator.java | 25 ++++++++------- .../aggregation/ValuesIntAggregator.java | 19 +++++------ .../aggregation/ValuesLongAggregator.java | 25 ++++++++------- .../aggregation/X-ValuesAggregator.java.st | 32 ++++++------------- 6 files changed, 67 insertions(+), 85 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index c187a4218a674..120a77fda92c6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -138,7 +138,7 @@ public void close() { /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -265,23 +265,17 @@ public static class GroupingState implements GroupingAggregatorState { private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - BytesRefHash _bytes = null; - IntArray _firstValues = null; - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { - _bytes = new BytesRefHash(1, driverContext.bigArrays()); - _firstValues = driverContext.bigArrays().newIntArray(1, true); - _nextValues = new NextValues(driverContext.blockFactory()); - - this.bytes = _bytes; - _bytes = null; - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.bytes = new BytesRefHash(1, driverContext.bigArrays()); + this.firstValues = driverContext.bigArrays().newIntArray(1, true); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_bytes, _firstValues, _nextValues); + if (success == false) { + this.close(); + } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index 1ed1fc48281fa..e9c217ff172b7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -120,7 +120,7 @@ public void close() { /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -241,19 +241,16 @@ public static class GroupingState implements GroupingAggregatorState { private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - DoubleArray _firstValues = null; - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { - _firstValues = driverContext.bigArrays().newDoubleArray(1, false); - _nextValues = new NextValues(driverContext.blockFactory()); - - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.firstValues = driverContext.bigArrays().newDoubleArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_firstValues, _nextValues); + if (success == false) { + this.close(); + } } } @@ -294,6 +291,10 @@ private void trackGroupId(int groupId) { } } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ private boolean hasValue(int groupId) { return seen == null || seen.get(groupId); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 9ee895ebb7150..cd8ed8a854763 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -120,7 +120,7 @@ public void close() { /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -248,19 +248,16 @@ public static class GroupingState implements GroupingAggregatorState { private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - FloatArray _firstValues = null; - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { - _firstValues = driverContext.bigArrays().newFloatArray(1, false); - _nextValues = new NextValues(driverContext.blockFactory()); - - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.firstValues = driverContext.bigArrays().newFloatArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_firstValues, _nextValues); + if (success == false) { + this.close(); + } } } @@ -301,6 +298,10 @@ private void trackGroupId(int groupId) { } } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ private boolean hasValue(int groupId) { return seen == null || seen.get(groupId); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index be64fd0d0bbad..f6c05104f749b 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -248,19 +248,16 @@ public static class GroupingState implements GroupingAggregatorState { private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - IntArray _firstValues = null; - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { - _firstValues = driverContext.bigArrays().newIntArray(1, false); - _nextValues = new NextValues(driverContext.blockFactory()); - - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.firstValues = driverContext.bigArrays().newIntArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_firstValues, _nextValues); + if (success == false) { + this.close(); + } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index bea612338d532..93bfdce654d1e 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -120,7 +120,7 @@ public void close() { /** * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value. * When emitting the output, we need to iterate the hash one group at a time to build the output block, - * which can require O(N^2). To avoid this, we compute the counts for each group and remap the hash id + * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id * to an array, allowing us to build the output in O(N) instead. */ private static class NextValues implements Releasable { @@ -241,19 +241,16 @@ public static class GroupingState implements GroupingAggregatorState { private final NextValues nextValues; private GroupingState(DriverContext driverContext) { - LongArray _firstValues = null; - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { - _firstValues = driverContext.bigArrays().newLongArray(1, false); - _nextValues = new NextValues(driverContext.blockFactory()); - - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.firstValues = driverContext.bigArrays().newLongArray(1, false); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException(_firstValues, _nextValues); + if (success == false) { + this.close(); + } } } @@ -294,6 +291,10 @@ private void trackGroupId(int groupId) { } } + /** + * Returns true if the group has a value in firstValues; having a value in nextValues is optional. + * Returns false if the group does not have values in either firstValues or nextValues. + */ private boolean hasValue(int groupId) { return seen == null || seen.get(groupId); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 4935182065df6..017c5540c98f0 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -360,33 +360,21 @@ $endif$ private final NextValues nextValues; private GroupingState(DriverContext driverContext) { -$if(BytesRef)$ - BytesRefHash _bytes = null; - IntArray _firstValues = null; -$else$ - $Type$Array _firstValues = null; -$endif$ - NextValues _nextValues = null; + this.blockFactory = driverContext.blockFactory(); + boolean success = false; try { $if(BytesRef)$ - _bytes = new BytesRefHash(1, driverContext.bigArrays()); - _firstValues = driverContext.bigArrays().newIntArray(1, true); + this.bytes = new BytesRefHash(1, driverContext.bigArrays()); + this.firstValues = driverContext.bigArrays().newIntArray(1, true); $else$ - _firstValues = driverContext.bigArrays().new$Type$Array(1, false); + this.firstValues = driverContext.bigArrays().new$Type$Array(1, false); $endif$ - _nextValues = new NextValues(driverContext.blockFactory()); - -$if(BytesRef)$ - this.bytes = _bytes; - _bytes = null; -$endif$ - this.firstValues = _firstValues; - _firstValues = null; - this.nextValues = _nextValues; - _nextValues = null; - this.blockFactory = driverContext.blockFactory(); + this.nextValues = new NextValues(driverContext.blockFactory()); + success = true; } finally { - Releasables.closeExpectNoException($if(BytesRef)$_bytes, $endif$_firstValues, _nextValues); + if (success == false) { + this.close(); + } } }