1616import org .apache .lucene .search .KnnCollector ;
1717import org .apache .lucene .store .IndexInput ;
1818import org .apache .lucene .util .ArrayUtil ;
19+ import org .apache .lucene .util .Bits ;
1920import org .apache .lucene .util .VectorUtil ;
2021import org .apache .lucene .util .hnsw .NeighborQueue ;
2122import org .elasticsearch .index .codec .vectors .reflect .OffHeapStats ;
2526
2627import java .io .IOException ;
2728import java .util .Map ;
28- import java .util .function .IntPredicate ;
2929
3030import static org .apache .lucene .codecs .lucene102 .Lucene102BinaryQuantizedVectorsFormat .QUERY_BITS ;
3131import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
@@ -294,11 +294,10 @@ private static void score(
294294 }
295295
296296 @ Override
297- PostingVisitor getPostingVisitor (FieldInfo fieldInfo , IndexInput indexInput , float [] target , IntPredicate needsScoring )
298- throws IOException {
297+ PostingVisitor getPostingVisitor (FieldInfo fieldInfo , IndexInput indexInput , float [] target , Bits acceptDocs ) throws IOException {
299298 FieldEntry entry = fields .get (fieldInfo .number );
300299 final int maxPostingListSize = indexInput .readVInt ();
301- return new MemorySegmentPostingsVisitor (target , indexInput , entry , fieldInfo , maxPostingListSize , needsScoring );
300+ return new MemorySegmentPostingsVisitor (target , indexInput , entry , fieldInfo , maxPostingListSize , acceptDocs );
302301 }
303302
304303 @ Override
@@ -312,7 +311,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
312311 final float [] target ;
313312 final FieldEntry entry ;
314313 final FieldInfo fieldInfo ;
315- final IntPredicate needsScoring ;
314+ final Bits acceptDocs ;
316315 private final ES91OSQVectorsScorer osqVectorsScorer ;
317316 final float [] scores = new float [BULK_SIZE ];
318317 final float [] correctionsLower = new float [BULK_SIZE ];
@@ -342,13 +341,13 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
342341 FieldEntry entry ,
343342 FieldInfo fieldInfo ,
344343 int maxPostingListSize ,
345- IntPredicate needsScoring
344+ Bits acceptDocs
346345 ) throws IOException {
347346 this .target = target ;
348347 this .indexInput = indexInput ;
349348 this .entry = entry ;
350349 this .fieldInfo = fieldInfo ;
351- this .needsScoring = needsScoring ;
350+ this .acceptDocs = acceptDocs ;
352351 centroid = new float [fieldInfo .getVectorDimension ()];
353352 scratch = new float [target .length ];
354353 quantizationScratch = new int [target .length ];
@@ -419,11 +418,12 @@ private float scoreIndividually(int offset) throws IOException {
419418 return maxScore ;
420419 }
421420
422- private static int docToBulkScore (int [] docIds , int offset , IntPredicate needsScoring ) {
421+ private static int docToBulkScore (int [] docIds , int offset , Bits acceptDocs ) {
422+ assert acceptDocs != null : "acceptDocs must not be null" ;
423423 int docToScore = ES91OSQVectorsScorer .BULK_SIZE ;
424424 for (int i = 0 ; i < ES91OSQVectorsScorer .BULK_SIZE ; i ++) {
425425 final int idx = offset + i ;
426- if (needsScoring . test (docIds [idx ]) == false ) {
426+ if (acceptDocs . get (docIds [idx ]) == false ) {
427427 docIds [idx ] = -1 ;
428428 docToScore --;
429429 }
@@ -447,7 +447,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
447447 int limit = vectors - BULK_SIZE + 1 ;
448448 int i = 0 ;
449449 for (; i < limit ; i += BULK_SIZE ) {
450- final int docsToBulkScore = docToBulkScore (docIdsScratch , i , needsScoring );
450+ final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore (docIdsScratch , i , acceptDocs );
451451 if (docsToBulkScore == 0 ) {
452452 continue ;
453453 }
@@ -476,7 +476,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
476476 // process tail
477477 for (; i < vectors ; i ++) {
478478 int doc = docIdsScratch [i ];
479- if (needsScoring . test (doc )) {
479+ if (acceptDocs == null || acceptDocs . get (doc )) {
480480 quantizeQueryIfNecessary ();
481481 indexInput .seek (slicePos + i * quantizedByteLength );
482482 float qcDist = osqVectorsScorer .quantizeScore (quantizedQueryScratch );
0 commit comments