Skip to content

Commit e50d5b9

Browse files
committed
Add almost passing test
1 parent a309133 commit e50d5b9

File tree

4 files changed

+122
-37
lines changed

4 files changed

+122
-37
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,27 @@
2020
import java.io.IOException;
2121

2222
public abstract class AbstractCategorizeBlockHash extends BlockHash {
23+
// TODO: this should probably also take an emitBatchSize
24+
private final int channel;
2325
private final boolean outputPartial;
2426
protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
2527

2628
AbstractCategorizeBlockHash(
2729
BlockFactory blockFactory,
30+
int channel,
2831
boolean outputPartial,
2932
TokenListCategorizer.CloseableTokenListCategorizer categorizer
3033
) {
3134
super(blockFactory);
35+
this.channel = channel;
3236
this.outputPartial = outputPartial;
3337
this.categorizer = categorizer;
3438
}
3539

40+
protected int channel() {
41+
return channel;
42+
}
43+
3644
@Override
3745
public Block[] getKeys() {
3846
if (outputPartial) {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.common.util.BigArrays;
1313
import org.elasticsearch.common.util.BitArray;
1414
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
15-
import org.elasticsearch.compute.aggregation.Warnings;
1615
import org.elasticsearch.compute.ann.Fixed;
1716
import org.elasticsearch.compute.data.Block;
1817
import org.elasticsearch.compute.data.BlockFactory;
@@ -21,8 +20,7 @@
2120
import org.elasticsearch.compute.data.IntBlock;
2221
import org.elasticsearch.compute.data.IntVector;
2322
import org.elasticsearch.compute.data.Page;
24-
import org.elasticsearch.compute.operator.DriverContext;
25-
import org.elasticsearch.compute.operator.EvalOperator;
23+
import org.elasticsearch.core.Releasable;
2624
import org.elasticsearch.core.Releasables;
2725
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
2826
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
@@ -34,17 +32,18 @@ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
3432

3533
CategorizeRawBlockHash(
3634
BlockFactory blockFactory,
35+
int channel,
3736
boolean outputPartial,
38-
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
39-
CategorizeEvaluator evaluator
37+
CategorizationAnalyzer analyzer,
38+
TokenListCategorizer.CloseableTokenListCategorizer categorizer
4039
) {
41-
super(blockFactory, outputPartial, categorizer);
42-
this.evaluator = evaluator;
40+
super(blockFactory, channel, outputPartial, categorizer);
41+
this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
4342
}
4443

4544
@Override
4645
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
47-
IntBlock result = (IntBlock) evaluator.eval(page);
46+
IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()));
4847
addInput.add(0, result);
4948
}
5049

@@ -66,18 +65,14 @@ public void close() {
6665
}
6766

6867
/**
69-
* NOCOMMIT: Super-duper copy-pasted.
68+
* NOCOMMIT: Super-duper copy-pasted from the actually generated evaluator; needs cleanup.
7069
*/
71-
public static final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
72-
private final Warnings warnings;
73-
74-
private final EvalOperator.ExpressionEvaluator v;
75-
70+
public static final class CategorizeEvaluator implements Releasable {
7671
private final CategorizationAnalyzer analyzer;
7772

7873
private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
7974

80-
private final DriverContext driverContext;
75+
private final BlockFactory blockFactory;
8176

8277
static int process(
8378
BytesRef v,
@@ -93,31 +88,25 @@ static int process(
9388
}
9489

9590
public CategorizeEvaluator(
96-
EvalOperator.ExpressionEvaluator v,
9791
CategorizationAnalyzer analyzer,
9892
TokenListCategorizer.CloseableTokenListCategorizer categorizer,
99-
DriverContext driverContext
93+
BlockFactory blockFactory
10094
) {
101-
this.v = v;
10295
this.analyzer = analyzer;
10396
this.categorizer = categorizer;
104-
this.driverContext = driverContext;
105-
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), -1, -1, "");
97+
this.blockFactory = blockFactory;
10698
}
10799

108-
@Override
109-
public Block eval(Page page) {
110-
try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
111-
BytesRefVector vVector = vBlock.asVector();
112-
if (vVector == null) {
113-
return eval(page.getPositionCount(), vBlock);
114-
}
115-
return eval(page.getPositionCount(), vVector).asBlock();
100+
public Block eval(BytesRefBlock vBlock) {
101+
BytesRefVector vVector = vBlock.asVector();
102+
if (vVector == null) {
103+
return eval(vBlock.getPositionCount(), vBlock);
116104
}
105+
return eval(vBlock.getPositionCount(), vVector).asBlock();
117106
}
118107

119108
public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
120-
try (IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
109+
try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
121110
BytesRef vScratch = new BytesRef();
122111
position: for (int p = 0; p < positionCount; p++) {
123112
if (vBlock.isNull(p)) {
@@ -126,7 +115,7 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
126115
}
127116
if (vBlock.getValueCount(p) != 1) {
128117
if (vBlock.getValueCount(p) > 1) {
129-
warnings.registerException(new IllegalArgumentException("single-value function encountered multi-value"));
118+
// TODO: handle multi-values
130119
}
131120
result.appendNull();
132121
continue position;
@@ -138,7 +127,7 @@ public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
138127
}
139128

140129
public IntVector eval(int positionCount, BytesRefVector vVector) {
141-
try (IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
130+
try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
142131
BytesRef vScratch = new BytesRef();
143132
position: for (int p = 0; p < positionCount; p++) {
144133
result.appendInt(p, process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
@@ -149,12 +138,12 @@ public IntVector eval(int positionCount, BytesRefVector vVector) {
149138

150139
@Override
151140
public String toString() {
152-
return "CategorizeEvaluator[" + "v=" + v + "]";
141+
return "CategorizeEvaluator";
153142
}
154143

155144
@Override
156145
public void close() {
157-
Releasables.closeExpectNoException(v, analyzer, categorizer);
146+
Releasables.closeExpectNoException(analyzer, categorizer);
158147
}
159148
}
160149
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
3131
private final IntBlockHash hash;
32-
private final int channel;
3332

3433
CategorizedIntermediateBlockHash(
3534
BlockFactory blockFactory,
@@ -38,13 +37,12 @@ public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHas
3837
IntBlockHash hash,
3938
int channel
4039
) {
41-
super(blockFactory, outputPartial, categorizer);
40+
super(blockFactory, channel, outputPartial, categorizer);
4241
this.hash = hash;
43-
this.channel = channel;
4442
}
4543

4644
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
47-
CompositeBlock block = page.getBlock(channel);
45+
CompositeBlock block = page.getBlock(channel());
4846
BytesRefBlock groupingState = block.getBlock(0);
4947
BytesRefBlock groups = block.getBlock(0);
5048
Map<Integer, Integer> idMap;

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1212

13+
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
1314
import org.apache.lucene.util.BytesRef;
1415
import org.elasticsearch.common.breaker.CircuitBreaker;
1516
import org.elasticsearch.common.unit.ByteSizeValue;
1617
import org.elasticsearch.common.util.BigArrays;
18+
import org.elasticsearch.common.util.BytesRefHash;
1719
import org.elasticsearch.common.util.MockBigArrays;
1820
import org.elasticsearch.common.util.PageCacheRecycler;
1921
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
@@ -31,8 +33,16 @@
3133
import org.elasticsearch.core.Releasable;
3234
import org.elasticsearch.core.ReleasableIterator;
3335
import org.elasticsearch.core.Releasables;
36+
import org.elasticsearch.index.analysis.CharFilterFactory;
37+
import org.elasticsearch.index.analysis.CustomAnalyzer;
38+
import org.elasticsearch.index.analysis.TokenFilterFactory;
39+
import org.elasticsearch.index.analysis.TokenizerFactory;
3440
import org.elasticsearch.indices.breaker.CircuitBreakerService;
3541
import org.elasticsearch.test.ESTestCase;
42+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
43+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
44+
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
45+
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
3646
import org.junit.After;
3747

3848
import java.util.ArrayList;
@@ -1209,6 +1219,86 @@ public void close() {
12091219
}
12101220
}
12111221

1222+
/**
1223+
* Replicate the existing csv test, using sample_data.csv
1224+
*/
1225+
public void testCategorizeRaw() {
1226+
final Page page;
1227+
final int positions = 7;
1228+
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
1229+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
1230+
builder.appendBytesRef(new BytesRef("Connection error"));
1231+
builder.appendBytesRef(new BytesRef("Connection error"));
1232+
builder.appendBytesRef(new BytesRef("Connection error"));
1233+
builder.appendBytesRef(new BytesRef("Disconnected"));
1234+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
1235+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
1236+
page = new Page(builder.build());
1237+
}
1238+
// final int emitBatchSize = between(positions, 10 * 1024);
1239+
try (
1240+
BlockHash hash = new CategorizeRawBlockHash(
1241+
blockFactory,
1242+
0,
1243+
true,
1244+
new CategorizationAnalyzer(
1245+
// TODO: should be the same analyzer as used in Production
1246+
new CustomAnalyzer(
1247+
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
1248+
new CharFilterFactory[0],
1249+
new TokenFilterFactory[0]
1250+
),
1251+
true
1252+
),
1253+
new TokenListCategorizer.CloseableTokenListCategorizer(
1254+
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
1255+
CategorizationPartOfSpeechDictionary.getInstance(),
1256+
0.70f
1257+
)
1258+
);
1259+
) {
1260+
hash.add(page, new GroupingAggregatorFunction.AddInput() {
1261+
@Override
1262+
public void add(int positionOffset, IntBlock groupIds) {
1263+
groupIds.incRef();
1264+
assertEquals(groupIds.getPositionCount(), positions);
1265+
1266+
assertEquals(0, groupIds.getInt(0));
1267+
assertEquals(1, groupIds.getInt(1));
1268+
assertEquals(1, groupIds.getInt(2));
1269+
assertEquals(1, groupIds.getInt(3));
1270+
assertEquals(2, groupIds.getInt(4));
1271+
assertEquals(0, groupIds.getInt(5));
1272+
assertEquals(0, groupIds.getInt(6));
1273+
}
1274+
1275+
@Override
1276+
public void add(int positionOffset, IntVector groupIds) {
1277+
groupIds.incRef();
1278+
assertEquals(groupIds.getPositionCount(), positions);
1279+
1280+
assertEquals(0, groupIds.getInt(0));
1281+
assertEquals(1, groupIds.getInt(1));
1282+
assertEquals(1, groupIds.getInt(2));
1283+
assertEquals(1, groupIds.getInt(3));
1284+
assertEquals(2, groupIds.getInt(4));
1285+
assertEquals(0, groupIds.getInt(5));
1286+
assertEquals(0, groupIds.getInt(6));
1287+
}
1288+
1289+
@Override
1290+
public void close() {
1291+
fail("hashes should not close AddInput");
1292+
}
1293+
});
1294+
} finally {
1295+
page.releaseBlocks();
1296+
}
1297+
// TODO: randomize and try multiple pages.
1298+
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
1299+
// TODO: also test the lookup method and other stuff.
1300+
}
1301+
12121302
record OrdsAndKeys(String description, int positionOffset, IntBlock ords, Block[] keys, IntVector nonEmpty) {}
12131303

12141304
/**

0 commit comments

Comments
 (0)