Skip to content

Commit bf0a02c

Browse files
committed
ES|QL: Add PRESENT ES|QL function
- Optimize AggregatorFunctions Part of #131069
1 parent ec176ee commit bf0a02c

File tree

2 files changed

+42
-50
lines changed

2 files changed

+42
-50
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PresentAggregatorFunction.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.compute.operator.DriverContext;
1616

1717
import java.util.List;
18+
import java.util.concurrent.atomic.AtomicBoolean;
1819

1920
public class PresentAggregatorFunction implements AggregatorFunction {
2021
public static AggregatorFunctionSupplier supplier() {
@@ -47,22 +48,21 @@ public String describe() {
4748
}
4849

4950
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
50-
new IntermediateStateDesc("present", ElementType.BOOLEAN),
51-
new IntermediateStateDesc("seen", ElementType.BOOLEAN)
51+
new IntermediateStateDesc("present", ElementType.BOOLEAN)
5252
);
5353

5454
public static List<IntermediateStateDesc> intermediateStateDesc() {
5555
return INTERMEDIATE_STATE_DESC;
5656
}
5757

58-
private final BooleanState state;
58+
private final AtomicBoolean state;
5959
private final List<Integer> channels;
6060

6161
public static PresentAggregatorFunction create(List<Integer> inputChannels) {
62-
return new PresentAggregatorFunction(inputChannels, new BooleanState(false));
62+
return new PresentAggregatorFunction(inputChannels, new AtomicBoolean(false));
6363
}
6464

65-
private PresentAggregatorFunction(List<Integer> channels, BooleanState state) {
65+
private PresentAggregatorFunction(List<Integer> channels, AtomicBoolean state) {
6666
this.channels = channels;
6767
this.state = state;
6868
}
@@ -79,7 +79,6 @@ private int blockIndex() {
7979
@Override
8080
public void addRawInput(Page page, BooleanVector mask) {
8181
Block block = page.getBlock(blockIndex());
82-
BooleanState state = this.state;
8382
boolean present;
8483
if (mask.isConstant()) {
8584
if (mask.getBoolean(0) == false) {
@@ -89,7 +88,7 @@ public void addRawInput(Page page, BooleanVector mask) {
8988
} else {
9089
present = presentMasked(block, mask);
9190
}
92-
state.booleanValue(present);
91+
this.state.set(present);
9392
}
9493

9594
private boolean presentMasked(Block block, BooleanVector mask) {
@@ -111,20 +110,20 @@ public void addIntermediateInput(Page page) {
111110
return;
112111
}
113112
BooleanVector present = page.<BooleanBlock>getBlock(channels.get(0)).asVector();
114-
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
115113
assert present.getPositionCount() == 1;
116-
assert present.getPositionCount() == seen.getPositionCount();
117-
state.booleanValue(state.booleanValue() || present.getBoolean(0));
114+
if (present.getBoolean(0)) {
115+
state.set(true);
116+
}
118117
}
119118

120119
@Override
121120
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
122-
state.toIntermediate(blocks, offset, driverContext);
121+
evaluateFinal(blocks, offset, driverContext);
123122
}
124123

125124
@Override
126125
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
127-
blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1);
126+
blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.get(), 1);
128127
}
129128

130129
@Override
@@ -137,7 +136,5 @@ public String toString() {
137136
}
138137

139138
@Override
140-
public void close() {
141-
state.close();
142-
}
139+
public void close() {}
143140
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/PresentGroupingAggregatorFunction.java

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.compute.aggregation;
99

10+
import org.elasticsearch.common.util.BitArray;
1011
import org.elasticsearch.compute.data.Block;
1112
import org.elasticsearch.compute.data.BooleanBlock;
1213
import org.elasticsearch.compute.data.BooleanVector;
@@ -22,23 +23,22 @@
2223
public class PresentGroupingAggregatorFunction implements GroupingAggregatorFunction {
2324

2425
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
25-
new IntermediateStateDesc("present", ElementType.BOOLEAN),
26-
new IntermediateStateDesc("seen", ElementType.BOOLEAN)
26+
new IntermediateStateDesc("present", ElementType.BOOLEAN)
2727
);
2828

29-
private final BooleanArrayState state;
29+
private final BitArray state;
3030
private final List<Integer> channels;
3131
private final DriverContext driverContext;
3232

3333
public static PresentGroupingAggregatorFunction create(DriverContext driverContext, List<Integer> inputChannels) {
34-
return new PresentGroupingAggregatorFunction(inputChannels, new BooleanArrayState(driverContext.bigArrays(), false), driverContext);
34+
return new PresentGroupingAggregatorFunction(inputChannels, new BitArray(1, driverContext.bigArrays()), driverContext);
3535
}
3636

3737
public static List<IntermediateStateDesc> intermediateStateDesc() {
3838
return INTERMEDIATE_STATE_DESC;
3939
}
4040

41-
private PresentGroupingAggregatorFunction(List<Integer> channels, BooleanArrayState state, DriverContext driverContext) {
41+
private PresentGroupingAggregatorFunction(List<Integer> channels, BitArray state, DriverContext driverContext) {
4242
this.channels = channels;
4343
this.state = state;
4444
this.driverContext = driverContext;
@@ -57,10 +57,6 @@ public int intermediateBlockCount() {
5757
public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) {
5858
Block valuesBlock = page.getBlock(blockIndex());
5959

60-
if (valuesBlock.mayHaveNulls()) {
61-
state.enableGroupIdTracking(seenGroupIds);
62-
}
63-
6460
return new AddInput() {
6561
@Override
6662
public void add(int positionOffset, IntArrayBlock groupIds) {
@@ -88,8 +84,7 @@ private void addRawInput(int positionOffset, IntVector groups, Block values) {
8884
if (values.isNull(position)) {
8985
continue;
9086
}
91-
int groupId = groups.getInt(groupPosition);
92-
state.set(groupId, state.getOrDefault(groupId) || values.getValueCount(position) > 0);
87+
state.set(groups.getInt(groupPosition), true);
9388
}
9489
}
9590

@@ -102,8 +97,7 @@ private void addRawInput(int positionOffset, IntArrayBlock groups, Block values)
10297
int groupStart = groups.getFirstValueIndex(groupPosition);
10398
int groupEnd = groupStart + groups.getValueCount(groupPosition);
10499
for (int g = groupStart; g < groupEnd; g++) {
105-
int groupId = groups.getInt(g);
106-
state.set(groupId, state.getOrDefault(groupId) || values.getValueCount(position) > 0);
100+
state.set(groups.getInt(g), true);
107101
}
108102
}
109103
}
@@ -117,34 +111,29 @@ private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block valu
117111
int groupStart = groups.getFirstValueIndex(groupPosition);
118112
int groupEnd = groupStart + groups.getValueCount(groupPosition);
119113
for (int g = groupStart; g < groupEnd; g++) {
120-
int groupId = groups.getInt(g);
121-
state.set(groupId, state.getOrDefault(groupId) || values.getValueCount(position) > 0);
114+
state.set(groups.getInt(g), true);
122115
}
123116
}
124117
}
125118

126119
@Override
127-
public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
128-
state.enableGroupIdTracking(seenGroupIds);
129-
}
120+
public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {}
130121

131122
@Override
132123
public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) {
133124
assert channels.size() == intermediateBlockCount();
134125
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
135-
state.enableGroupIdTracking(new SeenGroupIds.Empty());
136126
BooleanVector present = page.<BooleanBlock>getBlock(channels.get(0)).asVector();
137-
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
138-
assert present.getPositionCount() == seen.getPositionCount();
139127
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
140128
if (groups.isNull(groupPosition)) {
141129
continue;
142130
}
143131
int groupStart = groups.getFirstValueIndex(groupPosition);
144132
int groupEnd = groupStart + groups.getValueCount(groupPosition);
145133
for (int g = groupStart; g < groupEnd; g++) {
146-
int groupId = groups.getInt(g);
147-
state.set(groupId, state.getOrDefault(groupId) || present.getBoolean(groupPosition + positionOffset));
134+
if (present.getBoolean(groupPosition + positionOffset)) {
135+
state.set(groups.getInt(g), true);
136+
}
148137
}
149138
}
150139
}
@@ -153,19 +142,17 @@ public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page
153142
public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) {
154143
assert channels.size() == intermediateBlockCount();
155144
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
156-
state.enableGroupIdTracking(new SeenGroupIds.Empty());
157145
BooleanVector present = page.<BooleanBlock>getBlock(channels.get(0)).asVector();
158-
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
159-
assert present.getPositionCount() == seen.getPositionCount();
160146
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
161147
if (groups.isNull(groupPosition)) {
162148
continue;
163149
}
164150
int groupStart = groups.getFirstValueIndex(groupPosition);
165151
int groupEnd = groupStart + groups.getValueCount(groupPosition);
166152
for (int g = groupStart; g < groupEnd; g++) {
167-
int groupId = groups.getInt(g);
168-
state.set(groupId, state.getOrDefault(groupId) || present.getBoolean(groupPosition + positionOffset));
153+
if (present.getBoolean(groupPosition + positionOffset)) {
154+
state.set(groups.getInt(g), true);
155+
}
169156
}
170157
}
171158
}
@@ -174,27 +161,35 @@ public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Pa
174161
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
175162
assert channels.size() == intermediateBlockCount();
176163
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
177-
state.enableGroupIdTracking(new SeenGroupIds.Empty());
178164
BooleanVector present = page.<BooleanBlock>getBlock(channels.get(0)).asVector();
179-
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
180-
assert present.getPositionCount() == seen.getPositionCount();
181165
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
182-
int groupId = groups.getInt(groupPosition);
183-
state.set(groupId, state.getOrDefault(groupId) || present.getBoolean(groupPosition + positionOffset));
166+
if (present.getBoolean(groupPosition + positionOffset)) {
167+
state.set(groups.getInt(groupPosition), true);
168+
}
184169
}
185170
}
186171

187172
@Override
188173
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
189-
state.toIntermediate(blocks, offset, selected, driverContext);
174+
try (var valuesBuilder = driverContext.blockFactory().newBooleanBlockBuilder(selected.getPositionCount())) {
175+
for (int i = 0; i < selected.getPositionCount(); i++) {
176+
int group = selected.getInt(i);
177+
if (group < state.size()) {
178+
valuesBuilder.appendBoolean(state.get(group));
179+
} else {
180+
valuesBuilder.appendBoolean(false);
181+
}
182+
}
183+
blocks[offset] = valuesBuilder.build();
184+
}
190185
}
191186

192187
@Override
193188
public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) {
194189
try (BooleanVector.Builder builder = evaluationContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount())) {
195190
for (int i = 0; i < selected.getPositionCount(); i++) {
196191
int si = selected.getInt(i);
197-
builder.appendBoolean(state.hasValue(si) && state.getOrDefault(si));
192+
builder.appendBoolean(state.get(si));
198193
}
199194
blocks[offset] = builder.build().asBlock();
200195
}

0 commit comments

Comments
 (0)