Skip to content

Commit 74cf2c6

Browse files
authored
ESQL: Add nulls support to Categorize (#117655) (#117716)
Handle nulls and empty strings (Which resolve to null) on Categorize grouping function. Also, implement `seenGroupIds()`, which would fail some queries with nulls otherwise.
1 parent 1885134 commit 74cf2c6

File tree

10 files changed

+189
-96
lines changed

10 files changed

+189
-96
lines changed

docs/changelog/117655.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117655
2+
summary: Add nulls support to Categorize
3+
area: ES|QL
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import org.elasticsearch.common.util.BigArrays;
1414
import org.elasticsearch.common.util.BitArray;
1515
import org.elasticsearch.common.util.BytesRefHash;
16+
import org.elasticsearch.compute.aggregation.SeenGroupIds;
1617
import org.elasticsearch.compute.data.Block;
1718
import org.elasticsearch.compute.data.BlockFactory;
19+
import org.elasticsearch.compute.data.BytesRefBlock;
1820
import org.elasticsearch.compute.data.BytesRefVector;
1921
import org.elasticsearch.compute.data.IntBlock;
2022
import org.elasticsearch.compute.data.IntVector;
@@ -31,11 +33,21 @@
3133
* Base BlockHash implementation for {@code Categorize} grouping function.
3234
*/
3335
public abstract class AbstractCategorizeBlockHash extends BlockHash {
36+
protected static final int NULL_ORD = 0;
37+
3438
// TODO: this should probably also take an emitBatchSize
3539
private final int channel;
3640
private final boolean outputPartial;
3741
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
3842

43+
/**
44+
* Store whether we've seen any {@code null} values.
45+
* <p>
46+
* Null gets the {@link #NULL_ORD} ord.
47+
* </p>
48+
*/
49+
protected boolean seenNull = false;
50+
3951
AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
4052
super(blockFactory);
4153
this.channel = channel;
@@ -58,12 +70,12 @@ public Block[] getKeys() {
5870

5971
@Override
6072
public IntVector nonEmpty() {
61-
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
73+
return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory);
6274
}
6375

6476
@Override
6577
public BitArray seenGroupIds(BigArrays bigArrays) {
66-
throw new UnsupportedOperationException();
78+
return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
6779
}
6880

6981
@Override
@@ -76,24 +88,39 @@ public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue target
7688
*/
7789
private Block buildIntermediateBlock() {
7890
if (categorizer.getCategoryCount() == 0) {
79-
return blockFactory.newConstantNullBlock(0);
91+
return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
8092
}
8193
try (BytesStreamOutput out = new BytesStreamOutput()) {
8294
// TODO be more careful here.
95+
out.writeBoolean(seenNull);
8396
out.writeVInt(categorizer.getCategoryCount());
8497
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
8598
category.writeTo(out);
8699
}
87100
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
88-
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
101+
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
102+
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
89103
} catch (IOException e) {
90104
throw new RuntimeException(e);
91105
}
92106
}
93107

94108
private Block buildFinalBlock() {
109+
BytesRefBuilder scratch = new BytesRefBuilder();
110+
111+
if (seenNull) {
112+
try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
113+
result.appendNull();
114+
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
115+
scratch.copyChars(category.getRegex());
116+
result.appendBytesRef(scratch.get());
117+
scratch.clear();
118+
}
119+
return result.build();
120+
}
121+
}
122+
95123
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
96-
BytesRefBuilder scratch = new BytesRefBuilder();
97124
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
98125
scratch.copyChars(category.getRegex());
99126
result.appendBytesRef(scratch.get());

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void close() {
6464
/**
6565
* Similar implementation to an Evaluator.
6666
*/
67-
public static final class CategorizeEvaluator implements Releasable {
67+
public final class CategorizeEvaluator implements Releasable {
6868
private final CategorizationAnalyzer analyzer;
6969

7070
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
@@ -95,7 +95,8 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
9595
BytesRef vScratch = new BytesRef();
9696
for (int p = 0; p < positionCount; p++) {
9797
if (vBlock.isNull(p)) {
98-
result.appendNull();
98+
seenNull = true;
99+
result.appendInt(NULL_ORD);
99100
continue;
100101
}
101102
int first = vBlock.getFirstValueIndex(p);
@@ -126,7 +127,12 @@ public IntVector eval(int positionCount, BytesRefVector vVector) {
126127
}
127128

128129
private int process(BytesRef v) {
129-
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
130+
var category = categorizer.computeCategory(v.utf8ToString(), analyzer);
131+
if (category == null) {
132+
seenNull = true;
133+
return NULL_ORD;
134+
}
135+
return category.getId() + 1;
130136
}
131137

132138
@Override

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
4040
return;
4141
}
4242
BytesRefBlock categorizerState = page.getBlock(channel());
43+
if (categorizerState.areAllValuesNull()) {
44+
seenNull = true;
45+
try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
46+
addInput.add(0, newIds);
47+
}
48+
return;
49+
}
50+
4351
Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
4452
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
45-
for (int i = 0; i < idMap.size(); i++) {
53+
int fromId = idMap.containsKey(0) ? 0 : 1;
54+
int toId = fromId + idMap.size();
55+
for (int i = fromId; i < toId; i++) {
4656
newIdsBuilder.appendInt(idMap.get(i));
4757
}
4858
try (IntBlock newIds = newIdsBuilder.build()) {
@@ -59,10 +69,15 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
5969
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
6070
Map<Integer, Integer> idMap = new HashMap<>();
6171
try (StreamInput in = new BytesArray(bytes).streamInput()) {
72+
if (in.readBoolean()) {
73+
seenNull = true;
74+
idMap.put(NULL_ORD, NULL_ORD);
75+
}
6276
int count = in.readVInt();
6377
for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
6478
int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
65-
idMap.put(oldCategoryId, newCategoryId);
79+
// +1 because the 0 ordinal is reserved for null
80+
idMap.put(oldCategoryId + 1, newCategoryId + 1);
6681
}
6782
return idMap;
6883
} catch (IOException e) {

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
5252

5353
public void testCategorizeRaw() {
5454
final Page page;
55-
final int positions = 7;
55+
boolean withNull = randomBoolean();
56+
final int positions = 7 + (withNull ? 1 : 0);
5657
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
5758
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
5859
builder.appendBytesRef(new BytesRef("Connection error"));
@@ -61,6 +62,13 @@ public void testCategorizeRaw() {
6162
builder.appendBytesRef(new BytesRef("Disconnected"));
6263
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
6364
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
65+
if (withNull) {
66+
if (randomBoolean()) {
67+
builder.appendNull();
68+
} else {
69+
builder.appendBytesRef(new BytesRef(""));
70+
}
71+
}
6472
page = new Page(builder.build());
6573
}
6674

@@ -70,13 +78,16 @@ public void testCategorizeRaw() {
7078
public void add(int positionOffset, IntBlock groupIds) {
7179
assertEquals(groupIds.getPositionCount(), positions);
7280

73-
assertEquals(0, groupIds.getInt(0));
74-
assertEquals(1, groupIds.getInt(1));
75-
assertEquals(1, groupIds.getInt(2));
76-
assertEquals(1, groupIds.getInt(3));
77-
assertEquals(2, groupIds.getInt(4));
78-
assertEquals(0, groupIds.getInt(5));
79-
assertEquals(0, groupIds.getInt(6));
81+
assertEquals(1, groupIds.getInt(0));
82+
assertEquals(2, groupIds.getInt(1));
83+
assertEquals(2, groupIds.getInt(2));
84+
assertEquals(2, groupIds.getInt(3));
85+
assertEquals(3, groupIds.getInt(4));
86+
assertEquals(1, groupIds.getInt(5));
87+
assertEquals(1, groupIds.getInt(6));
88+
if (withNull) {
89+
assertEquals(0, groupIds.getInt(7));
90+
}
8091
}
8192

8293
@Override
@@ -100,7 +111,8 @@ public void close() {
100111

101112
public void testCategorizeIntermediate() {
102113
Page page1;
103-
int positions1 = 7;
114+
boolean withNull = randomBoolean();
115+
int positions1 = 7 + (withNull ? 1 : 0);
104116
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
105117
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
106118
builder.appendBytesRef(new BytesRef("Connection error"));
@@ -109,6 +121,13 @@ public void testCategorizeIntermediate() {
109121
builder.appendBytesRef(new BytesRef("Connection error"));
110122
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
111123
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
124+
if (withNull) {
125+
if (randomBoolean()) {
126+
builder.appendNull();
127+
} else {
128+
builder.appendBytesRef(new BytesRef(""));
129+
}
130+
}
112131
page1 = new Page(builder.build());
113132
}
114133
Page page2;
@@ -133,13 +152,16 @@ public void testCategorizeIntermediate() {
133152
@Override
134153
public void add(int positionOffset, IntBlock groupIds) {
135154
assertEquals(groupIds.getPositionCount(), positions1);
136-
assertEquals(0, groupIds.getInt(0));
137-
assertEquals(1, groupIds.getInt(1));
138-
assertEquals(1, groupIds.getInt(2));
139-
assertEquals(0, groupIds.getInt(3));
140-
assertEquals(1, groupIds.getInt(4));
141-
assertEquals(0, groupIds.getInt(5));
142-
assertEquals(0, groupIds.getInt(6));
155+
assertEquals(1, groupIds.getInt(0));
156+
assertEquals(2, groupIds.getInt(1));
157+
assertEquals(2, groupIds.getInt(2));
158+
assertEquals(1, groupIds.getInt(3));
159+
assertEquals(2, groupIds.getInt(4));
160+
assertEquals(1, groupIds.getInt(5));
161+
assertEquals(1, groupIds.getInt(6));
162+
if (withNull) {
163+
assertEquals(0, groupIds.getInt(7));
164+
}
143165
}
144166

145167
@Override
@@ -158,11 +180,11 @@ public void close() {
158180
@Override
159181
public void add(int positionOffset, IntBlock groupIds) {
160182
assertEquals(groupIds.getPositionCount(), positions2);
161-
assertEquals(0, groupIds.getInt(0));
162-
assertEquals(1, groupIds.getInt(1));
163-
assertEquals(0, groupIds.getInt(2));
164-
assertEquals(1, groupIds.getInt(3));
165-
assertEquals(2, groupIds.getInt(4));
183+
assertEquals(1, groupIds.getInt(0));
184+
assertEquals(2, groupIds.getInt(1));
185+
assertEquals(1, groupIds.getInt(2));
186+
assertEquals(2, groupIds.getInt(3));
187+
assertEquals(3, groupIds.getInt(4));
166188
}
167189

168190
@Override
@@ -189,7 +211,11 @@ public void add(int positionOffset, IntBlock groupIds) {
189211
.map(groupIds::getInt)
190212
.boxed()
191213
.collect(Collectors.toSet());
192-
assertEquals(values, Set.of(0, 1));
214+
if (withNull) {
215+
assertEquals(Set.of(0, 1, 2), values);
216+
} else {
217+
assertEquals(Set.of(1, 2), values);
218+
}
193219
}
194220

195221
@Override
@@ -212,7 +238,7 @@ public void add(int positionOffset, IntBlock groupIds) {
212238
.collect(Collectors.toSet());
213239
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
214240
// 0 matches an existing category (Connected to ...), and the others are new.
215-
assertEquals(values, Set.of(0, 2, 3));
241+
assertEquals(Set.of(1, 3, 4), values);
216242
}
217243

218244
@Override

0 commit comments

Comments
 (0)