Skip to content

Commit a234326

Browse files
committed
Unit test for CategorizedIntermediateBlockHash.
1 parent 82cc74a commit a234326

File tree

6 files changed

+198
-96
lines changed

6 files changed

+198
-96
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
import org.elasticsearch.common.io.stream.BytesStreamOutput;
1111
import org.elasticsearch.common.unit.ByteSizeValue;
12+
import org.elasticsearch.common.util.BigArrays;
13+
import org.elasticsearch.common.util.BitArray;
1214
import org.elasticsearch.compute.data.Block;
1315
import org.elasticsearch.compute.data.BlockFactory;
1416
import org.elasticsearch.compute.data.IntBlock;
17+
import org.elasticsearch.compute.data.IntVector;
1518
import org.elasticsearch.compute.data.Page;
1619
import org.elasticsearch.core.ReleasableIterator;
1720
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
@@ -44,33 +47,47 @@ protected int channel() {
4447
@Override
4548
public Block[] getKeys() {
4649
if (outputPartial) {
50+
return new Block[] { buildIntermediateBlock() };
4751
// NOCOMMIT load partial
48-
Block state = null;
49-
Block keys; // NOCOMMIT do we even need to send the keys? it's just going to be 0 to the length of state
52+
// Block state = null;
53+
// Block keys; // NOCOMMIT do we even need to send the keys? it's just going to be 0 to the length of state
5054
// return new Block[] {new CompositeBlock()};
51-
return null;
55+
// return null;
5256
}
5357

5458
// NOCOMMIT load final
5559
return new Block[0];
5660
}
5761

62+
@Override
63+
public IntVector nonEmpty() {
64+
// TODO
65+
return null;
66+
}
67+
68+
@Override
69+
public BitArray seenGroupIds(BigArrays bigArrays) {
70+
// TODO
71+
return null;
72+
}
73+
5874
@Override
5975
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
6076
throw new UnsupportedOperationException();
6177
}
6278

63-
private Block buildIntermediateBlock(BlockFactory blockFactory, int positionCount) {
79+
// visible for testing
80+
Block buildIntermediateBlock() {
6481
if (categorizer.getCategoryCount() == 0) {
65-
return blockFactory.newConstantNullBlock(positionCount);
82+
return blockFactory.newConstantNullBlock(1);
6683
}
6784
try (BytesStreamOutput out = new BytesStreamOutput()) {
6885
// TODO be more careful here.
6986
out.writeVInt(categorizer.getCategoryCount());
7087
for (SerializableTokenListCategory category : categorizer.toCategories(categorizer.getCategoryCount())) {
7188
category.writeTo(out);
7289
}
73-
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
90+
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), 1);
7491
} catch (IOException e) {
7592
throw new RuntimeException(e);
7693
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
package org.elasticsearch.compute.aggregation.blockhash;
99

10-
import org.apache.lucene.analysis.TokenStream;
1110
import org.apache.lucene.util.BytesRef;
12-
import org.elasticsearch.common.util.BigArrays;
13-
import org.elasticsearch.common.util.BitArray;
1411
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1512
import org.elasticsearch.compute.ann.Fixed;
1613
import org.elasticsearch.compute.data.Block;
@@ -25,8 +22,6 @@
2522
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2623
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
2724

28-
import java.io.IOException;
29-
3025
public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
3126
private final CategorizeEvaluator evaluator;
3227

@@ -47,18 +42,6 @@ public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
4742
addInput.add(0, result);
4843
}
4944

50-
@Override
51-
public IntVector nonEmpty() {
52-
// TODO
53-
return null;
54-
}
55-
56-
@Override
57-
public BitArray seenGroupIds(BigArrays bigArrays) {
58-
// TODO
59-
return null;
60-
}
61-
6245
@Override
6346
public void close() {
6447
evaluator.close();
@@ -79,12 +62,7 @@ static int process(
7962
@Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
8063
@Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
8164
) {
82-
String s = v.utf8ToString();
83-
try (TokenStream ts = analyzer.tokenStream("text", s)) {
84-
return categorizer.computeCategory(ts, s.length(), 1).getId();
85-
} catch (IOException e) {
86-
throw new RuntimeException(e);
87-
}
65+
return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
8866
}
8967

9068
public CategorizeEvaluator(

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@
1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.common.bytes.BytesArray;
1212
import org.elasticsearch.common.io.stream.StreamInput;
13-
import org.elasticsearch.common.util.BigArrays;
14-
import org.elasticsearch.common.util.BitArray;
1513
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
1614
import org.elasticsearch.compute.data.BlockFactory;
1715
import org.elasticsearch.compute.data.BytesRefBlock;
18-
import org.elasticsearch.compute.data.CompositeBlock;
1916
import org.elasticsearch.compute.data.IntBlock;
20-
import org.elasticsearch.compute.data.IntVector;
2117
import org.elasticsearch.compute.data.Page;
2218
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
2319
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
@@ -32,27 +28,24 @@ public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHas
3228

3329
CategorizedIntermediateBlockHash(
3430
BlockFactory blockFactory,
31+
int channel,
3532
boolean outputPartial,
36-
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
37-
IntBlockHash hash,
38-
int channel
33+
TokenListCategorizer.CloseableTokenListCategorizer categorizer
3934
) {
4035
super(blockFactory, channel, outputPartial, categorizer);
41-
this.hash = hash;
36+
this.hash = new IntBlockHash(channel, blockFactory);
4237
}
4338

4439
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
45-
CompositeBlock block = page.getBlock(channel());
46-
BytesRefBlock groupingState = block.getBlock(0);
47-
BytesRefBlock groups = block.getBlock(0);
40+
BytesRefBlock categorizerState = page.getBlock(channel());
4841
Map<Integer, Integer> idMap;
49-
if (groupingState.areAllValuesNull() == false) {
50-
idMap = readIntermediate(groupingState.getBytesRef(0, new BytesRef()));
42+
if (categorizerState.areAllValuesNull() == false) {
43+
idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
5144
} else {
5245
idMap = Collections.emptyMap();
5346
}
54-
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(groups.getTotalValueCount())) {
55-
for (int i = 0; i < groups.getTotalValueCount(); i++) {
47+
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
48+
for (int i = 0; i < idMap.size(); i++) {
5649
newIdsBuilder.appendInt(idMap.get(i));
5750
}
5851
IntBlock newIds = newIdsBuilder.build();
@@ -76,18 +69,9 @@ private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
7669
}
7770
}
7871

79-
@Override
80-
public IntVector nonEmpty() {
81-
return hash.nonEmpty();
82-
}
83-
84-
@Override
85-
public BitArray seenGroupIds(BigArrays bigArrays) {
86-
return hash.seenGroupIds(bigArrays);
87-
}
88-
8972
@Override
9073
public void close() {
91-
74+
categorizer.close();
75+
hash.close();
9276
}
9377
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,8 @@
1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1212

13-
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1413
import org.apache.lucene.util.BytesRef;
15-
import org.elasticsearch.common.breaker.CircuitBreaker;
1614
import org.elasticsearch.common.unit.ByteSizeValue;
17-
import org.elasticsearch.common.util.BigArrays;
18-
import org.elasticsearch.common.util.BytesRefHash;
19-
import org.elasticsearch.common.util.MockBigArrays;
20-
import org.elasticsearch.common.util.PageCacheRecycler;
2115
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
2216
import org.elasticsearch.compute.data.Block;
2317
import org.elasticsearch.compute.data.BooleanBlock;
@@ -27,22 +21,11 @@
2721
import org.elasticsearch.compute.data.IntBlock;
2822
import org.elasticsearch.compute.data.IntVector;
2923
import org.elasticsearch.compute.data.LongBlock;
30-
import org.elasticsearch.compute.data.MockBlockFactory;
3124
import org.elasticsearch.compute.data.Page;
3225
import org.elasticsearch.compute.data.TestBlockFactory;
3326
import org.elasticsearch.core.Releasable;
3427
import org.elasticsearch.core.ReleasableIterator;
3528
import org.elasticsearch.core.Releasables;
36-
import org.elasticsearch.index.analysis.CharFilterFactory;
37-
import org.elasticsearch.index.analysis.CustomAnalyzer;
38-
import org.elasticsearch.index.analysis.TokenFilterFactory;
39-
import org.elasticsearch.index.analysis.TokenizerFactory;
40-
import org.elasticsearch.indices.breaker.CircuitBreakerService;
41-
import org.elasticsearch.test.ESTestCase;
42-
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
43-
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
44-
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
45-
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
4629
import org.junit.After;
4730

4831
import java.util.ArrayList;
@@ -61,8 +44,6 @@
6144
import static org.hamcrest.Matchers.greaterThan;
6245
import static org.hamcrest.Matchers.is;
6346
import static org.hamcrest.Matchers.startsWith;
64-
import static org.mockito.Mockito.mock;
65-
import static org.mockito.Mockito.when;
6647

6748
public class BlockHashTests extends BlockHashTestCase {
6849

0 commit comments

Comments
 (0)