1212import org .apache .lucene .index .DirectoryReader ;
1313import org .apache .lucene .index .IndexReader ;
1414import org .apache .lucene .index .Term ;
15+ import org .apache .lucene .search .ConstantScoreQuery ;
1516import org .apache .lucene .search .IndexSearcher ;
1617import org .apache .lucene .search .MatchAllDocsQuery ;
1718import org .apache .lucene .search .MultiTermQuery ;
2425import org .apache .lucene .util .BytesRef ;
2526import org .elasticsearch .compute .OperatorTests ;
2627import org .elasticsearch .compute .data .BlockFactory ;
27- import org .elasticsearch .compute .data .BooleanBlock ;
28- import org .elasticsearch .compute .data .BooleanVector ;
29- import org .elasticsearch .compute .data .BytesRefBlock ;
30- import org .elasticsearch .compute .data .BytesRefVector ;
3128import org .elasticsearch .compute .data .DocBlock ;
3229import org .elasticsearch .compute .data .DoubleBlock ;
3330import org .elasticsearch .compute .data .ElementType ;
3431import org .elasticsearch .compute .data .Page ;
3532import org .elasticsearch .compute .data .Vector ;
3633import org .elasticsearch .compute .operator .Driver ;
3734import org .elasticsearch .compute .operator .DriverContext ;
38- import org .elasticsearch .compute .operator .EvalOperator ;
3935import org .elasticsearch .compute .operator .Operator ;
4036import org .elasticsearch .compute .operator .ShuffleDocsOperator ;
4137import org .elasticsearch .compute .test .ComputeTestCase ;
5450import java .util .TreeSet ;
5551
5652import static org .elasticsearch .compute .test .OperatorTestCase .randomPageSize ;
57- import static org .hamcrest .Matchers .equalTo ;
5853
54+ /**
55+ * Base class for testing Lucene query evaluators.
56+ */
5957public abstract class LuceneQueryEvaluatorTests <T extends Vector , U extends Vector .Builder > extends ComputeTestCase {
6058
6159 private static final String FIELD = "g" ;
62- // Scores are not interesting to this test, but enabled conditionally and effectively ignored just for coverage.
63- protected final boolean useScoring = randomBoolean ();
64-
65- private static LuceneOperator .Factory luceneOperatorFactory (IndexReader reader , Query query , boolean scoring ) {
66- final ShardContext searchContext = new LuceneSourceOperatorTests .MockShardContext (reader , 0 );
67- return new LuceneSourceOperator .Factory (
68- List .of (searchContext ),
69- ctx -> query ,
70- randomFrom (DataPartitioning .values ()),
71- randomIntBetween (1 , 10 ),
72- randomPageSize (),
73- LuceneOperator .NO_LIMIT ,
74- scoring
75- );
76- }
7760
7861 @ SuppressWarnings ("unchecked" )
7962 public void testDenseCollectorSmall () throws IOException {
80- try (LuceneQueryEvaluator .DenseCollector <U > collector = createDensecollector (0 , 2 )) {
63+ try (LuceneQueryEvaluator .DenseCollector <U > collector = createDenseCollector (0 , 2 )) {
8164 collector .setScorer (getScorer ());
8265 collector .collect (0 );
8366 collector .collect (1 );
@@ -93,7 +76,7 @@ public void testDenseCollectorSmall() throws IOException {
9376
9477 @ SuppressWarnings ("unchecked" )
9578 public void testDenseCollectorSimple () throws IOException {
96- try (LuceneQueryEvaluator .DenseCollector <U > collector = createDensecollector (0 , 10 )) {
79+ try (LuceneQueryEvaluator .DenseCollector <U > collector = createDenseCollector (0 , 10 )) {
9780 collector .setScorer (getScorer ());
9881 collector .collect (2 );
9982 collector .collect (5 );
@@ -112,7 +95,7 @@ public void testDenseCollector() throws IOException {
11295 int min = between (0 , Integer .MAX_VALUE - length - 1 );
11396 int max = min + length ;
11497 boolean [] expected = new boolean [length ];
115- try (LuceneQueryEvaluator .DenseCollector <U > collector = createDensecollector (min , max )) {
98+ try (LuceneQueryEvaluator .DenseCollector <U > collector = createDenseCollector (min , max )) {
11699 collector .setScorer (getScorer ());
117100 for (int i = 0 ; i < length ; i ++) {
118101 expected [i ] = randomBoolean ();
@@ -133,31 +116,14 @@ public void testTermQuery() throws IOException {
133116 Set <String > values = values ();
134117 String term = values .iterator ().next ();
135118 List <Page > results = runQuery (values , new TermQuery (new Term (FIELD , term )), false );
136- assertTermQuery ( term , results );
119+ assertTermsQuery ( results , Set . of ( term ), 1 );
137120 }
138121
139122 public void testTermQueryShuffled () throws IOException {
140123 Set <String > values = values ();
141124 String term = values .iterator ().next ();
142- List <Page > results = runQuery (values , new TermQuery (new Term (FIELD , term )), true );
143- assertTermQuery (term , results );
144- }
145-
146- private void assertTermQuery (String term , List <Page > results ) {
147- int matchCount = 0 ;
148- for (Page page : results ) {
149- int initialBlockIndex = initialBlockIndex (page );
150- BytesRefVector terms = page .<BytesRefBlock >getBlock (initialBlockIndex ).asVector ();
151- BooleanVector matches = page .<BooleanBlock >getBlock (initialBlockIndex + 1 ).asVector ();
152- for (int i = 0 ; i < page .getPositionCount (); i ++) {
153- BytesRef termAtPosition = terms .getBytesRef (i , new BytesRef ());
154- assertThat (matches .getBoolean (i ), equalTo (termAtPosition .utf8ToString ().equals (term )));
155- if (matches .getBoolean (i )) {
156- matchCount ++;
157- }
158- }
159- }
160- assertThat (matchCount , equalTo (1 ));
125+ List <Page > results = runQuery (values , new ConstantScoreQuery (new TermQuery (new Term (FIELD , term ))), true );
126+ assertTermsQuery (results , Set .of (term ), 1 );
161127 }
162128
163129 public void testTermsQuery () throws IOException {
@@ -180,28 +146,10 @@ private void testTermsQuery(boolean shuffleDocs) throws IOException {
180146 matchingBytes .add (new BytesRef (v ));
181147 }
182148 List <Page > results = runQuery (values , new TermInSetQuery (MultiTermQuery .CONSTANT_SCORE_REWRITE , FIELD , matchingBytes ), shuffleDocs );
183- int matchCount = 0 ;
184- for (Page page : results ) {
185- int initialBlockIndex = initialBlockIndex (page );
186- BytesRefVector terms = page .<BytesRefBlock >getBlock (initialBlockIndex ).asVector ();
187- BooleanVector matches = page .<BooleanBlock >getBlock (initialBlockIndex + 1 ).asVector ();
188- for (int i = 0 ; i < page .getPositionCount (); i ++) {
189- BytesRef termAtPosition = terms .getBytesRef (i , new BytesRef ());
190- assertThat (matches .getBoolean (i ), equalTo (matching .contains (termAtPosition .utf8ToString ())));
191- if (matches .getBoolean (i )) {
192- matchCount ++;
193- }
194- }
195- }
196- assertThat (matchCount , equalTo (expectedMatchCount ));
149+ assertTermsQuery (results , matching , expectedMatchCount );
197150 }
198151
199- protected Operator createOperator (BlockFactory blockFactory , LuceneQueryEvaluator .ShardConfig [] shards ) {
200- return new EvalOperator (blockFactory , new LuceneQueryExpressionEvaluator (
201- blockFactory ,
202- shards
203- ));
204- }
152+ protected abstract void assertTermsQuery (List <Page > results , Set <String > matching , int expectedMatchCount );
205153
206154 private List <Page > runQuery (Set <String > values , Query query , boolean shuffleDocs ) throws IOException {
207155 DriverContext driverContext = driverContext ();
@@ -240,7 +188,7 @@ private List<Page> runQuery(Set<String> values, Query query, boolean shuffleDocs
240188 List <Page > results = new ArrayList <>();
241189 Driver driver = TestDriverFactory .create (
242190 driverContext ,
243- LuceneQueryEvaluatorTests .luceneOperatorFactory (reader , new MatchAllDocsQuery (), useScoring )
191+ LuceneQueryEvaluatorTests .luceneOperatorFactory (reader , new MatchAllDocsQuery (), usesScoring () )
244192 .get (driverContext ),
245193 operators ,
246194 new TestResultPageSinkOperator (results ::add )
@@ -282,23 +230,64 @@ private DriverContext driverContext() {
282230 }
283231
284232 // Returns the initial block index, ignoring the score block if scoring is enabled
285- private int initialBlockIndex (Page page ) {
233+ protected int termsBlockIndex (Page page ) {
286234 assert page .getBlock (0 ) instanceof DocBlock : "expected doc block at index 0" ;
287- if (useScoring ) {
235+ if (usesScoring () ) {
288236 assert page .getBlock (1 ) instanceof DoubleBlock : "expected double block at index 1" ;
289237 return 2 ;
290238 } else {
291239 return 1 ;
292240 }
293241 }
294242
295- protected abstract LuceneQueryEvaluator .DenseCollector <U > createDensecollector (int min , int max );
243+ private static LuceneOperator .Factory luceneOperatorFactory (IndexReader reader , Query query , boolean scoring ) {
244+ final ShardContext searchContext = new LuceneSourceOperatorTests .MockShardContext (reader , 0 );
245+ return new LuceneSourceOperator .Factory (
246+ List .of (searchContext ),
247+ ctx -> query ,
248+ randomFrom (DataPartitioning .values ()),
249+ randomIntBetween (1 , 10 ),
250+ randomPageSize (),
251+ LuceneOperator .NO_LIMIT ,
252+ scoring
253+ );
254+ }
255+
256+ // Returns the block index for the results to check
257+ protected abstract int resultsBlockIndex (Page page );
258+
259+ /**
260+ * Create a dense collector for the given range.
261+ */
262+ protected abstract LuceneQueryEvaluator .DenseCollector <U > createDenseCollector (int min , int max );
296263
264+ /**
265+ * Returns a test scorer to use for scoring docs. Can be null
266+ */
297267 protected abstract Scorable getScorer ();
298268
269+ /**
270+ * Retrieves the value at a given index from the vector. Need to do this as Vector does not export a generic get() method
271+ */
299272 protected abstract Object getValueAt (T vector , int i );
300273
274+ /**
275+ * Value that should be returned for a matching doc from the resulting vector
276+ */
301277 protected abstract Object valueForMatch ();
302278
279+ /**
280+ * Value that should be returned for a non-matching doc from the resulting vector
281+ */
303282 protected abstract Object valueForNoMatch ();
283+
284+ /**
285+ * Create the operator to test
286+ */
287+ protected abstract Operator createOperator (BlockFactory blockFactory , LuceneQueryEvaluator .ShardConfig [] shards );
288+
289+ /**
290+ * Should the test use scoring?
291+ */
292+ protected abstract boolean usesScoring ();
304293}
0 commit comments