Skip to content

Commit a253a24

Browse files
shubhamvishujpountzChrisHegarty
committed
Add AcceptDocs abstraction for accepted KNN docs (#15011)
Co-authored-by: Adrien Grand <[email protected]> Co-authored-by: Chris Hegarty <[email protected]>
1 parent a71a3e2 commit a253a24

File tree

53 files changed

+733
-270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+733
-270
lines changed

lucene/CHANGES.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ API Changes
2323
* GITHUB#14978: Add a bulk scoring interface to RandomVectorScorer
2424
(Trevor McCulloch, Chris Hegarty)
2525

26+
* GITHUB#15011: LeafReader#searchNearestVectors now accepts an AcceptDocs
27+
instance instead of a Bits instance to identify document IDs to filter.
28+
(Shubham Chaudhary, Adrien Grand)
29+
2630
New Features
2731
---------------------
2832
* GITHUB#15015: MultiIndexMergeScheduler: a production multi-tenant merge scheduler (Shawn Yarbrough)

lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.lucene.index.SegmentReadState;
3636
import org.apache.lucene.index.VectorSimilarityFunction;
3737
import org.apache.lucene.internal.hppc.IntObjectHashMap;
38+
import org.apache.lucene.search.AcceptDocs;
3839
import org.apache.lucene.search.KnnCollector;
3940
import org.apache.lucene.search.VectorScorer;
4041
import org.apache.lucene.store.ChecksumIndexInput;
@@ -242,7 +243,7 @@ public ByteVectorValues getByteVectorValues(String field) {
242243
}
243244

244245
@Override
245-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
246+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
246247
throws IOException {
247248
final FieldEntry fieldEntry = getFieldEntry(field);
248249
if (fieldEntry.size() == 0) {
@@ -260,7 +261,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
260261
vectorValues,
261262
fieldEntry.similarityFunction,
262263
getGraphValues(fieldEntry),
263-
getAcceptOrds(acceptDocs, fieldEntry),
264+
getAcceptOrds(acceptDocs.bits(), fieldEntry),
264265
knnCollector.visitLimit(),
265266
random);
266267
knnCollector.incVisitedCount(results.visitedCount());
@@ -273,7 +274,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
273274
}
274275

275276
@Override
276-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
277+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
277278
throws IOException {
278279
throw new UnsupportedOperationException();
279280
}

lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.lucene.index.SegmentReadState;
3838
import org.apache.lucene.index.VectorSimilarityFunction;
3939
import org.apache.lucene.internal.hppc.IntObjectHashMap;
40+
import org.apache.lucene.search.AcceptDocs;
4041
import org.apache.lucene.search.DocIdSetIterator;
4142
import org.apache.lucene.search.KnnCollector;
4243
import org.apache.lucene.search.VectorScorer;
@@ -238,7 +239,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
238239
}
239240

240241
@Override
241-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
242+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
242243
throws IOException {
243244
final FieldEntry fieldEntry = getFieldEntry(field);
244245
if (fieldEntry.size() == 0) {
@@ -253,11 +254,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
253254
scorer,
254255
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
255256
getGraph(fieldEntry),
256-
getAcceptOrds(acceptDocs, fieldEntry));
257+
getAcceptOrds(acceptDocs.bits(), fieldEntry));
257258
}
258259

259260
@Override
260-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
261+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
261262
throws IOException {
262263
throw new UnsupportedOperationException();
263264
}

lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
import org.apache.lucene.index.SegmentReadState;
3737
import org.apache.lucene.index.VectorSimilarityFunction;
3838
import org.apache.lucene.internal.hppc.IntObjectHashMap;
39+
import org.apache.lucene.search.AcceptDocs;
3940
import org.apache.lucene.search.KnnCollector;
4041
import org.apache.lucene.store.ChecksumIndexInput;
4142
import org.apache.lucene.store.DataInput;
4243
import org.apache.lucene.store.IndexInput;
43-
import org.apache.lucene.util.Bits;
4444
import org.apache.lucene.util.IOUtils;
4545
import org.apache.lucene.util.hnsw.HnswGraph;
4646
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
@@ -236,7 +236,7 @@ public ByteVectorValues getByteVectorValues(String field) {
236236
}
237237

238238
@Override
239-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
239+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
240240
throws IOException {
241241
final FieldEntry fieldEntry = getFieldEntry(field);
242242
if (fieldEntry.size() == 0) {
@@ -251,11 +251,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
251251
scorer,
252252
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
253253
getGraph(fieldEntry),
254-
vectorValues.getAcceptOrds(acceptDocs));
254+
vectorValues.getAcceptOrds(acceptDocs.bits()));
255255
}
256256

257257
@Override
258-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
258+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
259259
throws IOException {
260260
throw new UnsupportedOperationException();
261261
}

lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
import org.apache.lucene.index.VectorEncoding;
3838
import org.apache.lucene.index.VectorSimilarityFunction;
3939
import org.apache.lucene.internal.hppc.IntObjectHashMap;
40+
import org.apache.lucene.search.AcceptDocs;
4041
import org.apache.lucene.search.KnnCollector;
4142
import org.apache.lucene.store.ChecksumIndexInput;
4243
import org.apache.lucene.store.DataInput;
4344
import org.apache.lucene.store.IndexInput;
44-
import org.apache.lucene.util.Bits;
4545
import org.apache.lucene.util.IOUtils;
4646
import org.apache.lucene.util.hnsw.HnswGraph;
4747
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
@@ -270,7 +270,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
270270
}
271271

272272
@Override
273-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
273+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
274274
throws IOException {
275275
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
276276
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
@@ -285,11 +285,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
285285
scorer,
286286
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
287287
getGraph(fieldEntry),
288-
vectorValues.getAcceptOrds(acceptDocs));
288+
vectorValues.getAcceptOrds(acceptDocs.bits()));
289289
}
290290

291291
@Override
292-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
292+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
293293
throws IOException {
294294
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
295295
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
@@ -304,7 +304,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
304304
scorer,
305305
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
306306
getGraph(fieldEntry),
307-
vectorValues.getAcceptOrds(acceptDocs));
307+
vectorValues.getAcceptOrds(acceptDocs.bits()));
308308
}
309309

310310
private HnswGraph getGraph(FieldEntry entry) throws IOException {

lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@
4141
import org.apache.lucene.index.VectorEncoding;
4242
import org.apache.lucene.index.VectorSimilarityFunction;
4343
import org.apache.lucene.internal.hppc.IntObjectHashMap;
44+
import org.apache.lucene.search.AcceptDocs;
4445
import org.apache.lucene.search.KnnCollector;
4546
import org.apache.lucene.store.ChecksumIndexInput;
4647
import org.apache.lucene.store.DataInput;
4748
import org.apache.lucene.store.IndexInput;
4849
import org.apache.lucene.store.RandomAccessInput;
4950
import org.apache.lucene.util.ArrayUtil;
50-
import org.apache.lucene.util.Bits;
5151
import org.apache.lucene.util.IOUtils;
5252
import org.apache.lucene.util.hnsw.HnswGraph;
5353
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
@@ -296,7 +296,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
296296
}
297297

298298
@Override
299-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
299+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
300300
throws IOException {
301301
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
302302
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
@@ -320,11 +320,11 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
320320
scorer,
321321
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
322322
getGraph(fieldEntry),
323-
vectorValues.getAcceptOrds(acceptDocs));
323+
vectorValues.getAcceptOrds(acceptDocs.bits()));
324324
}
325325

326326
@Override
327-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
327+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
328328
throws IOException {
329329
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
330330
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
@@ -348,7 +348,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
348348
scorer,
349349
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
350350
getGraph(fieldEntry),
351-
vectorValues.getAcceptOrds(acceptDocs));
351+
vectorValues.getAcceptOrds(acceptDocs.bits()));
352352
}
353353

354354
/** Get knn graph values; used for testing */

lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@
3737
import org.apache.lucene.index.SegmentReadState;
3838
import org.apache.lucene.index.VectorSimilarityFunction;
3939
import org.apache.lucene.internal.hppc.IntObjectHashMap;
40+
import org.apache.lucene.search.AcceptDocs;
4041
import org.apache.lucene.search.DocIdSetIterator;
4142
import org.apache.lucene.search.KnnCollector;
4243
import org.apache.lucene.search.VectorScorer;
4344
import org.apache.lucene.store.BufferedChecksumIndexInput;
4445
import org.apache.lucene.store.ChecksumIndexInput;
4546
import org.apache.lucene.store.IOContext;
4647
import org.apache.lucene.store.IndexInput;
47-
import org.apache.lucene.util.Bits;
4848
import org.apache.lucene.util.BytesRef;
4949
import org.apache.lucene.util.BytesRefBuilder;
5050
import org.apache.lucene.util.IOUtils;
@@ -181,7 +181,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
181181
}
182182

183183
@Override
184-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
184+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
185185
throws IOException {
186186
FloatVectorValues values = getFloatVectorValues(field);
187187
if (target.length != values.dimension()) {
@@ -195,7 +195,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
195195
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
196196
for (int ord = 0; ord < values.size(); ord++) {
197197
int doc = values.ordToDoc(ord);
198-
if (acceptDocs != null && acceptDocs.get(doc) == false) {
198+
if (acceptDocs.bits() != null && acceptDocs.bits().get(doc) == false) {
199199
continue;
200200
}
201201

@@ -211,7 +211,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
211211
}
212212

213213
@Override
214-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
214+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
215215
throws IOException {
216216
ByteVectorValues values = getByteVectorValues(field);
217217
if (target.length != values.dimension()) {
@@ -226,7 +226,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
226226

227227
for (int ord = 0; ord < values.size(); ord++) {
228228
int doc = values.ordToDoc(ord);
229-
if (acceptDocs != null && acceptDocs.get(doc) == false) {
229+
if (acceptDocs.bits() != null && acceptDocs.bits().get(doc) == false) {
230230
continue;
231231
}
232232

lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.apache.lucene.index.LeafReader;
3434
import org.apache.lucene.index.StoredFields;
3535
import org.apache.lucene.index.VectorSimilarityFunction;
36+
import org.apache.lucene.search.AcceptDocs;
3637
import org.apache.lucene.search.TopDocs;
3738
import org.apache.lucene.search.TopKnnCollector;
3839
import org.apache.lucene.store.Directory;
@@ -84,7 +85,8 @@ public void testIndexAndSearchBitVectors() throws IOException {
8485
try (IndexReader reader = DirectoryReader.open(w)) {
8586
LeafReader r = getOnlyLeafReader(reader);
8687
TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE);
87-
r.searchNearestVectors("v1", vectors[0], collector, null);
88+
r.searchNearestVectors(
89+
"v1", vectors[0], collector, AcceptDocs.fromLiveDocs(null, r.maxDoc()));
8890
TopDocs topDocs = collector.topDocs();
8991
assertEquals(3, topDocs.scoreDocs.length);
9092

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import org.apache.lucene.index.FloatVectorValues;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
28+
import org.apache.lucene.search.AcceptDocs;
2829
import org.apache.lucene.search.KnnCollector;
29-
import org.apache.lucene.util.Bits;
3030
import org.apache.lucene.util.NamedSPILoader;
3131

3232
/**
@@ -140,13 +140,13 @@ public ByteVectorValues getByteVectorValues(String field) {
140140

141141
@Override
142142
public void search(
143-
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {
143+
String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
144144
throw new UnsupportedOperationException();
145145
}
146146

147147
@Override
148148
public void search(
149-
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {
149+
String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
150150
throw new UnsupportedOperationException();
151151
}
152152

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.index.ByteVectorValues;
2626
import org.apache.lucene.index.FieldInfo;
2727
import org.apache.lucene.index.FloatVectorValues;
28+
import org.apache.lucene.search.AcceptDocs;
2829
import org.apache.lucene.search.KnnCollector;
2930
import org.apache.lucene.search.ScoreDoc;
3031
import org.apache.lucene.search.TopDocs;
@@ -88,7 +89,8 @@ protected KnnVectorsReader() {}
8889
* if they are all allowed to match.
8990
*/
9091
public abstract void search(
91-
String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
92+
String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
93+
throws IOException;
9294

9395
/**
9496
* Return the k nearest neighbor documents as determined by comparison of their vector values for
@@ -116,7 +118,8 @@ public abstract void search(
116118
* if they are all allowed to match.
117119
*/
118120
public abstract void search(
119-
String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException;
121+
String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs)
122+
throws IOException;
120123

121124
/**
122125
* Returns an instance optimized for merging. This instance may only be consumed in the thread

0 commit comments

Comments
 (0)