Skip to content

Commit 6755ab2

Browse files
committed
Refactor tests
1 parent 0008559 commit 6755ab2

File tree

3 files changed

+160
-78
lines changed

3 files changed

+160
-78
lines changed

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.index.DirectoryReader;
1313
import org.apache.lucene.index.IndexReader;
1414
import org.apache.lucene.index.Term;
15+
import org.apache.lucene.search.ConstantScoreQuery;
1516
import org.apache.lucene.search.IndexSearcher;
1617
import org.apache.lucene.search.MatchAllDocsQuery;
1718
import org.apache.lucene.search.MultiTermQuery;
@@ -24,18 +25,13 @@
2425
import org.apache.lucene.util.BytesRef;
2526
import org.elasticsearch.compute.OperatorTests;
2627
import org.elasticsearch.compute.data.BlockFactory;
27-
import org.elasticsearch.compute.data.BooleanBlock;
28-
import org.elasticsearch.compute.data.BooleanVector;
29-
import org.elasticsearch.compute.data.BytesRefBlock;
30-
import org.elasticsearch.compute.data.BytesRefVector;
3128
import org.elasticsearch.compute.data.DocBlock;
3229
import org.elasticsearch.compute.data.DoubleBlock;
3330
import org.elasticsearch.compute.data.ElementType;
3431
import org.elasticsearch.compute.data.Page;
3532
import org.elasticsearch.compute.data.Vector;
3633
import org.elasticsearch.compute.operator.Driver;
3734
import org.elasticsearch.compute.operator.DriverContext;
38-
import org.elasticsearch.compute.operator.EvalOperator;
3935
import org.elasticsearch.compute.operator.Operator;
4036
import org.elasticsearch.compute.operator.ShuffleDocsOperator;
4137
import org.elasticsearch.compute.test.ComputeTestCase;
@@ -54,30 +50,17 @@
5450
import java.util.TreeSet;
5551

5652
import static org.elasticsearch.compute.test.OperatorTestCase.randomPageSize;
57-
import static org.hamcrest.Matchers.equalTo;
5853

54+
/**
55+
* Base class for testing Lucene query evaluators.
56+
*/
5957
public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vector.Builder> extends ComputeTestCase {
6058

6159
private static final String FIELD = "g";
62-
// Scores are not interesting to this test, but enabled conditionally and effectively ignored just for coverage.
63-
protected final boolean useScoring = randomBoolean();
64-
65-
private static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, boolean scoring) {
66-
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
67-
return new LuceneSourceOperator.Factory(
68-
List.of(searchContext),
69-
ctx -> query,
70-
randomFrom(DataPartitioning.values()),
71-
randomIntBetween(1, 10),
72-
randomPageSize(),
73-
LuceneOperator.NO_LIMIT,
74-
scoring
75-
);
76-
}
7760

7861
@SuppressWarnings("unchecked")
7962
public void testDenseCollectorSmall() throws IOException {
80-
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDensecollector(0, 2)) {
63+
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 2)) {
8164
collector.setScorer(getScorer());
8265
collector.collect(0);
8366
collector.collect(1);
@@ -93,7 +76,7 @@ public void testDenseCollectorSmall() throws IOException {
9376

9477
@SuppressWarnings("unchecked")
9578
public void testDenseCollectorSimple() throws IOException {
96-
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDensecollector(0, 10)) {
79+
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(0, 10)) {
9780
collector.setScorer(getScorer());
9881
collector.collect(2);
9982
collector.collect(5);
@@ -112,7 +95,7 @@ public void testDenseCollector() throws IOException {
11295
int min = between(0, Integer.MAX_VALUE - length - 1);
11396
int max = min + length;
11497
boolean[] expected = new boolean[length];
115-
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDensecollector(min, max)) {
98+
try (LuceneQueryEvaluator.DenseCollector<U> collector = createDenseCollector(min, max)) {
11699
collector.setScorer(getScorer());
117100
for (int i = 0; i < length; i++) {
118101
expected[i] = randomBoolean();
@@ -133,31 +116,14 @@ public void testTermQuery() throws IOException {
133116
Set<String> values = values();
134117
String term = values.iterator().next();
135118
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), false);
136-
assertTermQuery(term, results);
119+
assertTermsQuery(results, Set.of(term), 1);
137120
}
138121

139122
public void testTermQueryShuffled() throws IOException {
140123
Set<String> values = values();
141124
String term = values.iterator().next();
142-
List<Page> results = runQuery(values, new TermQuery(new Term(FIELD, term)), true);
143-
assertTermQuery(term, results);
144-
}
145-
146-
private void assertTermQuery(String term, List<Page> results) {
147-
int matchCount = 0;
148-
for (Page page : results) {
149-
int initialBlockIndex = initialBlockIndex(page);
150-
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
151-
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
152-
for (int i = 0; i < page.getPositionCount(); i++) {
153-
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
154-
assertThat(matches.getBoolean(i), equalTo(termAtPosition.utf8ToString().equals(term)));
155-
if (matches.getBoolean(i)) {
156-
matchCount++;
157-
}
158-
}
159-
}
160-
assertThat(matchCount, equalTo(1));
125+
List<Page> results = runQuery(values, new ConstantScoreQuery(new TermQuery(new Term(FIELD, term))), true);
126+
assertTermsQuery(results, Set.of(term), 1);
161127
}
162128

163129
public void testTermsQuery() throws IOException {
@@ -180,28 +146,10 @@ private void testTermsQuery(boolean shuffleDocs) throws IOException {
180146
matchingBytes.add(new BytesRef(v));
181147
}
182148
List<Page> results = runQuery(values, new TermInSetQuery(MultiTermQuery.CONSTANT_SCORE_REWRITE, FIELD, matchingBytes), shuffleDocs);
183-
int matchCount = 0;
184-
for (Page page : results) {
185-
int initialBlockIndex = initialBlockIndex(page);
186-
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
187-
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
188-
for (int i = 0; i < page.getPositionCount(); i++) {
189-
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
190-
assertThat(matches.getBoolean(i), equalTo(matching.contains(termAtPosition.utf8ToString())));
191-
if (matches.getBoolean(i)) {
192-
matchCount++;
193-
}
194-
}
195-
}
196-
assertThat(matchCount, equalTo(expectedMatchCount));
149+
assertTermsQuery(results, matching, expectedMatchCount);
197150
}
198151

199-
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
200-
return new EvalOperator(blockFactory, new LuceneQueryExpressionEvaluator(
201-
blockFactory,
202-
shards
203-
));
204-
}
152+
protected abstract void assertTermsQuery(List<Page> results, Set<String> matching, int expectedMatchCount);
205153

206154
private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs) throws IOException {
207155
DriverContext driverContext = driverContext();
@@ -240,7 +188,7 @@ private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs
240188
List<Page> results = new ArrayList<>();
241189
Driver driver = TestDriverFactory.create(
242190
driverContext,
243-
LuceneQueryEvaluatorTests.luceneOperatorFactory(reader, new MatchAllDocsQuery(), useScoring)
191+
LuceneQueryEvaluatorTests.luceneOperatorFactory(reader, new MatchAllDocsQuery(), usesScoring())
244192
.get(driverContext),
245193
operators,
246194
new TestResultPageSinkOperator(results::add)
@@ -282,23 +230,64 @@ private DriverContext driverContext() {
282230
}
283231

284232
// Returns the initial block index, ignoring the score block if scoring is enabled
285-
private int initialBlockIndex(Page page) {
233+
protected int termsBlockIndex(Page page) {
286234
assert page.getBlock(0) instanceof DocBlock : "expected doc block at index 0";
287-
if (useScoring) {
235+
if (usesScoring()) {
288236
assert page.getBlock(1) instanceof DoubleBlock : "expected double block at index 1";
289237
return 2;
290238
} else {
291239
return 1;
292240
}
293241
}
294242

295-
protected abstract LuceneQueryEvaluator.DenseCollector<U> createDensecollector(int min, int max);
243+
private static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, boolean scoring) {
244+
final ShardContext searchContext = new LuceneSourceOperatorTests.MockShardContext(reader, 0);
245+
return new LuceneSourceOperator.Factory(
246+
List.of(searchContext),
247+
ctx -> query,
248+
randomFrom(DataPartitioning.values()),
249+
randomIntBetween(1, 10),
250+
randomPageSize(),
251+
LuceneOperator.NO_LIMIT,
252+
scoring
253+
);
254+
}
255+
256+
// Returns the block index for the results to check
257+
protected abstract int resultsBlockIndex(Page page);
258+
259+
/**
260+
* Create a dense collector for the given range.
261+
*/
262+
protected abstract LuceneQueryEvaluator.DenseCollector<U> createDenseCollector(int min, int max);
296263

264+
/**
265+
* Returns a test scorer to use for scoring docs. Can be null
266+
*/
297267
protected abstract Scorable getScorer();
298268

269+
/**
270+
* Retrieves the value at a given index from the vector. Need to do this as Vector does not export a generic get() method
271+
*/
299272
protected abstract Object getValueAt(T vector, int i);
300273

274+
/**
275+
* Value that should be returned for a matching doc from the resulting vector
276+
*/
301277
protected abstract Object valueForMatch();
302278

279+
/**
280+
* Value that should be returned for a non-matching doc from the resulting vector
281+
*/
303282
protected abstract Object valueForNoMatch();
283+
284+
/**
285+
* Create the operator to test
286+
*/
287+
protected abstract Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards);
288+
289+
/**
290+
* Should the test use scoring?
291+
*/
292+
protected abstract boolean usesScoring();
304293
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,28 @@
88
package org.elasticsearch.compute.lucene;
99

1010
import org.apache.lucene.search.Scorable;
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.compute.data.BlockFactory;
13+
import org.elasticsearch.compute.data.BooleanBlock;
1114
import org.elasticsearch.compute.data.BooleanVector;
15+
import org.elasticsearch.compute.data.BytesRefBlock;
16+
import org.elasticsearch.compute.data.BytesRefVector;
17+
import org.elasticsearch.compute.data.Page;
1218
import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.DenseCollector;
19+
import org.elasticsearch.compute.operator.EvalOperator;
20+
import org.elasticsearch.compute.operator.Operator;
21+
22+
import java.util.List;
23+
import java.util.Set;
24+
25+
import static org.hamcrest.Matchers.equalTo;
1326

1427
public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTests<BooleanVector, BooleanVector.Builder> {
1528

29+
private final boolean useScoring = randomBoolean();
30+
1631
@Override
17-
protected DenseCollector<BooleanVector.Builder> createDensecollector(int min, int max) {
32+
protected DenseCollector<BooleanVector.Builder> createDenseCollector(int min, int max) {
1833
return new LuceneQueryEvaluator.DenseCollector<>(
1934
min,
2035
max,
@@ -43,5 +58,40 @@ protected Object valueForNoMatch() {
4358
return false;
4459
}
4560

61+
@Override
62+
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
63+
return new EvalOperator(blockFactory, new LuceneQueryExpressionEvaluator(
64+
blockFactory,
65+
shards
66+
));
67+
}
4668

69+
@Override
70+
protected boolean usesScoring() {
71+
// Be consistent for a single test execution
72+
return useScoring;
73+
}
74+
75+
@Override
76+
protected int resultsBlockIndex(Page page) {
77+
return page.getBlockCount() - 1;
78+
}
79+
80+
@Override
81+
protected void assertTermsQuery(List<Page> results, Set<String> matching, int expectedMatchCount) {
82+
int matchCount = 0;
83+
for (Page page : results) {
84+
int initialBlockIndex = termsBlockIndex(page);
85+
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
86+
BooleanVector matches = page.<BooleanBlock>getBlock(initialBlockIndex + 1).asVector();
87+
for (int i = 0; i < page.getPositionCount(); i++) {
88+
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
89+
assertThat(matches.getBoolean(i), equalTo(matching.contains(termAtPosition.utf8ToString())));
90+
if (matches.getBoolean(i)) {
91+
matchCount++;
92+
}
93+
}
94+
}
95+
assertThat(matchCount, equalTo(expectedMatchCount));
96+
}
4797
}

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryScoreEvaluatorTests.java

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,30 @@
88
package org.elasticsearch.compute.lucene;
99

1010
import org.apache.lucene.search.Scorable;
11+
import org.apache.lucene.util.BytesRef;
12+
import org.elasticsearch.compute.data.BlockFactory;
13+
import org.elasticsearch.compute.data.BytesRefBlock;
14+
import org.elasticsearch.compute.data.BytesRefVector;
1115
import org.elasticsearch.compute.data.DoubleVector;
16+
import org.elasticsearch.compute.data.Page;
17+
import org.elasticsearch.compute.operator.Operator;
18+
import org.elasticsearch.compute.operator.ScoreOperator;
1219

1320
import java.io.IOException;
21+
import java.util.List;
22+
import java.util.Set;
1423

1524
import static org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator.NO_MATCH_SCORE;
25+
import static org.hamcrest.Matchers.equalTo;
26+
import static org.hamcrest.Matchers.greaterThan;
1627

1728
public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<DoubleVector, DoubleVector.Builder> {
18-
private static final String FIELD = "g";
19-
public static final Scorable CONSTANT_SCORER = new Scorable() {
20-
@Override
21-
public float score() throws IOException {
22-
return TEST_SCORE;
23-
}
24-
};
25-
public static final float TEST_SCORE = 1.5f;
29+
30+
private static final float TEST_SCORE = 1.5f;
31+
private static final Double DEFAULT_SCORE = 1.0;
2632

2733
@Override
28-
protected LuceneQueryEvaluator.DenseCollector<DoubleVector.Builder> createDensecollector(int min, int max) {
34+
protected LuceneQueryEvaluator.DenseCollector<DoubleVector.Builder> createDenseCollector(int min, int max) {
2935
return new LuceneQueryEvaluator.DenseCollector<>(
3036
min,
3137
max,
@@ -58,4 +64,41 @@ protected Object valueForMatch() {
5864
protected Object valueForNoMatch() {
5965
return NO_MATCH_SCORE;
6066
}
67+
68+
@Override
69+
protected Operator createOperator(BlockFactory blockFactory, LuceneQueryEvaluator.ShardConfig[] shards) {
70+
return new ScoreOperator(blockFactory, new LuceneQueryScoreEvaluator(blockFactory, shards), 1);
71+
}
72+
73+
@Override
74+
protected boolean usesScoring() {
75+
return true;
76+
}
77+
78+
@Override
79+
protected int resultsBlockIndex(Page page) {
80+
// Uses the score block
81+
return 1;
82+
}
83+
84+
@Override
85+
protected void assertTermsQuery(List<Page> results, Set<String> matching, int expectedMatchCount) {
86+
int matchCount = 0;
87+
for (Page page : results) {
88+
int initialBlockIndex = termsBlockIndex(page);
89+
BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
90+
DoubleVector matches = (DoubleVector) page.getBlock(resultsBlockIndex(page)).asVector();
91+
for (int i = 0; i < page.getPositionCount(); i++) {
92+
BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
93+
if (matching.contains(termAtPosition.utf8ToString())) {
94+
assertThat(matches.getDouble(i), greaterThan((double) TEST_SCORE));
95+
matchCount++;
96+
} else {
97+
// Default score, as Lucene docs gets retrieved with a implicit score
98+
assertThat(matches.getDouble(i), equalTo(DEFAULT_SCORE));
99+
}
100+
}
101+
}
102+
assertThat(matchCount, equalTo(expectedMatchCount));
103+
}
61104
}

0 commit comments

Comments
 (0)