1616import org .apache .lucene .search .ScoreMode ;
1717import org .apache .lucene .search .Scorer ;
1818import org .apache .lucene .search .Weight ;
19- import org .apache .lucene .util .ArrayUtil ;
2019import org .apache .lucene .util .Bits ;
20+ import org .elasticsearch .common .CheckedBiConsumer ;
2121import org .elasticsearch .compute .data .Block ;
2222import org .elasticsearch .compute .data .BlockFactory ;
2323import org .elasticsearch .compute .data .BooleanVector ;
3232
3333import java .io .IOException ;
3434import java .io .UncheckedIOException ;
35- import java .util .function .BiFunction ;
35+ import java .util .ArrayList ;
36+ import java .util .Collections ;
37+ import java .util .List ;
38+ import java .util .function .Consumer ;
3639
3740/**
3841 * {@link EvalOperator.ExpressionEvaluator} to run a Lucene {@link Query} during
4144 * {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
4245 * this evaluator is here to save the day.
4346 */
44- public abstract class LuceneQueryEvaluator implements Releasable {
45-
46- public static final double NO_MATCH_SCORE = 0.0 ;
47+ public abstract class LuceneQueryEvaluator <T extends Vector .Builder > implements Releasable {
4748
4849 public record ShardConfig (Query query , IndexSearcher searcher ) {}
4950
5051 private final BlockFactory blockFactory ;
5152 private final ShardConfig [] shards ;
52- private final BiFunction <BlockFactory , Integer , ScoreVectorBuilder > scoreVectorBuilderSupplier ;
5353
54- private ShardState [] perShardState = EMPTY_SHARD_STATES ;
54+ private final List < ShardState > perShardState ;
5555
5656 protected LuceneQueryEvaluator (
5757 BlockFactory blockFactory ,
58- ShardConfig [] shards ,
59- BiFunction <BlockFactory , Integer , ScoreVectorBuilder > scoreVectorBuilderSupplier
58+ ShardConfig [] shards
6059 ) {
6160 this .blockFactory = blockFactory ;
6261 this .shards = shards ;
63- this .scoreVectorBuilderSupplier = scoreVectorBuilderSupplier ;
62+ this .perShardState = new ArrayList <>( Collections . nCopies ( shards . length , null )) ;
6463 }
6564
6665 public Block executeQuery (Page page ) {
@@ -115,7 +114,7 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException
115114 int min = docs .docs ().getInt (0 );
116115 int max = docs .docs ().getInt (docs .getPositionCount () - 1 );
117116 int length = max - min + 1 ;
118- try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier . apply (blockFactory , length )) {
117+ try (T scoreBuilder = createBuilder (blockFactory , length )) {
119118 if (length == docs .getPositionCount () && length > 1 ) {
120119 return segmentState .scoreDense (scoreBuilder , min , max );
121120 }
@@ -143,8 +142,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
143142 int prevShard = -1 ;
144143 int prevSegment = -1 ;
145144 SegmentState segmentState = null ;
146- try (ScoreVectorBuilder scoreBuilder = scoreVectorBuilderSupplier .apply (blockFactory , docs .getPositionCount ())) {
147- scoreBuilder .initVector ();
145+ try (T scoreBuilder = createBuilder (blockFactory , docs .getPositionCount ())) {
148146 for (int i = 0 ; i < docs .getPositionCount (); i ++) {
149147 int shard = docs .shards ().getInt (docs .shards ().getInt (map [i ]));
150148 int segment = docs .segments ().getInt (map [i ]);
@@ -155,7 +153,7 @@ private Vector evalSlow(DocVector docs) throws IOException {
155153 prevSegment = segment ;
156154 }
157155 if (segmentState .noMatch ) {
158- scoreBuilder . appendNoMatch ();
156+ appendNoMatch (scoreBuilder );
159157 } else {
160158 segmentState .scoreSingleDocWithScorer (scoreBuilder , docs .docs ().getInt (map [i ]));
161159 }
@@ -170,40 +168,39 @@ private Vector evalSlow(DocVector docs) throws IOException {
170168 public void close () {
171169 }
172170
173- protected abstract ScoreMode scoreMode ();
174-
175171 private ShardState shardState (int shard ) throws IOException {
176- if (shard >= perShardState .length ) {
177- perShardState = ArrayUtil .grow (perShardState , shard + 1 );
178- } else if (perShardState [shard ] != null ) {
179- return perShardState [shard ];
172+ ShardState shardState = perShardState .get (shard );
173+ if (shardState != null ) {
174+ return shardState ;
180175 }
181- perShardState [shard ] = new ShardState (shards [shard ]);
182- return perShardState [shard ];
176+ shardState = new ShardState (shards [shard ]);
177+ perShardState .set (shard , shardState );
178+ return shardState ;
183179 }
184180
185181 private class ShardState {
186182 private final Weight weight ;
187183 private final IndexSearcher searcher ;
188- private SegmentState [] perSegmentState = EMPTY_SEGMENT_STATES ;
184+ private final List < SegmentState > perSegmentState ;
189185
190186 ShardState (ShardConfig config ) throws IOException {
191187 weight = config .searcher .createWeight (config .query , scoreMode (), 1.0f );
192188 searcher = config .searcher ;
189+ perSegmentState = new ArrayList <>(Collections .nCopies (searcher .getLeafContexts ().size (), null ));
193190 }
194191
195192 SegmentState segmentState (int segment ) throws IOException {
196- if (segment >= perSegmentState .length ) {
197- perSegmentState = ArrayUtil .grow (perSegmentState , segment + 1 );
198- } else if (perSegmentState [segment ] != null ) {
199- return perSegmentState [segment ];
193+ SegmentState segmentState = perSegmentState .get (segment );
194+ if (segmentState != null ) {
195+ return segmentState ;
200196 }
201- perSegmentState [segment ] = new SegmentState (weight , searcher .getLeafContexts ().get (segment ));
202- return perSegmentState [segment ];
197+ segmentState = new SegmentState (weight , searcher .getLeafContexts ().get (segment ));
198+ perSegmentState .set (segment , segmentState );
199+ return segmentState ;
203200 }
204201 }
205202
206- private static class SegmentState {
203+ private class SegmentState {
207204 private final Weight weight ;
208205 private final LeafReaderContext ctx ;
209206
@@ -244,9 +241,9 @@ private SegmentState(Weight weight, LeafReaderContext ctx) {
244241 * Score a range using the {@link BulkScorer}. This should be faster
245242 * than using {@link #scoreSparse} for dense doc ids.
246243 */
247- Vector scoreDense (ScoreVectorBuilder scoreBuilder , int min , int max ) throws IOException {
244+ Vector scoreDense (T scoreBuilder , int min , int max ) throws IOException {
248245 if (noMatch ) {
249- return scoreBuilder . createNoMatchVector ();
246+ return createNoMatchVector (blockFactory , max - min + 1 );
250247 }
251248 if (bulkScorer == null || // The bulkScorer wasn't initialized
252249 Thread .currentThread () != bulkScorerThread // The bulkScorer was initialized on a different thread
@@ -255,10 +252,12 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
255252 bulkScorer = weight .bulkScorer (ctx );
256253 if (bulkScorer == null ) {
257254 noMatch = true ;
258- return scoreBuilder . createNoMatchVector ();
255+ return createNoMatchVector (blockFactory , max - min + 1 );
259256 }
260257 }
261- try (DenseCollector collector = new DenseCollector (min , max , scoreBuilder )) {
258+ try (DenseCollector <T > collector = new DenseCollector <>(min , max , scoreBuilder ,
259+ LuceneQueryEvaluator .this ::appendNoMatch ,
260+ LuceneQueryEvaluator .this ::appendMatch )) {
262261 bulkScorer .score (collector , ctx .reader ().getLiveDocs (), min , max + 1 );
263262 return collector .build ();
264263 }
@@ -268,12 +267,11 @@ Vector scoreDense(ScoreVectorBuilder scoreBuilder, int min, int max) throws IOEx
268267 * Score a vector of doc ids using {@link Scorer}. If you have a dense range of
269268 * doc ids it'd be faster to use {@link #scoreDense}.
270269 */
271- Vector scoreSparse (ScoreVectorBuilder scoreBuilder , IntVector docs ) throws IOException {
270+ Vector scoreSparse (T scoreBuilder , IntVector docs ) throws IOException {
272271 initScorer (docs .getInt (0 ));
273272 if (noMatch ) {
274- return scoreBuilder . createNoMatchVector ();
273+ return createNoMatchVector (blockFactory , docs . getPositionCount () );
275274 }
276- scoreBuilder .initVector ();
277275 for (int i = 0 ; i < docs .getPositionCount (); i ++) {
278276 scoreSingleDocWithScorer (scoreBuilder , docs .getInt (i ));
279277 }
@@ -296,41 +294,47 @@ private void initScorer(int minDocId) throws IOException {
296294 }
297295 }
298296
299- private void scoreSingleDocWithScorer (ScoreVectorBuilder builder , int doc ) throws IOException {
297+ private void scoreSingleDocWithScorer (T builder , int doc ) throws IOException {
300298 if (scorer .iterator ().docID () == doc ) {
301- builder . appendMatch (scorer );
299+ appendMatch (builder , scorer );
302300 } else if (scorer .iterator ().docID () > doc ) {
303- builder . appendNoMatch ();
301+ appendNoMatch (builder );
304302 } else {
305303 if (scorer .iterator ().advance (doc ) == doc ) {
306- builder . appendMatch (scorer );
304+ appendMatch (builder , scorer );
307305 } else {
308- builder . appendNoMatch ();
306+ appendNoMatch (builder );
309307 }
310308 }
311309 }
312310 }
313311
314- private static final ShardState [] EMPTY_SHARD_STATES = new ShardState [0 ];
315- private static final SegmentState [] EMPTY_SEGMENT_STATES = new SegmentState [0 ];
316-
317312 /**
318313 * Collects matching information for dense range of doc ids. This assumes that
319314 * doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
320315 * which isn't documented, but @jpountz swears is true.
321316 */
322- static class DenseCollector implements LeafCollector , Releasable {
323- private final ScoreVectorBuilder scoreBuilder ;
317+ static class DenseCollector < U extends Vector . Builder > implements LeafCollector , Releasable {
318+ private final U scoreBuilder ;
324319 private final int max ;
325- private Scorable scorer ;
320+ private final Consumer <U > appendNoMatch ;
321+ private final CheckedBiConsumer <U , Scorable , IOException > appendMatch ;
326322
323+ private Scorable scorer ;
327324 int next ;
328325
329- DenseCollector (int min , int max , ScoreVectorBuilder scoreBuilder ) {
326+ DenseCollector (
327+ int min ,
328+ int max ,
329+ U scoreBuilder ,
330+ Consumer <U > appendNoMatch ,
331+ CheckedBiConsumer <U , Scorable , IOException > appendMatch
332+ ) {
330333 this .scoreBuilder = scoreBuilder ;
331- scoreBuilder .initVector ();
332334 this .max = max ;
333335 next = min ;
336+ this .appendNoMatch = appendNoMatch ;
337+ this .appendMatch = appendMatch ;
334338 }
335339
336340 @ Override
@@ -341,9 +345,9 @@ public void setScorer(Scorable scorable) {
341345 @ Override
342346 public void collect (int doc ) throws IOException {
343347 while (next ++ < doc ) {
344- scoreBuilder . appendNoMatch ( );
348+ appendNoMatch . accept ( scoreBuilder );
345349 }
346- scoreBuilder . appendMatch ( scorer );
350+ appendMatch . accept ( scoreBuilder , scorer );
347351 }
348352
349353 public Vector build () {
@@ -353,7 +357,7 @@ public Vector build() {
353357 @ Override
354358 public void finish () {
355359 while (next ++ <= max ) {
356- scoreBuilder . appendNoMatch ( );
360+ appendNoMatch . accept ( scoreBuilder );
357361 }
358362 }
359363
@@ -363,15 +367,13 @@ public void close() {
363367 }
364368 }
365369
366- public interface ScoreVectorBuilder extends Releasable {
367- Vector createNoMatchVector ();
370+ protected abstract ScoreMode scoreMode ();
368371
369- void initVector ( );
372+ protected abstract Vector createNoMatchVector ( BlockFactory blockFactory , int size );
370373
371- void appendNoMatch ( );
374+ protected abstract T createBuilder ( BlockFactory blockFactory , int size );
372375
373- void appendMatch ( Scorable scorer ) throws IOException ;
376+ protected abstract void appendNoMatch ( T builder ) ;
374377
375- Vector build ();
376- }
378+ protected abstract void appendMatch (T builder , Scorable scorer ) throws IOException ;
377379}
0 commit comments