Skip to content

Commit 3f94143

Browse files
committed
Unit test with Driver
1 parent a234326 commit 3f94143

File tree

11 files changed

+205
-26
lines changed

11 files changed

+205
-26
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10+
import org.apache.lucene.util.BytesRefBuilder;
1011
import org.elasticsearch.common.io.stream.BytesStreamOutput;
1112
import org.elasticsearch.common.unit.ByteSizeValue;
1213
import org.elasticsearch.common.util.BigArrays;
1314
import org.elasticsearch.common.util.BitArray;
1415
import org.elasticsearch.compute.data.Block;
1516
import org.elasticsearch.compute.data.BlockFactory;
17+
import org.elasticsearch.compute.data.BytesRefVector;
1618
import org.elasticsearch.compute.data.IntBlock;
1719
import org.elasticsearch.compute.data.IntVector;
1820
import org.elasticsearch.compute.data.Page;
@@ -46,38 +48,25 @@ protected int channel() {
4648

4749
@Override
4850
public Block[] getKeys() {
49-
if (outputPartial) {
50-
return new Block[] { buildIntermediateBlock() };
51-
// NOCOMMIT load partial
52-
// Block state = null;
53-
// Block keys; // NOCOMMIT do we even need to send the keys? it's just going to be 0 to the length of state
54-
// return new Block[] {new CompositeBlock()};
55-
// return null;
56-
}
57-
58-
// NOCOMMIT load final
59-
return new Block[0];
51+
return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
6052
}
6153

6254
@Override
6355
public IntVector nonEmpty() {
64-
// TODO
65-
return null;
56+
return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
6657
}
6758

6859
@Override
6960
public BitArray seenGroupIds(BigArrays bigArrays) {
70-
// TODO
71-
return null;
61+
throw new UnsupportedOperationException();
7262
}
7363

7464
@Override
7565
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
7666
throw new UnsupportedOperationException();
7767
}
7868

79-
// visible for testing
80-
Block buildIntermediateBlock() {
69+
private Block buildIntermediateBlock() {
8170
if (categorizer.getCategoryCount() == 0) {
8271
return blockFactory.newConstantNullBlock(1);
8372
}
@@ -92,4 +81,16 @@ Block buildIntermediateBlock() {
9281
throw new RuntimeException(e);
9382
}
9483
}
84+
85+
private Block buildFinalBlock() {
86+
try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
87+
BytesRefBuilder scratch = new BytesRefBuilder();
88+
for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) {
89+
scratch.copyChars(category.getRegex());
90+
result.appendBytesRef(scratch.get());
91+
scratch.clear();
92+
}
93+
return result.build().asBlock();
94+
}
95+
}
9596
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ private static BlockHash newForElementType(int channel, ElementType type, BlockF
133133
case LONG -> new LongBlockHash(channel, blockFactory);
134134
case DOUBLE -> new DoubleBlockHash(channel, blockFactory);
135135
case BYTES_REF -> new BytesRefBlockHash(channel, blockFactory);
136+
case CATEGORY_RAW -> new CategorizeRawBlockHash(channel, blockFactory, true);
137+
case CATEGORY_INTERMEDIATE -> new CategorizedIntermediateBlockHash(channel, blockFactory, false);
136138
default -> throw new IllegalArgumentException("unsupported grouping element type [" + type + "]");
137139
};
138140
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10+
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1011
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.common.util.BytesRefHash;
1113
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1214
import org.elasticsearch.compute.ann.Fixed;
1315
import org.elasticsearch.compute.data.Block;
@@ -19,15 +21,43 @@
1921
import org.elasticsearch.compute.data.Page;
2022
import org.elasticsearch.core.Releasable;
2123
import org.elasticsearch.core.Releasables;
24+
import org.elasticsearch.index.analysis.CharFilterFactory;
25+
import org.elasticsearch.index.analysis.CustomAnalyzer;
26+
import org.elasticsearch.index.analysis.TokenFilterFactory;
27+
import org.elasticsearch.index.analysis.TokenizerFactory;
28+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
29+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
2230
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2331
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2432

2533
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
2634
private final CategorizeEvaluator evaluator;
2735

36+
CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
37+
this(
38+
channel,
39+
blockFactory,
40+
outputPartial,
41+
new CategorizationAnalyzer(
42+
// TODO: should be the same analyzer as used in Production
43+
new CustomAnalyzer(
44+
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
45+
new CharFilterFactory[0],
46+
new TokenFilterFactory[0]
47+
),
48+
true
49+
),
50+
new TokenListCategorizer.CloseableTokenListCategorizer(
51+
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
52+
CategorizationPartOfSpeechDictionary.getInstance(),
53+
0.70f
54+
)
55+
);
56+
}
57+
2858
CategorizeRawBlockHash(
29-
BlockFactory blockFactory,
3059
int channel,
60+
BlockFactory blockFactory,
3161
boolean outputPartial,
3262
CategorizationAnalyzer analyzer,
3363
TokenListCategorizer.CloseableTokenListCategorizer categorizer

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.common.bytes.BytesArray;
1212
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.util.BytesRefHash;
1314
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1415
import org.elasticsearch.compute.data.BlockFactory;
1516
import org.elasticsearch.compute.data.BytesRefBlock;
1617
import org.elasticsearch.compute.data.IntBlock;
1718
import org.elasticsearch.compute.data.Page;
19+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
20+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
1821
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
1922
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2023

@@ -26,9 +29,22 @@
2629
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
2730
private final IntBlockHash hash;
2831

32+
CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
33+
this(
34+
channel,
35+
blockFactory,
36+
outputPartial,
37+
new TokenListCategorizer.CloseableTokenListCategorizer(
38+
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
39+
CategorizationPartOfSpeechDictionary.getInstance(),
40+
0.70f
41+
)
42+
);
43+
}
44+
2945
CategorizedIntermediateBlockHash(
30-
BlockFactory blockFactory,
3146
int channel,
47+
BlockFactory blockFactory,
3248
boolean outputPartial,
3349
TokenListCategorizer.CloseableTokenListCategorizer categorizer
3450
) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ private static Object valueAtOffset(Block block, int offset) {
276276
DocVector v = ((DocBlock) block).asVector();
277277
yield new Doc(v.shards().getInt(offset), v.segments().getInt(offset), v.docs().getInt(offset));
278278
}
279+
case CATEGORY_RAW, CATEGORY_INTERMEDIATE -> throw new IllegalArgumentException("can't read values from category blocks");
279280
case COMPOSITE -> throw new IllegalArgumentException("can't read values from composite blocks");
280281
case UNKNOWN -> throw new IllegalArgumentException("can't read values from [" + block + "]");
281282
};

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/ElementType.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ public enum ElementType {
2424
NULL("Null", (blockFactory, estimatedSize) -> new ConstantNullBlock.Builder(blockFactory)),
2525

2626
BYTES_REF("BytesRef", BlockFactory::newBytesRefBlockBuilder),
27+
CATEGORY_RAW("CategoryRaw", BlockFactory::newBytesRefBlockBuilder),
28+
CATEGORY_INTERMEDIATE("CategoryIntermediate", BlockFactory::newBytesRefBlockBuilder),
2729

2830
/**
2931
* Blocks that reference individual lucene documents.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperatorFactories.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ static List<GroupingAggregator.Factory> valuesAggregatorForGroupings(List<BlockH
147147
case INT -> new org.elasticsearch.compute.aggregation.ValuesIntAggregatorFunctionSupplier(channels);
148148
case LONG -> new org.elasticsearch.compute.aggregation.ValuesLongAggregatorFunctionSupplier(channels);
149149
case BOOLEAN -> new org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier(channels);
150-
case FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException("unsupported grouping type");
150+
case CATEGORY_RAW, CATEGORY_INTERMEDIATE, FLOAT, NULL, DOC, COMPOSITE, UNKNOWN -> throw new IllegalArgumentException(
151+
"unsupported grouping type"
152+
);
151153
});
152154
aggregators.add(aggregatorSupplier.groupingAggregatorFactory(AggregatorMode.SINGLE));
153155
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,32 @@
99

1010
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1111
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.common.breaker.CircuitBreaker;
13+
import org.elasticsearch.common.collect.Iterators;
14+
import org.elasticsearch.common.unit.ByteSizeValue;
15+
import org.elasticsearch.common.util.BigArrays;
1216
import org.elasticsearch.common.util.BytesRefHash;
17+
import org.elasticsearch.common.util.MockBigArrays;
18+
import org.elasticsearch.common.util.PageCacheRecycler;
19+
import org.elasticsearch.compute.aggregation.AggregatorMode;
1320
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
21+
import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
22+
import org.elasticsearch.compute.data.Block;
23+
import org.elasticsearch.compute.data.BlockFactory;
1424
import org.elasticsearch.compute.data.BytesRefBlock;
25+
import org.elasticsearch.compute.data.BytesRefVector;
26+
import org.elasticsearch.compute.data.ElementType;
1527
import org.elasticsearch.compute.data.IntBlock;
1628
import org.elasticsearch.compute.data.IntVector;
29+
import org.elasticsearch.compute.data.LongVector;
1730
import org.elasticsearch.compute.data.Page;
31+
import org.elasticsearch.compute.operator.CannedSourceOperator;
32+
import org.elasticsearch.compute.operator.Driver;
33+
import org.elasticsearch.compute.operator.DriverContext;
34+
import org.elasticsearch.compute.operator.HashAggregationOperator;
35+
import org.elasticsearch.compute.operator.LocalSourceOperator;
36+
import org.elasticsearch.compute.operator.PageConsumerOperator;
37+
import org.elasticsearch.core.Releasables;
1838
import org.elasticsearch.index.analysis.CharFilterFactory;
1939
import org.elasticsearch.index.analysis.CustomAnalyzer;
2040
import org.elasticsearch.index.analysis.TokenFilterFactory;
@@ -24,10 +44,17 @@
2444
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer.CloseableTokenListCategorizer;
2545
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2646

47+
import java.util.ArrayList;
48+
import java.util.List;
2749
import java.util.Set;
2850
import java.util.stream.Collectors;
2951
import java.util.stream.IntStream;
3052

53+
import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
54+
import static org.hamcrest.Matchers.containsInAnyOrder;
55+
import static org.hamcrest.Matchers.equalTo;
56+
import static org.hamcrest.Matchers.hasSize;
57+
3158
public class CategorizeBlockHashTests extends BlockHashTestCase {
3259

3360
/**
@@ -47,7 +74,7 @@ public void testCategorizeRaw() {
4774
page = new Page(builder.build());
4875
}
4976
// final int emitBatchSize = between(positions, 10 * 1024);
50-
try (BlockHash hash = new CategorizeRawBlockHash(blockFactory, 0, true, createAnalyzer(), createCategorizer())) {
77+
try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true, createAnalyzer(), createCategorizer())) {
5178
hash.add(page, new GroupingAggregatorFunction.AddInput() {
5279
@Override
5380
public void add(int positionOffset, IntBlock groupIds) {
@@ -107,9 +134,9 @@ public void testCategorizeIntermediate() {
107134
}
108135
// final int emitBatchSize = between(positions, 10 * 1024);
109136
try (
110-
BlockHash rawHash1 = new CategorizeRawBlockHash(blockFactory, 0, true, createAnalyzer(), createCategorizer());
111-
BlockHash rawHash2 = new CategorizeRawBlockHash(blockFactory, 0, true, createAnalyzer(), createCategorizer());
112-
BlockHash intermediateHash = new CategorizedIntermediateBlockHash(blockFactory, 0, true, createCategorizer())
137+
BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true, createAnalyzer(), createCategorizer());
138+
BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true, createAnalyzer(), createCategorizer());
139+
BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true, createCategorizer())
113140
) {
114141
rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
115142
@Override
@@ -211,6 +238,103 @@ public void close() {
211238
}
212239
}
213240

241+
public void testCategorize_withDriver() {
242+
BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
243+
CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
244+
DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
245+
246+
LocalSourceOperator.BlockSupplier input1 = () -> {
247+
try (BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) {
248+
textsBuilder.appendBytesRef(new BytesRef("a"));
249+
textsBuilder.appendBytesRef(new BytesRef("b"));
250+
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
251+
textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
252+
textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
253+
textsBuilder.appendBytesRef(new BytesRef("c"));
254+
return new Block[] { textsBuilder.build().asBlock() };
255+
}
256+
};
257+
LocalSourceOperator.BlockSupplier input2 = () -> {
258+
try (BytesRefVector.Builder builder = driverContext.blockFactory().newBytesRefVectorBuilder(10)) {
259+
builder.appendBytesRef(new BytesRef("words words words hello nik"));
260+
builder.appendBytesRef(new BytesRef("c"));
261+
builder.appendBytesRef(new BytesRef("words words words goodbye chris"));
262+
builder.appendBytesRef(new BytesRef("d"));
263+
builder.appendBytesRef(new BytesRef("e"));
264+
return new Block[] { builder.build().asBlock() };
265+
}
266+
};
267+
List<Page> intermediateOutput = new ArrayList<>();
268+
List<Page> finalOutput = new ArrayList<>();
269+
270+
Driver driver = new Driver(
271+
driverContext,
272+
new LocalSourceOperator(input1),
273+
List.of(
274+
new HashAggregationOperator.HashAggregationOperatorFactory(
275+
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_RAW)),
276+
List.of(),
277+
16 * 1024
278+
).get(driverContext)
279+
),
280+
new PageConsumerOperator(intermediateOutput::add),
281+
() -> {}
282+
);
283+
runDriver(driver);
284+
285+
driver = new Driver(
286+
driverContext,
287+
new LocalSourceOperator(input2),
288+
List.of(
289+
new HashAggregationOperator.HashAggregationOperatorFactory(
290+
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_RAW)),
291+
List.of(),
292+
16 * 1024
293+
).get(driverContext)
294+
),
295+
new PageConsumerOperator(intermediateOutput::add),
296+
() -> {}
297+
);
298+
runDriver(driver);
299+
300+
driver = new Driver(
301+
driverContext,
302+
new CannedSourceOperator(intermediateOutput.iterator()),
303+
List.of(
304+
new HashAggregationOperator.HashAggregationOperatorFactory(
305+
List.of(new BlockHash.GroupSpec(0, ElementType.CATEGORY_INTERMEDIATE)),
306+
List.of(),
307+
16 * 1024
308+
).get(driverContext)
309+
),
310+
new PageConsumerOperator(finalOutput::add),
311+
() -> {}
312+
);
313+
runDriver(driver);
314+
315+
assertThat(finalOutput, hasSize(1));
316+
assertThat(finalOutput.get(0).getBlockCount(), equalTo(1));
317+
BytesRefBlock block = finalOutput.get(0).getBlock(0);
318+
BytesRefVector vector = block.asVector();
319+
List<String> values = new ArrayList<>();
320+
for (int p = 0; p < vector.getPositionCount(); p++) {
321+
values.add(vector.getBytesRef(p, new BytesRef()).utf8ToString());
322+
}
323+
assertThat(
324+
values,
325+
containsInAnyOrder(
326+
".*?a.*?",
327+
".*?b.*?",
328+
".*?c.*?",
329+
".*?d.*?",
330+
".*?e.*?",
331+
".*?words.+?words.+?words.+?goodbye.*?",
332+
".*?words.+?words.+?words.+?hello.*?"
333+
)
334+
);
335+
Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
336+
}
337+
214338
private static CategorizationAnalyzer createAnalyzer() {
215339
return new CategorizationAnalyzer(
216340
// TODO: should be the same analyzer as used in Production

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockTestUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public static Object randomValue(ElementType e) {
3535
case FLOAT -> randomFloat();
3636
case DOUBLE -> randomDouble();
3737
case BYTES_REF -> new BytesRef(randomRealisticUnicodeOfCodepointLengthBetween(0, 5)); // TODO: also test spatial WKB
38+
case CATEGORY_RAW, CATEGORY_INTERMEDIATE -> throw new IllegalArgumentException("can't make random values for category");
3839
case BOOLEAN -> randomBoolean();
3940
case DOC -> new BlockUtils.Doc(randomInt(), randomInt(), between(0, Integer.MAX_VALUE));
4041
case NULL -> null;

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/VectorBuilderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void testCranky() {
113113

114114
private Vector.Builder vectorBuilder(int estimatedSize, BlockFactory blockFactory) {
115115
return switch (elementType) {
116-
case NULL, DOC, COMPOSITE, UNKNOWN -> throw new UnsupportedOperationException();
116+
case NULL, DOC, COMPOSITE, UNKNOWN, CATEGORY_RAW, CATEGORY_INTERMEDIATE -> throw new UnsupportedOperationException();
117117
case BOOLEAN -> blockFactory.newBooleanVectorBuilder(estimatedSize);
118118
case BYTES_REF -> blockFactory.newBytesRefVectorBuilder(estimatedSize);
119119
case FLOAT -> blockFactory.newFloatVectorBuilder(estimatedSize);

0 commit comments

Comments
 (0)