Skip to content

Commit e40fb2f

Browse files
authored
Adjust knn reader interfaces to use new AcceptDocs api (#133501)
1 parent 247b6f4 commit e40fb2f

File tree

23 files changed

+150
-87
lines changed

23 files changed

+150
-87
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.index.SegmentReadState;
2626
import org.apache.lucene.index.SegmentWriteState;
2727
import org.apache.lucene.index.Sorter;
28+
import org.apache.lucene.search.AcceptDocs;
2829
import org.apache.lucene.search.KnnCollector;
2930
import org.apache.lucene.util.Bits;
3031
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
@@ -128,13 +129,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
128129
}
129130

130131
@Override
131-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
132+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
132133
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
133134
}
134135

135-
private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs, RandomVectorScorer scorer) throws IOException {
136+
private void collectAllMatchingDocs(KnnCollector knnCollector, AcceptDocs acceptDocs, RandomVectorScorer scorer)
137+
throws IOException {
136138
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
137-
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
139+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
138140
for (int i = 0; i < scorer.maxOrd(); i++) {
139141
if (acceptedOrds == null || acceptedOrds.get(i)) {
140142
collector.collect(i, scorer.score(i));
@@ -145,7 +147,7 @@ private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs,
145147
}
146148

147149
@Override
148-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
150+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
149151
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
150152
}
151153

server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.lucene.index.SegmentReadState;
2424
import org.apache.lucene.index.SegmentWriteState;
2525
import org.apache.lucene.index.Sorter;
26+
import org.apache.lucene.search.AcceptDocs;
2627
import org.apache.lucene.search.KnnCollector;
2728
import org.apache.lucene.util.Bits;
2829
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
@@ -136,13 +137,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
136137
}
137138

138139
@Override
139-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
140+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
140141
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
141142
}
142143

143-
private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs, RandomVectorScorer scorer) throws IOException {
144+
private void collectAllMatchingDocs(KnnCollector knnCollector, AcceptDocs acceptDocs, RandomVectorScorer scorer)
145+
throws IOException {
144146
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
145-
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
147+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
146148
for (int i = 0; i < scorer.maxOrd(); i++) {
147149
if (acceptedOrds == null || acceptedOrds.get(i)) {
148150
collector.collect(i, scorer.score(i));
@@ -153,7 +155,7 @@ private void collectAllMatchingDocs(KnnCollector knnCollector, Bits acceptDocs,
153155
}
154156

155157
@Override
156-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
158+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
157159
collectAllMatchingDocs(knnCollector, acceptDocs, reader.getRandomVectorScorer(field, target));
158160
}
159161

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import org.apache.lucene.index.VectorEncoding;
2323
import org.apache.lucene.index.VectorSimilarityFunction;
2424
import org.apache.lucene.internal.hppc.IntObjectHashMap;
25+
import org.apache.lucene.search.AcceptDocs;
2526
import org.apache.lucene.search.KnnCollector;
2627
import org.apache.lucene.store.ChecksumIndexInput;
2728
import org.apache.lucene.store.DataInput;
2829
import org.apache.lucene.store.IOContext;
2930
import org.apache.lucene.store.IndexInput;
30-
import org.apache.lucene.util.BitSet;
3131
import org.apache.lucene.util.Bits;
3232
import org.elasticsearch.core.IOUtils;
3333
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
@@ -212,7 +212,7 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti
212212
}
213213

214214
@Override
215-
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
215+
public final void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
216216
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
217217
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
218218
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
@@ -223,11 +223,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
223223
"vector query dimension: " + target.length + " differs from field dimension: " + fieldInfo.getVectorDimension()
224224
);
225225
}
226-
float percentFiltered = 1f;
227-
if (acceptDocs instanceof BitSet bitSet) {
228-
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
229-
}
230226
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
227+
float percentFiltered = Math.max(0f, Math.min(1f, (float) acceptDocs.cost() / numVectors));
231228
float visitRatio = DYNAMIC_VISIT_RATIO;
232229
// Search strategy may be null if this is being called from checkIndex (e.g. from a test)
233230
if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfSearchStrategy) {
@@ -255,7 +252,8 @@ public final void search(String field, float[] target, KnnCollector knnCollector
255252
target,
256253
postListSlice
257254
);
258-
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs);
255+
Bits acceptDocsBits = acceptDocs.bits();
256+
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocsBits);
259257
long expectedDocs = 0;
260258
long actualDocs = 0;
261259
// initially we visit only the "centroids to search"
@@ -271,7 +269,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
271269
expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset());
272270
actualDocs += scorer.visit(knnCollector);
273271
}
274-
if (acceptDocs != null) {
272+
if (acceptDocsBits != null) {
275273
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
276274
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
277275
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
@@ -284,7 +282,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
284282
}
285283

286284
@Override
287-
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
285+
public final void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
288286
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
289287
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
290288
for (int i = 0; i < values.size(); i++) {

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
import org.apache.lucene.index.VectorEncoding;
2525
import org.apache.lucene.index.VectorSimilarityFunction;
2626
import org.apache.lucene.search.DocIdSetIterator;
27+
import org.apache.lucene.store.DataAccessHint;
2728
import org.apache.lucene.store.IOContext;
2829
import org.apache.lucene.store.IndexInput;
2930
import org.apache.lucene.store.IndexOutput;
3031
import org.apache.lucene.store.RandomAccessInput;
31-
import org.apache.lucene.store.ReadAdvice;
3232
import org.apache.lucene.util.LongValues;
3333
import org.apache.lucene.util.VectorUtil;
3434
import org.elasticsearch.core.IOUtils;
@@ -302,11 +302,11 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws
302302
try (
303303
IndexInput vectors = mergeState.segmentInfo.dir.openInput(
304304
tempRawVectorsFileName,
305-
IOContext.DEFAULT.withReadAdvice(ReadAdvice.SEQUENTIAL)
305+
IOContext.DEFAULT.withHints(DataAccessHint.SEQUENTIAL)
306306
);
307307
IndexInput docs = docsFileName == null
308308
? null
309-
: mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT.withReadAdvice(ReadAdvice.SEQUENTIAL))
309+
: mergeState.segmentInfo.dir.openInput(docsFileName, IOContext.DEFAULT.withHints(DataAccessHint.SEQUENTIAL))
310310
) {
311311
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, docs, vectors, numVectors);
312312

server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.index.SegmentReadState;
3333
import org.apache.lucene.index.VectorEncoding;
3434
import org.apache.lucene.index.VectorSimilarityFunction;
35+
import org.apache.lucene.search.AcceptDocs;
3536
import org.apache.lucene.search.KnnCollector;
3637
import org.apache.lucene.search.VectorScorer;
3738
import org.apache.lucene.store.ChecksumIndexInput;
@@ -226,17 +227,17 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
226227
}
227228

228229
@Override
229-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
230+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
230231
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
231232
}
232233

233234
@Override
234-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
235+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
235236
if (knnCollector.k() == 0) return;
236237
final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
237238
if (scorer == null) return;
238239
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
239-
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
240+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
240241
for (int i = 0; i < scorer.maxOrd(); i++) {
241242
if (acceptedOrds == null || acceptedOrds.get(i)) {
242243
collector.collect(i, scorer.score(i));

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.lucene.index.SegmentReadState;
3333
import org.apache.lucene.index.VectorEncoding;
3434
import org.apache.lucene.index.VectorSimilarityFunction;
35+
import org.apache.lucene.search.AcceptDocs;
3536
import org.apache.lucene.search.KnnCollector;
3637
import org.apache.lucene.search.VectorScorer;
3738
import org.apache.lucene.store.ChecksumIndexInput;
@@ -240,17 +241,17 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
240241
}
241242

242243
@Override
243-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
244+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
244245
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
245246
}
246247

247248
@Override
248-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
249+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
249250
if (knnCollector.k() == 0) return;
250251
final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
251252
if (scorer == null) return;
252253
OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
253-
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
254+
Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs.bits());
254255
for (int i = 0; i < scorer.maxOrd(); i++) {
255256
if (acceptedOrds == null || acceptedOrds.get(i)) {
256257
collector.collect(i, scorer.score(i));

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import org.apache.lucene.index.ByteVectorValues;
1414
import org.apache.lucene.index.FieldInfo;
1515
import org.apache.lucene.index.FloatVectorValues;
16+
import org.apache.lucene.search.AcceptDocs;
1617
import org.apache.lucene.search.KnnCollector;
1718
import org.apache.lucene.util.Accountable;
18-
import org.apache.lucene.util.Bits;
1919
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2020
import org.elasticsearch.core.IOUtils;
2121

@@ -60,12 +60,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
6060
}
6161

6262
@Override
63-
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
63+
public void search(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
6464
mainReader.search(field, target, knnCollector, acceptDocs);
6565
}
6666

6767
@Override
68-
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
68+
public void search(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) throws IOException {
6969
mainReader.search(field, target, knnCollector, acceptDocs);
7070
}
7171

server/src/main/java/org/elasticsearch/index/engine/TranslogDirectoryReader.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.apache.lucene.index.TermsEnum;
4141
import org.apache.lucene.index.VectorEncoding;
4242
import org.apache.lucene.index.VectorSimilarityFunction;
43+
import org.apache.lucene.search.AcceptDocs;
4344
import org.apache.lucene.search.DocIdSetIterator;
4445
import org.apache.lucene.search.KnnCollector;
4546
import org.apache.lucene.store.ByteBuffersDirectory;
@@ -447,12 +448,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
447448
}
448449

449450
@Override
450-
public void searchNearestVectors(String field, float[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
451+
public void searchNearestVectors(String field, float[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
451452
getDelegate().searchNearestVectors(field, target, collector, acceptDocs);
452453
}
453454

454455
@Override
455-
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
456+
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
456457
getDelegate().searchNearestVectors(field, target, collector, acceptDocs);
457458
}
458459

server/src/main/java/org/elasticsearch/index/mapper/DocumentLeafReader.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.lucene.index.VectorEncoding;
3535
import org.apache.lucene.index.VectorSimilarityFunction;
3636
import org.apache.lucene.index.memory.MemoryIndex;
37+
import org.apache.lucene.search.AcceptDocs;
3738
import org.apache.lucene.search.DocIdSetIterator;
3839
import org.apache.lucene.search.KnnCollector;
3940
import org.apache.lucene.util.Bits;
@@ -210,7 +211,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
210211
}
211212

212213
@Override
213-
public void searchNearestVectors(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {
214+
public void searchNearestVectors(String field, float[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
214215
throw new UnsupportedOperationException();
215216
}
216217

@@ -255,7 +256,7 @@ public ByteVectorValues getByteVectorValues(String field) {
255256
}
256257

257258
@Override
258-
public void searchNearestVectors(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {
259+
public void searchNearestVectors(String field, byte[] target, KnnCollector knnCollector, AcceptDocs acceptDocs) {
259260
throw new UnsupportedOperationException();
260261
}
261262

server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
import org.apache.lucene.index.QueryTimeout;
2424
import org.apache.lucene.index.Terms;
2525
import org.apache.lucene.index.TermsEnum;
26+
import org.apache.lucene.search.AcceptDocs;
2627
import org.apache.lucene.search.DocIdSetIterator;
2728
import org.apache.lucene.search.KnnCollector;
2829
import org.apache.lucene.search.VectorScorer;
2930
import org.apache.lucene.search.suggest.document.CompletionTerms;
30-
import org.apache.lucene.util.Bits;
3131
import org.apache.lucene.util.BytesRef;
3232
import org.apache.lucene.util.automaton.CompiledAutomaton;
3333
import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
@@ -141,7 +141,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
141141
}
142142

143143
@Override
144-
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
144+
public void searchNearestVectors(String field, byte[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
145145
if (queryCancellation.isEnabled() == false) {
146146
in.searchNearestVectors(field, target, collector, acceptDocs);
147147
return;
@@ -159,7 +159,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
159159
}
160160

161161
@Override
162-
public void searchNearestVectors(String field, float[] target, KnnCollector collector, Bits acceptDocs) throws IOException {
162+
public void searchNearestVectors(String field, float[] target, KnnCollector collector, AcceptDocs acceptDocs) throws IOException {
163163
if (queryCancellation.isEnabled() == false) {
164164
in.searchNearestVectors(field, target, collector, acceptDocs);
165165
return;

0 commit comments

Comments
 (0)