Skip to content

Commit 239d159

Browse files
committed
Add aggregator to unit test
1 parent 3f94143 commit 239d159

File tree

4 files changed

+70
-31
lines changed

4 files changed

+70
-31
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ private Block buildIntermediateBlock() {
7373
try (BytesStreamOutput out = new BytesStreamOutput()) {
7474
// TODO be more careful here.
7575
out.writeVInt(categorizer.getCategoryCount());
76-
for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) {
76+
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
7777
category.writeTo(out);
7878
}
79-
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), 1);
79+
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
8080
} catch (IOException e) {
8181
throw new RuntimeException(e);
8282
}
@@ -85,7 +85,7 @@ private Block buildIntermediateBlock() {
8585
private Block buildFinalBlock() {
8686
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
8787
BytesRefBuilder scratch = new BytesRefBuilder();
88-
for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) {
88+
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
8989
scratch.copyChars(category.getRegex());
9090
result.appendBytesRef(scratch.get());
9191
scratch.clear();

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHas
5252
this.hash = new IntBlockHash(channel, blockFactory);
5353
}
5454

55+
@Override
5556
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
5657
BytesRefBlock categorizerState = page.getBlock(channel());
5758
Map<Integer, Integer> idMap;
@@ -60,6 +61,8 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
6061
} else {
6162
idMap = Collections.emptyMap();
6263
}
64+
// TODO: when there are aggregators running, this renumbering doesn't work.
65+
// This should renumber the destination IDs only, but it also renumbers the source IDs.
6366
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
6467
for (int i = 0; i < idMap.size(); i++) {
6568
newIdsBuilder.appendInt(idMap.get(i));

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.compute.data.ElementType;
2727
import org.elasticsearch.compute.data.IntBlock;
2828
import org.elasticsearch.compute.data.IntVector;
29+
import org.elasticsearch.compute.data.LongBlock;
2930
import org.elasticsearch.compute.data.LongVector;
3031
import org.elasticsearch.compute.data.Page;
3132
import org.elasticsearch.compute.operator.CannedSourceOperator;
@@ -45,13 +46,14 @@
4546
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
4647

4748
import java.util.ArrayList;
49+
import java.util.HashMap;
4850
import java.util.List;
51+
import java.util.Map;
4952
import java.util.Set;
5053
import java.util.stream.Collectors;
5154
import java.util.stream.IntStream;
5255

5356
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
54-
import static org.hamcrest.Matchers.containsInAnyOrder;
5557
import static org.hamcrest.Matchers.equalTo;
5658
import static org.hamcrest.Matchers.hasSize;
5759

@@ -244,24 +246,41 @@ public void testCategorize_withDriver() {
244246
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
245247

246248
LocalSourceOperator.BlockSupplier input1 = () -> {
247-
try (BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) {
249+
try (
250+
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
251+
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
252+
) {
248253
textsBuilder.appendBytesRef(new BytesRef("a"));
249254
textsBuilder.appendBytesRef(new BytesRef("b"));
250255
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
251256
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
252257
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
253258
textsBuilder.appendBytesRef(new BytesRef("c"));
254-
return new Block[] { textsBuilder.build().asBlock() };
259+
countsBuilder.appendLong(11);
260+
countsBuilder.appendLong(22);
261+
countsBuilder.appendLong(800);
262+
countsBuilder.appendLong(80);
263+
countsBuilder.appendLong(900);
264+
countsBuilder.appendLong(30);
265+
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
255266
}
256267
};
257268
LocalSourceOperator.BlockSupplier input2 = () -> {
258-
try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) {
259-
builder.appendBytesRef(new BytesRef("words words words hello nik"));
260-
builder.appendBytesRef(new BytesRef("c"));
261-
builder.appendBytesRef(new BytesRef("words words words goodbye chris"));
262-
builder.appendBytesRef(new BytesRef("d"));
263-
builder.appendBytesRef(new BytesRef("e"));
264-
return new Block[] { builder.build().asBlock() };
269+
try (
270+
BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
271+
LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
272+
) {
273+
textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
274+
textsBuilder.appendBytesRef(new BytesRef("c"));
275+
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
276+
textsBuilder.appendBytesRef(new BytesRef("d"));
277+
textsBuilder.appendBytesRef(new BytesRef("e"));
278+
countsBuilder.appendLong(99);
279+
countsBuilder.appendLong(3);
280+
countsBuilder.appendLong(8);
281+
countsBuilder.appendLong(44);
282+
countsBuilder.appendLong(55);
283+
return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
265284
}
266285
};
267286
List<Page> intermediateOutput = new ArrayList<>();
@@ -273,7 +292,7 @@ public void testCategorize_withDriver() {
273292
List.of(
274293
new HashAggregationOperator.HashAggregationOperatorFactory(
275294
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_RAW)),
276-
List.of(),
295+
List.of(new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
277296
16 * 1024
278297
).get(driverContext)
279298
),
@@ -288,7 +307,7 @@ public void testCategorize_withDriver() {
288307
List.of(
289308
new HashAggregationOperator.HashAggregationOperatorFactory(
290309
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_RAW)),
291-
List.of(),
310+
List.of(new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
292311
16 * 1024
293312
).get(driverContext)
294313
),
@@ -303,7 +322,7 @@ public void testCategorize_withDriver() {
303322
List.of(
304323
new HashAggregationOperator.HashAggregationOperatorFactory(
305324
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_INTERMEDIATE)),
306-
List.of(),
325+
List.of(new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
307326
16 * 1024
308327
).get(driverContext)
309328
),
@@ -313,23 +332,32 @@ public void testCategorize_withDriver() {
313332
runDriver(driver);
314333

315334
assertThat(finalOutput, hasSize(1));
316-
assertThat(finalOutput.get(0).getBlockCount(), equalTo(1));
317-
BytesRefBlock block = finalOutput.get(0).getBlock(0);
318-
BytesRefVector vector = block.asVector();
319-
List<String> values = new ArrayList<>();
320-
for (int p = 0; p < vector.getPositionCount(); p++) {
321-
values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString());
335+
assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
336+
BytesRefVector textsVector = ((BytesRefBlock) finalOutput.get(0).getBlock(0)).asVector();
337+
LongVector countsVector = ((LongBlock) finalOutput.get(0).getBlock(1)).asVector();
338+
Map<String, Long> counts = new HashMap<>();
339+
for (int i = 0; i < countsVector.getPositionCount(); i++) {
340+
counts.put(textsVector.getBytesRef(i, new BytesRef()).utf8ToString(), countsVector.getLong(i));
322341
}
323342
assertThat(
324-
values,
325-
containsInAnyOrder(
326-
".*?a.*?",
327-
".*?b.*?",
328-
".*?c.*?",
329-
".*?d.*?",
330-
".*?e.*?",
331-
".*?words.+?words.+?words.+?goodbye.*?",
332-
".*?words.+?words.+?words.+?hello.*?"
343+
counts,
344+
equalTo(
345+
Map.of(
346+
".*?a.*?",
347+
11,
348+
".*?b.*?",
349+
22,
350+
".*?c.*?",
351+
33,
352+
".*?d.*?",
353+
44,
354+
".*?e.*?",
355+
55,
356+
".*?words.+?words.+?words.+?goodbye.*?",
357+
888,
358+
".*?words.+?words.+?words.+?hello.*?",
359+
999
360+
)
333361
)
334362
);
335363
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public void close() {
8484
@Nullable
8585
private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
8686

87+
private final List<TokenListCategory> categoriesById;
88+
8789
/**
8890
* Categories stored in such a way that the most common are accessed first.
8991
* This is implemented as an {@link ArrayList} with bespoke ordering rather
@@ -109,6 +111,7 @@ public TokenListCategorizer(
109111
this.lowerThreshold = threshold;
110112
this.upperThreshold = (1.0f + threshold) / 2.0f;
111113
this.categoriesByNumMatches = new ArrayList<>();
114+
this.categoriesById = new ArrayList<>();
112115
cacheRamUsage(0);
113116
}
114117

@@ -310,6 +313,7 @@ private synchronized TokenListCategory computeCategory(
310313
maxUnfilteredStringLen,
311314
numDocs
312315
);
316+
categoriesById.add(newCategory);
313317
categoriesByNumMatches.add(newCategory);
314318
cacheRamUsage(newCategory.ramBytesUsed());
315319
return repositionCategory(newCategory, newIndex);
@@ -428,6 +432,10 @@ public List<SerializableTokenListCategory> toCategories(int size) {
428432
.toList();
429433
}
430434

435+
public List<SerializableTokenListCategory> toCategoriesById() {
436+
return categoriesById.stream().map(category -> new SerializableTokenListCategory(category, bytesRefHash)).toList();
437+
}
438+
431439
public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
432440
return categoriesByNumMatches.stream()
433441
.limit(size)

0 commit comments

Comments
 (0)