Skip to content

Commit 01a6493

Browse files
authored
ESQL: Categorize grouping function testing improvements (#118013) (#118262)
Added some extra tests on the CategorizeBlockHash. Added NullFold rule comments, and forced nullable() to TRUE on Categorize.
1 parent 1b303fc commit 01a6493

File tree

5 files changed

+229
-58
lines changed

5 files changed

+229
-58
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Nullability.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
package org.elasticsearch.xpack.esql.core.expression;
88

99
public enum Nullability {
10-
TRUE, // Whether the expression can become null
11-
FALSE, // The expression can never become null
12-
UNKNOWN // Cannot determine if the expression supports possible null folding
10+
/**
11+
* Whether the expression can become null
12+
*/
13+
TRUE,
14+
15+
/**
16+
* The expression can never become null
17+
*/
18+
FALSE,
19+
20+
/**
21+
* Cannot determine if the expression supports possible null folding
22+
*/
23+
UNKNOWN
1324
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 182 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@
5050
import java.util.HashMap;
5151
import java.util.List;
5252
import java.util.Map;
53-
import java.util.Set;
5453
import java.util.stream.Collectors;
5554
import java.util.stream.IntStream;
5655

5756
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
57+
import static org.hamcrest.Matchers.arrayWithSize;
5858
import static org.hamcrest.Matchers.equalTo;
5959
import static org.hamcrest.Matchers.hasSize;
6060

@@ -95,41 +95,114 @@ public void testCategorizeRaw() {
9595
page = new Page(builder.build());
9696
}
9797

98-
try (BlockHash hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, analysisRegistry)) {
99-
hash.add(page, new GroupingAggregatorFunction.AddInput() {
100-
@Override
101-
public void add(int positionOffset, IntBlock groupIds) {
102-
assertEquals(groupIds.getPositionCount(), positions);
98+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
99+
for (int i = randomInt(2); i < 3; i++) {
100+
hash.add(page, new GroupingAggregatorFunction.AddInput() {
101+
@Override
102+
public void add(int positionOffset, IntBlock groupIds) {
103+
assertEquals(groupIds.getPositionCount(), positions);
104+
105+
assertEquals(1, groupIds.getInt(0));
106+
assertEquals(2, groupIds.getInt(1));
107+
assertEquals(2, groupIds.getInt(2));
108+
assertEquals(2, groupIds.getInt(3));
109+
assertEquals(3, groupIds.getInt(4));
110+
assertEquals(1, groupIds.getInt(5));
111+
assertEquals(1, groupIds.getInt(6));
112+
if (withNull) {
113+
assertEquals(0, groupIds.getInt(7));
114+
}
115+
}
103116

104-
assertEquals(1, groupIds.getInt(0));
105-
assertEquals(2, groupIds.getInt(1));
106-
assertEquals(2, groupIds.getInt(2));
107-
assertEquals(2, groupIds.getInt(3));
108-
assertEquals(3, groupIds.getInt(4));
109-
assertEquals(1, groupIds.getInt(5));
110-
assertEquals(1, groupIds.getInt(6));
111-
if (withNull) {
112-
assertEquals(0, groupIds.getInt(7));
117+
@Override
118+
public void add(int positionOffset, IntVector groupIds) {
119+
add(positionOffset, groupIds.asBlock());
113120
}
114-
}
115121

116-
@Override
117-
public void add(int positionOffset, IntVector groupIds) {
118-
add(positionOffset, groupIds.asBlock());
119-
}
122+
@Override
123+
public void close() {
124+
fail("hashes should not close AddInput");
125+
}
126+
});
120127

121-
@Override
122-
public void close() {
123-
fail("hashes should not close AddInput");
124-
}
125-
});
128+
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
129+
}
126130
} finally {
127131
page.releaseBlocks();
128132
}
129133

130-
// TODO: randomize and try multiple pages.
131-
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
132-
// TODO: also test the lookup method and other stuff.
134+
// TODO: randomize values? May give wrong results
135+
// TODO: assert the categorizer state after adding pages.
136+
}
137+
138+
public void testCategorizeRawMultivalue() {
139+
final Page page;
140+
boolean withNull = randomBoolean();
141+
final int positions = 3 + (withNull ? 1 : 0);
142+
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
143+
builder.beginPositionEntry();
144+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
145+
builder.appendBytesRef(new BytesRef("Connection error"));
146+
builder.appendBytesRef(new BytesRef("Connection error"));
147+
builder.appendBytesRef(new BytesRef("Connection error"));
148+
builder.endPositionEntry();
149+
builder.appendBytesRef(new BytesRef("Disconnected"));
150+
builder.beginPositionEntry();
151+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
152+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
153+
builder.endPositionEntry();
154+
if (withNull) {
155+
if (randomBoolean()) {
156+
builder.appendNull();
157+
} else {
158+
builder.appendBytesRef(new BytesRef(""));
159+
}
160+
}
161+
page = new Page(builder.build());
162+
}
163+
164+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, analysisRegistry)) {
165+
for (int i = randomInt(2); i < 3; i++) {
166+
hash.add(page, new GroupingAggregatorFunction.AddInput() {
167+
@Override
168+
public void add(int positionOffset, IntBlock groupIds) {
169+
assertEquals(groupIds.getPositionCount(), positions);
170+
171+
assertThat(groupIds.getFirstValueIndex(0), equalTo(0));
172+
assertThat(groupIds.getValueCount(0), equalTo(4));
173+
assertThat(groupIds.getFirstValueIndex(1), equalTo(4));
174+
assertThat(groupIds.getValueCount(1), equalTo(1));
175+
assertThat(groupIds.getFirstValueIndex(2), equalTo(5));
176+
assertThat(groupIds.getValueCount(2), equalTo(2));
177+
178+
assertEquals(1, groupIds.getInt(0));
179+
assertEquals(2, groupIds.getInt(1));
180+
assertEquals(2, groupIds.getInt(2));
181+
assertEquals(2, groupIds.getInt(3));
182+
assertEquals(3, groupIds.getInt(4));
183+
assertEquals(1, groupIds.getInt(5));
184+
assertEquals(1, groupIds.getInt(6));
185+
if (withNull) {
186+
assertEquals(0, groupIds.getInt(7));
187+
}
188+
}
189+
190+
@Override
191+
public void add(int positionOffset, IntVector groupIds) {
192+
add(positionOffset, groupIds.asBlock());
193+
}
194+
195+
@Override
196+
public void close() {
197+
fail("hashes should not close AddInput");
198+
}
199+
});
200+
201+
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
202+
}
203+
} finally {
204+
page.releaseBlocks();
205+
}
133206
}
134207

135208
public void testCategorizeIntermediate() {
@@ -226,18 +299,18 @@ public void close() {
226299
page2.releaseBlocks();
227300
}
228301

229-
try (BlockHash intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INTERMEDIATE, null)) {
302+
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, null)) {
230303
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
231304
@Override
232305
public void add(int positionOffset, IntBlock groupIds) {
233-
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
306+
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
234307
.map(groupIds::getInt)
235308
.boxed()
236-
.collect(Collectors.toSet());
309+
.collect(Collectors.toList());
237310
if (withNull) {
238-
assertEquals(Set.of(0, 1, 2), values);
311+
assertEquals(List.of(0, 1, 2), values);
239312
} else {
240-
assertEquals(Set.of(1, 2), values);
313+
assertEquals(List.of(1, 2), values);
241314
}
242315
}
243316

@@ -252,28 +325,39 @@ public void close() {
252325
}
253326
});
254327

255-
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
256-
@Override
257-
public void add(int positionOffset, IntBlock groupIds) {
258-
Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
259-
.map(groupIds::getInt)
260-
.boxed()
261-
.collect(Collectors.toSet());
262-
// The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
263-
// 0 matches an existing category (Connected to ...), and the others are new.
264-
assertEquals(Set.of(1, 3, 4), values);
265-
}
328+
for (int i = randomInt(2); i < 3; i++) {
329+
intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
330+
@Override
331+
public void add(int positionOffset, IntBlock groupIds) {
332+
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
333+
.map(groupIds::getInt)
334+
.boxed()
335+
.collect(Collectors.toList());
336+
// The category IDs {1, 2, 3} should map to groups {1, 3, 4}, because
337+
// 1 matches an existing category (Connected to ...), and the others are new.
338+
assertEquals(List.of(3, 1, 4), values);
339+
}
266340

267-
@Override
268-
public void add(int positionOffset, IntVector groupIds) {
269-
add(positionOffset, groupIds.asBlock());
270-
}
341+
@Override
342+
public void add(int positionOffset, IntVector groupIds) {
343+
add(positionOffset, groupIds.asBlock());
344+
}
271345

272-
@Override
273-
public void close() {
274-
fail("hashes should not close AddInput");
275-
}
276-
});
346+
@Override
347+
public void close() {
348+
fail("hashes should not close AddInput");
349+
}
350+
});
351+
352+
assertHashState(
353+
intermediateHash,
354+
withNull,
355+
".*?Connected.+?to.*?",
356+
".*?Connection.+?error.*?",
357+
".*?Disconnected.*?",
358+
".*?System.+?shutdown.*?"
359+
);
360+
}
277361
} finally {
278362
intermediatePage1.releaseBlocks();
279363
intermediatePage2.releaseBlocks();
@@ -457,4 +541,49 @@ public void testCategorize_withDriver() {
457541
private BlockHash.GroupSpec makeGroupSpec() {
458542
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
459543
}
544+
545+
private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
546+
// Check the keys
547+
Block[] blocks = null;
548+
try {
549+
blocks = hash.getKeys();
550+
assertThat(blocks, arrayWithSize(1));
551+
552+
var keysBlock = (BytesRefBlock) blocks[0];
553+
assertThat(keysBlock.getPositionCount(), equalTo(expectedKeys.length + (withNull ? 1 : 0)));
554+
555+
if (withNull) {
556+
assertTrue(keysBlock.isNull(0));
557+
}
558+
559+
for (int i = 0; i < expectedKeys.length; i++) {
560+
int position = i + (withNull ? 1 : 0);
561+
String key = keysBlock.getBytesRef(position, new BytesRef()).utf8ToString();
562+
assertThat(key, equalTo(expectedKeys[i]));
563+
}
564+
} finally {
565+
if (blocks != null) {
566+
Releasables.close(blocks);
567+
}
568+
}
569+
570+
// Check the nonEmpty() result
571+
try (IntVector nonEmptyKeys = hash.nonEmpty()) {
572+
int oneIfNull = withNull ? 1 : 0;
573+
assertThat(nonEmptyKeys.getPositionCount(), equalTo(expectedKeys.length + oneIfNull));
574+
575+
for (int i = 0; i < expectedKeys.length + oneIfNull; i++) {
576+
assertThat(nonEmptyKeys.getInt(i), equalTo(i + 1 - oneIfNull));
577+
}
578+
}
579+
580+
// Check seenGroupIds()
581+
try (var seenGroupIds = hash.seenGroupIds(blockFactory.bigArrays())) {
582+
assertThat(seenGroupIds.get(0), equalTo(withNull));
583+
584+
for (int i = 1; i <= expectedKeys.length; i++) {
585+
assertThat(seenGroupIds.get(i), equalTo(true));
586+
}
587+
}
588+
}
460589
}

x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,29 @@ COUNT():long | category:keyword
374374
7 | null
375375
;
376376

377+
on const null
378+
required_capability: categorize_v5
379+
380+
FROM sample_data
381+
| STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(null)
382+
| SORT category
383+
;
384+
385+
COUNT():long | SUM(event_duration):long | category:keyword
386+
7 | 23231327 | null
387+
;
388+
389+
on null row
390+
required_capability: categorize_v5
391+
392+
ROW message = null, str = ["a", "b", "c"]
393+
| STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
394+
;
395+
396+
COUNT():long | VALUES(str):keyword | category:keyword
397+
1 | [a, b, c] | null
398+
;
399+
377400
filtering out all data
378401
required_capability: categorize_v5
379402

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
1414
import org.elasticsearch.xpack.esql.capabilities.Validatable;
1515
import org.elasticsearch.xpack.esql.core.expression.Expression;
16+
import org.elasticsearch.xpack.esql.core.expression.Nullability;
1617
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1718
import org.elasticsearch.xpack.esql.core.tree.Source;
1819
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -92,6 +93,12 @@ public boolean foldable() {
9293
return false;
9394
}
9495

96+
@Override
97+
public Nullability nullable() {
98+
// Both nulls and empty strings result in null values
99+
return Nullability.TRUE;
100+
}
101+
95102
@Override
96103
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
97104
throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ public Expression rule(Expression e) {
4141
if (Expressions.isGuaranteedNull(in.value())) {
4242
return Literal.of(in, null);
4343
}
44-
} else if (e instanceof Alias == false
45-
&& e.nullable() == Nullability.TRUE
44+
} else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE
45+
// Categorize function stays as a STATS grouping (It isn't moved to an early EVAL like other groupings),
46+
// so folding it to null would currently break the plan, as we don't create an attribute/channel for that null value.
4647
&& e instanceof Categorize == false
4748
&& Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) {
4849
return Literal.of(e, null);

0 commit comments

Comments
 (0)