Skip to content

Commit 5300744

Browse files
committed
better test coverage + polish code
1 parent e95c295 commit 5300744

File tree

5 files changed

+94
-49
lines changed

5 files changed

+94
-49
lines changed

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@
6363

6464
public class CategorizeBlockHashTests extends BlockHashTestCase {
6565

66-
private static final BlockHash.CategorizeDef CATEGORIZE_DEF = new BlockHash.CategorizeDef(
67-
null,
68-
BlockHash.CategorizeDef.OutputFormat.REGEX,
69-
70
70-
);
71-
7266
private AnalysisRegistry analysisRegistry;
7367

7468
@Before
@@ -82,7 +76,13 @@ private void initAnalysisRegistry() throws IOException {
8276
).getAnalysisRegistry();
8377
}
8478

79+
private BlockHash.CategorizeDef getCategorizeDef() {
80+
return new BlockHash.CategorizeDef(null, randomFrom(BlockHash.CategorizeDef.OutputFormat.values()), 70);
81+
}
82+
8583
public void testCategorizeRaw() {
84+
BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
85+
8686
final Page page;
8787
boolean withNull = randomBoolean();
8888
final int positions = 7 + (withNull ? 1 : 0);
@@ -104,7 +104,7 @@ public void testCategorizeRaw() {
104104
page = new Page(builder.build());
105105
}
106106

107-
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, CATEGORIZE_DEF, analysisRegistry)) {
107+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) {
108108
for (int i = randomInt(2); i < 3; i++) {
109109
hash.add(page, new GroupingAggregatorFunction.AddInput() {
110110
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -143,14 +143,19 @@ public void close() {
143143
}
144144
});
145145

146-
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
146+
switch (categorizeDef.outputFormat()) {
147+
case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
148+
case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected");
149+
}
147150
}
148151
} finally {
149152
page.releaseBlocks();
150153
}
151154
}
152155

153156
public void testCategorizeRawMultivalue() {
157+
BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
158+
154159
final Page page;
155160
boolean withNull = randomBoolean();
156161
final int positions = 3 + (withNull ? 1 : 0);
@@ -176,7 +181,7 @@ public void testCategorizeRawMultivalue() {
176181
page = new Page(builder.build());
177182
}
178183

179-
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, CATEGORIZE_DEF, analysisRegistry)) {
184+
try (var hash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.SINGLE, categorizeDef, analysisRegistry)) {
180185
for (int i = randomInt(2); i < 3; i++) {
181186
hash.add(page, new GroupingAggregatorFunction.AddInput() {
182187
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -222,14 +227,19 @@ public void close() {
222227
}
223228
});
224229

225-
assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
230+
switch (categorizeDef.outputFormat()) {
231+
case REGEX -> assertHashState(hash, withNull, ".*?Connected.+?to.*?", ".*?Connection.+?error.*?", ".*?Disconnected.*?");
232+
case TOKENS -> assertHashState(hash, withNull, "Connected to", "Connection error", "Disconnected");
233+
}
226234
}
227235
} finally {
228236
page.releaseBlocks();
229237
}
230238
}
231239

232240
public void testCategorizeIntermediate() {
241+
BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
242+
233243
Page page1;
234244
boolean withNull = randomBoolean();
235245
int positions1 = 7 + (withNull ? 1 : 0);
@@ -265,8 +275,8 @@ public void testCategorizeIntermediate() {
265275

266276
// Fill intermediatePages with the intermediate state from the raw hashes
267277
try (
268-
BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, CATEGORIZE_DEF, analysisRegistry);
269-
BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, CATEGORIZE_DEF, analysisRegistry);
278+
BlockHash rawHash1 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry);
279+
BlockHash rawHash2 = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.INITIAL, categorizeDef, analysisRegistry);
270280
) {
271281
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
272282
private void addBlock(int positionOffset, IntBlock groupIds) {
@@ -341,7 +351,7 @@ public void close() {
341351
page2.releaseBlocks();
342352
}
343353

344-
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, CATEGORIZE_DEF, null)) {
354+
try (var intermediateHash = new CategorizeBlockHash(blockFactory, 0, AggregatorMode.FINAL, categorizeDef, null)) {
345355
intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
346356
private void addBlock(int positionOffset, IntBlock groupIds) {
347357
List<Integer> values = IntStream.range(0, groupIds.getPositionCount())
@@ -409,14 +419,24 @@ public void close() {
409419
}
410420
});
411421

412-
assertHashState(
413-
intermediateHash,
414-
withNull,
415-
".*?Connected.+?to.*?",
416-
".*?Connection.+?error.*?",
417-
".*?Disconnected.*?",
418-
".*?System.+?shutdown.*?"
419-
);
422+
switch (categorizeDef.outputFormat()) {
423+
case REGEX -> assertHashState(
424+
intermediateHash,
425+
withNull,
426+
".*?Connected.+?to.*?",
427+
".*?Connection.+?error.*?",
428+
".*?Disconnected.*?",
429+
".*?System.+?shutdown.*?"
430+
);
431+
case TOKENS -> assertHashState(
432+
intermediateHash,
433+
withNull,
434+
"Connected to",
435+
"Connection error",
436+
"Disconnected",
437+
"System shutdown"
438+
);
439+
}
420440
}
421441
} finally {
422442
intermediatePage1.releaseBlocks();
@@ -425,6 +445,9 @@ public void close() {
425445
}
426446

427447
public void testCategorize_withDriver() {
448+
BlockHash.CategorizeDef categorizeDef = getCategorizeDef();
449+
BlockHash.GroupSpec groupSpec = new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef);
450+
428451
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
429452
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
430453
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
@@ -483,7 +506,7 @@ public void testCategorize_withDriver() {
483506
new LocalSourceOperator(input1),
484507
List.of(
485508
new HashAggregationOperator.HashAggregationOperatorFactory(
486-
List.of(makeGroupSpec()),
509+
List.of(groupSpec),
487510
AggregatorMode.INITIAL,
488511
List.of(
489512
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)),
@@ -502,7 +525,7 @@ public void testCategorize_withDriver() {
502525
new LocalSourceOperator(input2),
503526
List.of(
504527
new HashAggregationOperator.HashAggregationOperatorFactory(
505-
List.of(makeGroupSpec()),
528+
List.of(groupSpec),
506529
AggregatorMode.INITIAL,
507530
List.of(
508531
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)),
@@ -523,7 +546,7 @@ public void testCategorize_withDriver() {
523546
new CannedSourceOperator(intermediateOutput.iterator()),
524547
List.of(
525548
new HashAggregationOperator.HashAggregationOperatorFactory(
526-
List.of(makeGroupSpec()),
549+
List.of(groupSpec),
527550
AggregatorMode.FINAL,
528551
List.of(
529552
new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(1, 2)),
@@ -550,23 +573,36 @@ public void testCategorize_withDriver() {
550573
sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
551574
maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
552575
}
576+
List<String> keys = switch (categorizeDef.outputFormat()) {
577+
case REGEX -> List.of(
578+
".*?aaazz.*?",
579+
".*?bbbzz.*?",
580+
".*?ccczz.*?",
581+
".*?dddzz.*?",
582+
".*?eeezz.*?",
583+
".*?words.+?words.+?words.+?goodbye.*?",
584+
".*?words.+?words.+?words.+?hello.*?"
585+
);
586+
case TOKENS -> List.of("aaazz", "bbbzz", "ccczz", "dddzz", "eeezz", "words words words goodbye", "words words words hello");
587+
};
588+
553589
assertThat(
554590
sums,
555591
equalTo(
556592
Map.of(
557-
".*?aaazz.*?",
593+
keys.get(0),
558594
1L,
559-
".*?bbbzz.*?",
595+
keys.get(1),
560596
2L,
561-
".*?ccczz.*?",
597+
keys.get(2),
562598
33L,
563-
".*?dddzz.*?",
599+
keys.get(3),
564600
44L,
565-
".*?eeezz.*?",
601+
keys.get(4),
566602
5L,
567-
".*?words.+?words.+?words.+?goodbye.*?",
603+
keys.get(5),
568604
8888L,
569-
".*?words.+?words.+?words.+?hello.*?",
605+
keys.get(6),
570606
999L
571607
)
572608
)
@@ -575,30 +611,26 @@ public void testCategorize_withDriver() {
575611
maxs,
576612
equalTo(
577613
Map.of(
578-
".*?aaazz.*?",
614+
keys.get(0),
579615
1L,
580-
".*?bbbzz.*?",
616+
keys.get(1),
581617
2L,
582-
".*?ccczz.*?",
618+
keys.get(2),
583619
30L,
584-
".*?dddzz.*?",
620+
keys.get(3),
585621
40L,
586-
".*?eeezz.*?",
622+
keys.get(4),
587623
5L,
588-
".*?words.+?words.+?words.+?goodbye.*?",
624+
keys.get(5),
589625
8000L,
590-
".*?words.+?words.+?words.+?hello.*?",
626+
keys.get(6),
591627
900L
592628
)
593629
)
594630
);
595631
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
596632
}
597633

598-
private BlockHash.GroupSpec makeGroupSpec() {
599-
return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, CATEGORIZE_DEF);
600-
}
601-
602634
private void assertHashState(CategorizeBlockHash hash, boolean withNull, String... expectedKeys) {
603635
// Check the keys
604636
Block[] blocks = null;

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ public void testCategorize_withDriver() {
7474
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
7575
boolean withNull = randomBoolean();
7676
boolean withMultivalues = randomBoolean();
77-
BlockHash.CategorizeDef categorizeDef = new BlockHash.CategorizeDef(null, BlockHash.CategorizeDef.OutputFormat.REGEX, 70);
77+
BlockHash.CategorizeDef categorizeDef = new BlockHash.CategorizeDef(
78+
null,
79+
randomFrom(BlockHash.CategorizeDef.OutputFormat.values()),
80+
70
81+
);
7882

7983
List<BlockHash.GroupSpec> groupSpecs = List.of(
8084
new BlockHash.GroupSpec(0, ElementType.BYTES_REF, categorizeDef),
@@ -219,8 +223,12 @@ public void testCategorize_withDriver() {
219223
}
220224
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
221225

226+
List<String> keys = switch (categorizeDef.outputFormat()) {
227+
case REGEX -> List.of(".*?connected.+?to.*?", ".*?connection.+?error.*?", ".*?disconnected.*?");
228+
case TOKENS -> List.of("connected to", "connection error", "disconnected");
229+
};
222230
Map<String, Map<Integer, Set<String>>> expectedResult = Map.of(
223-
".*?connected.+?to.*?",
231+
keys.get(0),
224232
Map.of(
225233
7,
226234
Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
@@ -229,9 +237,9 @@ public void testCategorize_withDriver() {
229237
111,
230238
Set.of("connected to 2.1.1")
231239
),
232-
".*?connection.+?error.*?",
240+
keys.get(1),
233241
Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
234-
".*?disconnected.*?",
242+
keys.get(2),
235243
Map.of(7, Set.of("disconnected"))
236244
);
237245
if (withNull) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ public Categorize(
107107
),
108108
@MapParam.MapParamEntry(
109109
name = "output_format",
110-
type = "boolean",
110+
type = "keyword",
111111
valueHint = { "regex", "tokens" },
112112
description = "The output format of the categories. Defaults to regex."
113113
),
114114
@MapParam.MapParamEntry(
115115
name = "similarity_threshold",
116-
type = "boolean",
116+
type = "integer",
117117
valueHint = { "70" },
118118
description = "The minimum percentage of token weight that must match for text to be added to the category bucket. "
119119
+ "Must be between 1 and 100. The larger the value the narrower the categories. "

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,7 +1966,9 @@ public void testCategorizeOptionOutputFormat() {
19661966
assumeTrue("categorize options must be enabled", EsqlCapabilities.Cap.CATEGORIZE_OPTIONS.isEnabled());
19671967

19681968
query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"regex\" })");
1969+
query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"REGEX\" })");
19691970
query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"tokens\" })");
1971+
query("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"ToKeNs\" })");
19701972
assertEquals(
19711973
"1:31: invalid output format [blah], expecting one of [REGEX, TOKENS]",
19721974
error("FROM test | STATS COUNT(*) BY CATEGORIZE(last_name, { \"output_format\": \"blah\" })")

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/SerializableTokenListCategory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ public BytesRef[] getKeyTokens() {
163163
}
164164

165165
public String getKeyTokensString() {
166-
return Arrays.stream(getKeyTokens()).map(BytesRef::utf8ToString).collect(Collectors.joining(" "));
166+
return Arrays.stream(keyTokenIndexes)
167+
.mapToObj(index -> baseTokens[index])
168+
.map(BytesRef::utf8ToString)
169+
.collect(Collectors.joining(" "));
167170
}
168171

169172
public String getRegex() {

0 commit comments

Comments
 (0)