16
16
import org .apache .lucene .search .KnnCollector ;
17
17
import org .apache .lucene .store .IndexInput ;
18
18
import org .apache .lucene .util .ArrayUtil ;
19
+ import org .apache .lucene .util .Bits ;
19
20
import org .apache .lucene .util .VectorUtil ;
20
21
import org .apache .lucene .util .hnsw .NeighborQueue ;
21
22
import org .elasticsearch .index .codec .vectors .reflect .OffHeapStats ;
25
26
26
27
import java .io .IOException ;
27
28
import java .util .Map ;
28
- import java .util .function .IntPredicate ;
29
29
30
30
import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .QUERY_BITS ;
31
31
import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
@@ -294,11 +294,10 @@ private static void score(
294
294
}
295
295
296
296
@ Override
297
- PostingVisitor getPostingVisitor (FieldInfo fieldInfo , IndexInput indexInput , float [] target , IntPredicate needsScoring )
298
- throws IOException {
297
+ PostingVisitor getPostingVisitor (FieldInfo fieldInfo , IndexInput indexInput , float [] target , Bits acceptDocs ) throws IOException {
299
298
FieldEntry entry = fields .get (fieldInfo .number );
300
299
final int maxPostingListSize = indexInput .readVInt ();
301
- return new MemorySegmentPostingsVisitor (target , indexInput , entry , fieldInfo , maxPostingListSize , needsScoring );
300
+ return new MemorySegmentPostingsVisitor (target , indexInput , entry , fieldInfo , maxPostingListSize , acceptDocs );
302
301
}
303
302
304
303
@ Override
@@ -312,7 +311,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
312
311
final float [] target ;
313
312
final FieldEntry entry ;
314
313
final FieldInfo fieldInfo ;
315
- final IntPredicate needsScoring ;
314
+ final Bits acceptDocs ;
316
315
private final ES91OSQVectorsScorer osqVectorsScorer ;
317
316
final float [] scores = new float [BULK_SIZE ];
318
317
final float [] correctionsLower = new float [BULK_SIZE ];
@@ -342,13 +341,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
342
341
FieldEntry entry ,
343
342
FieldInfo fieldInfo ,
344
343
int maxPostingListSize ,
345
- IntPredicate needsScoring
344
+ Bits acceptDocs
346
345
) throws IOException {
347
346
this .target = target ;
348
347
this .indexInput = indexInput ;
349
348
this .entry = entry ;
350
349
this .fieldInfo = fieldInfo ;
351
- this .needsScoring = needsScoring ;
350
+ this .acceptDocs = acceptDocs ;
352
351
centroid = new float [fieldInfo .getVectorDimension ()];
353
352
scratch = new float [target .length ];
354
353
quantizationScratch = new int [target .length ];
@@ -419,11 +418,12 @@ private float scoreIndividually(int offset) throws IOException {
419
418
return maxScore ;
420
419
}
421
420
422
- private static int docToBulkScore (int [] docIds , int offset , IntPredicate needsScoring ) {
421
+ private static int docToBulkScore (int [] docIds , int offset , Bits acceptDocs ) {
422
+ assert acceptDocs != null : "acceptDocs must not be null" ;
423
423
int docToScore = ES91OSQVectorsScorer .BULK_SIZE ;
424
424
for (int i = 0 ; i < ES91OSQVectorsScorer .BULK_SIZE ; i ++) {
425
425
final int idx = offset + i ;
426
- if (needsScoring . test (docIds [idx ]) == false ) {
426
+ if (acceptDocs . get (docIds [idx ]) == false ) {
427
427
docIds [idx ] = -1 ;
428
428
docToScore --;
429
429
}
@@ -447,7 +447,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
447
447
int limit = vectors - BULK_SIZE + 1 ;
448
448
int i = 0 ;
449
449
for (; i < limit ; i += BULK_SIZE ) {
450
- final int docsToBulkScore = docToBulkScore (docIdsScratch , i , needsScoring );
450
+ final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore (docIdsScratch , i , acceptDocs );
451
451
if (docsToBulkScore == 0 ) {
452
452
continue ;
453
453
}
@@ -476,7 +476,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
476
476
// process tail
477
477
for (; i < vectors ; i ++) {
478
478
int doc = docIdsScratch [i ];
479
- if (needsScoring . test (doc )) {
479
+ if (acceptDocs == null || acceptDocs . get (doc )) {
480
480
quantizeQueryIfNecessary ();
481
481
indexInput .seek (slicePos + i * quantizedByteLength );
482
482
float qcDist = osqVectorsScorer .quantizeScore (quantizedQueryScratch );
0 commit comments