Skip to content

Commit 0ba57ea

Browse files
committed
ES|QL: Add PRESENT ES|QL function
Add a new ES|QL function that checks for the presence of a field in the output result. Presence means that the input expression yields any non-null value. Part of #131069
1 parent 54a2472 commit 0ba57ea

File tree

8 files changed

+643
-3
lines changed

8 files changed

+643
-3
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.BooleanBlock;
12+
import org.elasticsearch.compute.data.BooleanVector;
13+
import org.elasticsearch.compute.data.ElementType;
14+
import org.elasticsearch.compute.data.IntBlock;
15+
import org.elasticsearch.compute.data.IntVector;
16+
import org.elasticsearch.compute.data.Page;
17+
import org.elasticsearch.compute.operator.DriverContext;
18+
19+
import java.util.List;
20+
21+
public class PresentAggregatorFunction implements AggregatorFunction {
22+
public static AggregatorFunctionSupplier supplier() {
23+
return new AggregatorFunctionSupplier() {
24+
@Override
25+
public List<IntermediateStateDesc> nonGroupingIntermediateStateDesc() {
26+
return PresentAggregatorFunction.intermediateStateDesc();
27+
}
28+
29+
@Override
30+
public List<IntermediateStateDesc> groupingIntermediateStateDesc() {
31+
return PresentGroupingAggregatorFunction.intermediateStateDesc();
32+
}
33+
34+
@Override
35+
public AggregatorFunction aggregator(DriverContext driverContext, List<Integer> channels) {
36+
return PresentAggregatorFunction.create(channels);
37+
}
38+
39+
@Override
40+
public GroupingAggregatorFunction groupingAggregator(DriverContext driverContext, List<Integer> channels) {
41+
return PresentGroupingAggregatorFunction.create(driverContext, channels);
42+
}
43+
44+
@Override
45+
public String describe() {
46+
return "present";
47+
}
48+
};
49+
}
50+
51+
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
52+
new IntermediateStateDesc("count", ElementType.INT),
53+
new IntermediateStateDesc("seen", ElementType.BOOLEAN)
54+
);
55+
56+
public static List<IntermediateStateDesc> intermediateStateDesc() {
57+
return INTERMEDIATE_STATE_DESC;
58+
}
59+
60+
private final IntState state;
61+
private final List<Integer> channels;
62+
63+
public static PresentAggregatorFunction create(List<Integer> inputChannels) {
64+
return new PresentAggregatorFunction(inputChannels, new IntState(0));
65+
}
66+
67+
private PresentAggregatorFunction(List<Integer> channels, IntState state) {
68+
this.channels = channels;
69+
this.state = state;
70+
}
71+
72+
@Override
73+
public int intermediateBlockCount() {
74+
return intermediateStateDesc().size();
75+
}
76+
77+
private int blockIndex() {
78+
return channels.get(0);
79+
}
80+
81+
@Override
82+
public void addRawInput(Page page, BooleanVector mask) {
83+
Block block = page.getBlock(blockIndex());
84+
IntState state = this.state;
85+
int count;
86+
if (mask.isConstant()) {
87+
if (mask.getBoolean(0) == false) {
88+
return;
89+
}
90+
count = block.getTotalValueCount();
91+
} else {
92+
count = countMasked(block, mask);
93+
}
94+
state.intValue(Math.min(count, 1));
95+
}
96+
97+
private int countMasked(Block block, BooleanVector mask) {
98+
for (int p = 0; p < block.getPositionCount(); p++) {
99+
if (mask.getBoolean(p)) {
100+
return 1;
101+
}
102+
}
103+
return 0;
104+
}
105+
106+
@Override
107+
public void addIntermediateInput(Page page) {
108+
assert channels.size() == intermediateBlockCount();
109+
var blockIndex = blockIndex();
110+
assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size();
111+
Block uncastBlock = page.getBlock(channels.get(0));
112+
if (uncastBlock.areAllValuesNull()) {
113+
return;
114+
}
115+
IntVector count = page.<IntBlock>getBlock(channels.get(0)).asVector();
116+
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
117+
assert count.getPositionCount() == 1;
118+
assert count.getPositionCount() == seen.getPositionCount();
119+
state.intValue(Math.min(state.intValue() + count.getInt(0), 1));
120+
}
121+
122+
@Override
123+
public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
124+
state.toIntermediate(blocks, offset, driverContext);
125+
}
126+
127+
@Override
128+
public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
129+
blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
StringBuilder sb = new StringBuilder();
135+
sb.append(this.getClass().getSimpleName()).append("[");
136+
sb.append("channels=").append(channels);
137+
sb.append("]");
138+
return sb.toString();
139+
}
140+
141+
@Override
142+
public void close() {
143+
state.close();
144+
}
145+
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation;
9+
10+
import org.elasticsearch.compute.data.Block;
11+
import org.elasticsearch.compute.data.BooleanBlock;
12+
import org.elasticsearch.compute.data.BooleanVector;
13+
import org.elasticsearch.compute.data.ElementType;
14+
import org.elasticsearch.compute.data.IntArrayBlock;
15+
import org.elasticsearch.compute.data.IntBigArrayBlock;
16+
import org.elasticsearch.compute.data.IntBlock;
17+
import org.elasticsearch.compute.data.IntVector;
18+
import org.elasticsearch.compute.data.Page;
19+
import org.elasticsearch.compute.operator.DriverContext;
20+
21+
import java.util.List;
22+
23+
public class PresentGroupingAggregatorFunction implements GroupingAggregatorFunction {
24+
25+
private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
26+
new IntermediateStateDesc("count", ElementType.INT),
27+
new IntermediateStateDesc("seen", ElementType.BOOLEAN)
28+
);
29+
30+
private final IntArrayState state;
31+
private final List<Integer> channels;
32+
private final DriverContext driverContext;
33+
34+
public static PresentGroupingAggregatorFunction create(DriverContext driverContext, List<Integer> inputChannels) {
35+
return new PresentGroupingAggregatorFunction(inputChannels, new IntArrayState(driverContext.bigArrays(), 0), driverContext);
36+
}
37+
38+
public static List<IntermediateStateDesc> intermediateStateDesc() {
39+
return INTERMEDIATE_STATE_DESC;
40+
}
41+
42+
private PresentGroupingAggregatorFunction(List<Integer> channels, IntArrayState state, DriverContext driverContext) {
43+
this.channels = channels;
44+
this.state = state;
45+
this.driverContext = driverContext;
46+
}
47+
48+
private int blockIndex() {
49+
return channels.get(0);
50+
}
51+
52+
@Override
53+
public int intermediateBlockCount() {
54+
return intermediateStateDesc().size();
55+
}
56+
57+
@Override
58+
public AddInput prepareProcessRawInputPage(SeenGroupIds seenGroupIds, Page page) {
59+
Block valuesBlock = page.getBlock(blockIndex());
60+
61+
if (valuesBlock.mayHaveNulls()) {
62+
state.enableGroupIdTracking(seenGroupIds);
63+
}
64+
65+
return new AddInput() {
66+
@Override
67+
public void add(int positionOffset, IntArrayBlock groupIds) {
68+
addRawInput(positionOffset, groupIds, valuesBlock);
69+
}
70+
71+
@Override
72+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
73+
addRawInput(positionOffset, groupIds, valuesBlock);
74+
}
75+
76+
@Override
77+
public void add(int positionOffset, IntVector groupIds) {
78+
addRawInput(positionOffset, groupIds, valuesBlock);
79+
}
80+
81+
@Override
82+
public void close() {}
83+
};
84+
}
85+
86+
private void addRawInput(int positionOffset, IntVector groups, Block values) {
87+
int position = positionOffset;
88+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
89+
if (values.isNull(position)) {
90+
continue;
91+
}
92+
int groupId = groups.getInt(groupPosition);
93+
state.set(groupId, Math.min(state.getOrDefault(groupId) + values.getValueCount(position), 1));
94+
}
95+
}
96+
97+
private void addRawInput(int positionOffset, IntArrayBlock groups, Block values) {
98+
int position = positionOffset;
99+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
100+
if (groups.isNull(groupPosition) || values.isNull(position)) {
101+
continue;
102+
}
103+
int groupStart = groups.getFirstValueIndex(groupPosition);
104+
int groupEnd = groupStart + groups.getValueCount(groupPosition);
105+
for (int g = groupStart; g < groupEnd; g++) {
106+
int groupId = groups.getInt(g);
107+
state.set(groupId, Math.min(state.getOrDefault(groupId) + values.getValueCount(position), 1));
108+
}
109+
}
110+
}
111+
112+
private void addRawInput(int positionOffset, IntBigArrayBlock groups, Block values) {
113+
int position = positionOffset;
114+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++, position++) {
115+
if (groups.isNull(groupPosition) || values.isNull(position)) {
116+
continue;
117+
}
118+
int groupStart = groups.getFirstValueIndex(groupPosition);
119+
int groupEnd = groupStart + groups.getValueCount(groupPosition);
120+
for (int g = groupStart; g < groupEnd; g++) {
121+
int groupId = groups.getInt(g);
122+
state.set(groupId, Math.min(state.getOrDefault(groupId) + values.getValueCount(position), 1));
123+
}
124+
}
125+
}
126+
127+
@Override
128+
public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
129+
state.enableGroupIdTracking(seenGroupIds);
130+
}
131+
132+
@Override
133+
public void addIntermediateInput(int positionOffset, IntArrayBlock groups, Page page) {
134+
assert channels.size() == intermediateBlockCount();
135+
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
136+
state.enableGroupIdTracking(new SeenGroupIds.Empty());
137+
IntVector count = page.<IntBlock>getBlock(channels.get(0)).asVector();
138+
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
139+
assert count.getPositionCount() == seen.getPositionCount();
140+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
141+
if (groups.isNull(groupPosition)) {
142+
continue;
143+
}
144+
int groupStart = groups.getFirstValueIndex(groupPosition);
145+
int groupEnd = groupStart + groups.getValueCount(groupPosition);
146+
for (int g = groupStart; g < groupEnd; g++) {
147+
int groupId = groups.getInt(g);
148+
state.set(groupId, Math.min(state.getOrDefault(groupId) + count.getInt(groupPosition + positionOffset), 1));
149+
}
150+
}
151+
}
152+
153+
@Override
154+
public void addIntermediateInput(int positionOffset, IntBigArrayBlock groups, Page page) {
155+
assert channels.size() == intermediateBlockCount();
156+
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
157+
state.enableGroupIdTracking(new SeenGroupIds.Empty());
158+
IntVector count = page.<IntBlock>getBlock(channels.get(0)).asVector();
159+
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
160+
assert count.getPositionCount() == seen.getPositionCount();
161+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
162+
if (groups.isNull(groupPosition)) {
163+
continue;
164+
}
165+
int groupStart = groups.getFirstValueIndex(groupPosition);
166+
int groupEnd = groupStart + groups.getValueCount(groupPosition);
167+
for (int g = groupStart; g < groupEnd; g++) {
168+
int groupId = groups.getInt(g);
169+
state.set(groupId, Math.min(state.getOrDefault(groupId) + count.getInt(groupPosition + positionOffset), 1));
170+
}
171+
}
172+
}
173+
174+
@Override
175+
public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
176+
assert channels.size() == intermediateBlockCount();
177+
assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
178+
state.enableGroupIdTracking(new SeenGroupIds.Empty());
179+
IntVector count = page.<IntBlock>getBlock(channels.get(0)).asVector();
180+
BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
181+
assert count.getPositionCount() == seen.getPositionCount();
182+
for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
183+
int groupId = groups.getInt(groupPosition);
184+
state.set(groupId, Math.min(state.getOrDefault(groupId) + count.getInt(groupPosition + positionOffset), 1));
185+
}
186+
}
187+
188+
@Override
189+
public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
190+
state.toIntermediate(blocks, offset, selected, driverContext);
191+
}
192+
193+
@Override
194+
public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evaluationContext) {
195+
try (IntVector.Builder builder = evaluationContext.blockFactory().newIntVectorFixedBuilder(selected.getPositionCount())) {
196+
for (int i = 0; i < selected.getPositionCount(); i++) {
197+
int si = selected.getInt(i);
198+
builder.appendInt(state.hasValue(si) ? state.getOrDefault(si) : 0);
199+
}
200+
blocks[offset] = builder.build().asBlock();
201+
}
202+
}
203+
204+
@Override
205+
public String toString() {
206+
StringBuilder sb = new StringBuilder();
207+
sb.append(this.getClass().getSimpleName()).append("[");
208+
sb.append("channels=").append(channels);
209+
sb.append("]");
210+
return sb.toString();
211+
}
212+
213+
@Override
214+
public void close() {
215+
state.close();
216+
}
217+
}

0 commit comments

Comments
 (0)