Skip to content

Commit 82cc74a

Browse files
committed
Move Categorize BlockHash tests to separate file
1 parent 31e9e20 commit 82cc74a

File tree

3 files changed

+135
-92
lines changed

3 files changed

+135
-92
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.elasticsearch.common.breaker.CircuitBreaker;
11+
import org.elasticsearch.common.unit.ByteSizeValue;
12+
import org.elasticsearch.common.util.BigArrays;
13+
import org.elasticsearch.common.util.MockBigArrays;
14+
import org.elasticsearch.common.util.PageCacheRecycler;
15+
import org.elasticsearch.compute.data.MockBlockFactory;
16+
import org.elasticsearch.indices.breaker.CircuitBreakerService;
17+
import org.elasticsearch.test.ESTestCase;
18+
19+
import static org.mockito.Mockito.mock;
20+
import static org.mockito.Mockito.when;
21+
22+
public class BlockHashTestCase extends ESTestCase {
23+
24+
final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
25+
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
26+
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
27+
28+
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
29+
private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
30+
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
31+
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
32+
return breakerService;
33+
}
34+
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@
6464
import static org.mockito.Mockito.mock;
6565
import static org.mockito.Mockito.when;
6666

67-
public class BlockHashTests extends ESTestCase {
68-
69-
final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
70-
final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
71-
final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
67+
public class BlockHashTests extends BlockHashTestCase {
7268

7369
@ParametersFactory
7470
public static List<Object[]> params() {
@@ -1219,86 +1215,6 @@ public void close() {
12191215
}
12201216
}
12211217

1222-
/**
1223-
* Replicate the existing csv test, using sample_data.csv
1224-
*/
1225-
public void testCategorizeRaw() {
1226-
final Page page;
1227-
final int positions = 7;
1228-
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
1229-
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
1230-
builder.appendBytesRef(new BytesRef("Connection error"));
1231-
builder.appendBytesRef(new BytesRef("Connection error"));
1232-
builder.appendBytesRef(new BytesRef("Connection error"));
1233-
builder.appendBytesRef(new BytesRef("Disconnected"));
1234-
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
1235-
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
1236-
page = new Page(builder.build());
1237-
}
1238-
// final int emitBatchSize = between(positions, 10 * 1024);
1239-
try (
1240-
BlockHash hash = new CategorizeRawBlockHash(
1241-
blockFactory,
1242-
0,
1243-
true,
1244-
new CategorizationAnalyzer(
1245-
// TODO: should be the same analyzer as used in Production
1246-
new CustomAnalyzer(
1247-
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
1248-
new CharFilterFactory[0],
1249-
new TokenFilterFactory[0]
1250-
),
1251-
true
1252-
),
1253-
new TokenListCategorizer.CloseableTokenListCategorizer(
1254-
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
1255-
CategorizationPartOfSpeechDictionary.getInstance(),
1256-
0.70f
1257-
)
1258-
);
1259-
) {
1260-
hash.add(page, new GroupingAggregatorFunction.AddInput() {
1261-
@Override
1262-
public void add(int positionOffset, IntBlock groupIds) {
1263-
groupIds.incRef();
1264-
assertEquals(groupIds.getPositionCount(), positions);
1265-
1266-
assertEquals(0, groupIds.getInt(0));
1267-
assertEquals(1, groupIds.getInt(1));
1268-
assertEquals(1, groupIds.getInt(2));
1269-
assertEquals(1, groupIds.getInt(3));
1270-
assertEquals(2, groupIds.getInt(4));
1271-
assertEquals(0, groupIds.getInt(5));
1272-
assertEquals(0, groupIds.getInt(6));
1273-
}
1274-
1275-
@Override
1276-
public void add(int positionOffset, IntVector groupIds) {
1277-
groupIds.incRef();
1278-
assertEquals(groupIds.getPositionCount(), positions);
1279-
1280-
assertEquals(0, groupIds.getInt(0));
1281-
assertEquals(1, groupIds.getInt(1));
1282-
assertEquals(1, groupIds.getInt(2));
1283-
assertEquals(1, groupIds.getInt(3));
1284-
assertEquals(2, groupIds.getInt(4));
1285-
assertEquals(0, groupIds.getInt(5));
1286-
assertEquals(0, groupIds.getInt(6));
1287-
}
1288-
1289-
@Override
1290-
public void close() {
1291-
fail("hashes should not close AddInput");
1292-
}
1293-
});
1294-
} finally {
1295-
page.releaseBlocks();
1296-
}
1297-
// TODO: randomize and try multiple pages.
1298-
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
1299-
// TODO: also test the lookup method and other stuff.
1300-
}
1301-
13021218
record OrdsAndKeys(String description, int positionOffset, IntBlock ords, Block[] keys, IntVector nonEmpty) {}
13031219

13041220
/**
@@ -1493,13 +1409,6 @@ private void assertKeys(Block[] actualKeys, Object[][] expectedKeys) {
14931409
}
14941410
}
14951411

1496-
// A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
1497-
static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
1498-
CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
1499-
when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
1500-
return breakerService;
1501-
}
1502-
15031412
IntVector intRange(int startInclusive, int endExclusive) {
15041413
return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
15051414
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation.blockhash;
9+
10+
import org.apache.lucene.analysis.core.WhitespaceTokenizer;
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.common.util.BytesRefHash;
13+
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
14+
import org.elasticsearch.compute.data.BytesRefBlock;
15+
import org.elasticsearch.compute.data.IntBlock;
16+
import org.elasticsearch.compute.data.IntVector;
17+
import org.elasticsearch.compute.data.Page;
18+
import org.elasticsearch.index.analysis.CharFilterFactory;
19+
import org.elasticsearch.index.analysis.CustomAnalyzer;
20+
import org.elasticsearch.index.analysis.TokenFilterFactory;
21+
import org.elasticsearch.index.analysis.TokenizerFactory;
22+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
23+
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
24+
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
25+
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
26+
27+
public class CategorizeBlockHashTests extends BlockHashTestCase {
28+
29+
/**
30+
* Replicate the existing csv test, using sample_data.csv
31+
*/
32+
public void testCategorizeRaw() {
33+
final Page page;
34+
final int positions = 7;
35+
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
36+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
37+
builder.appendBytesRef(new BytesRef("Connection error"));
38+
builder.appendBytesRef(new BytesRef("Connection error"));
39+
builder.appendBytesRef(new BytesRef("Connection error"));
40+
builder.appendBytesRef(new BytesRef("Disconnected"));
41+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
42+
builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
43+
page = new Page(builder.build());
44+
}
45+
// final int emitBatchSize = between(positions, 10 * 1024);
46+
try (
47+
BlockHash hash = new CategorizeRawBlockHash(
48+
blockFactory,
49+
0,
50+
true,
51+
new CategorizationAnalyzer(
52+
// TODO: should be the same analyzer as used in Production
53+
new CustomAnalyzer(
54+
TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
55+
new CharFilterFactory[0],
56+
new TokenFilterFactory[0]
57+
),
58+
true
59+
),
60+
new TokenListCategorizer.CloseableTokenListCategorizer(
61+
new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
62+
CategorizationPartOfSpeechDictionary.getInstance(),
63+
0.70f
64+
)
65+
);
66+
) {
67+
hash.add(page, new GroupingAggregatorFunction.AddInput() {
68+
@Override
69+
public void add(int positionOffset, IntBlock groupIds) {
70+
groupIds.incRef();
71+
assertEquals(groupIds.getPositionCount(), positions);
72+
73+
assertEquals(0, groupIds.getInt(0));
74+
assertEquals(1, groupIds.getInt(1));
75+
assertEquals(1, groupIds.getInt(2));
76+
assertEquals(1, groupIds.getInt(3));
77+
assertEquals(2, groupIds.getInt(4));
78+
assertEquals(0, groupIds.getInt(5));
79+
assertEquals(0, groupIds.getInt(6));
80+
}
81+
82+
@Override
83+
public void add(int positionOffset, IntVector groupIds) {
84+
add(positionOffset, groupIds.asBlock());
85+
}
86+
87+
@Override
88+
public void close() {
89+
fail("hashes should not close AddInput");
90+
}
91+
});
92+
} finally {
93+
page.releaseBlocks();
94+
}
95+
96+
// TODO: randomize and try multiple pages.
97+
// TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
98+
// TODO: also test the lookup method and other stuff.
99+
}
100+
}

0 commit comments

Comments
 (0)