Skip to content

Commit 3115f85

Browse files
authored
LUCENE-9908: Move VectorValues#search to LeafReader (#104)
This PR removes `VectorValues#search` in favor of exposing NN search through `VectorReader#search` and `LeafReader#searchNearestVectors`. It also marks the vector methods on `LeafReader` as experimental.
1 parent 6b386e7 commit 3115f85

File tree

24 files changed

+205
-119
lines changed

24 files changed

+205
-119
lines changed

lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorReader.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ public VectorValues getVectorValues(String field) throws IOException {
141141
return new SimpleTextVectorValues(fieldEntry, bytesSlice);
142142
}
143143

144+
@Override
145+
public TopDocs search(String field, float[] target, int k, int fanout) throws IOException {
146+
throw new UnsupportedOperationException();
147+
}
148+
144149
@Override
145150
public void checkIntegrity() throws IOException {
146151
IndexInput clone = dataIn.clone();
@@ -334,11 +339,6 @@ public float[] vectorValue(int targetOrd) throws IOException {
334339
public BytesRef binaryValue(int targetOrd) throws IOException {
335340
throw new UnsupportedOperationException();
336341
}
337-
338-
@Override
339-
public TopDocs search(float[] target, int k, int fanout) throws IOException {
340-
throw new UnsupportedOperationException();
341-
}
342342
}
343343

344344
private int readInt(IndexInput in, BytesRef field) throws IOException {

lucene/core/src/java/org/apache/lucene/codecs/VectorFormat.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.apache.lucene.index.SegmentReadState;
2222
import org.apache.lucene.index.SegmentWriteState;
2323
import org.apache.lucene.index.VectorValues;
24+
import org.apache.lucene.search.TopDocs;
25+
import org.apache.lucene.search.TopDocsCollector;
2426

2527
/**
2628
* Encodes/decodes per-document vector and any associated indexing structures required to support
@@ -61,7 +63,12 @@ public VectorValues getVectorValues(String field) {
6163
}
6264

6365
@Override
64-
public void close() throws IOException {}
66+
public TopDocs search(String field, float[] target, int k, int fanout) {
67+
return TopDocsCollector.EMPTY_TOPDOCS;
68+
}
69+
70+
@Override
71+
public void close() {}
6572

6673
@Override
6774
public long ramBytesUsed() {

lucene/core/src/java/org/apache/lucene/codecs/VectorReader.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.io.Closeable;
2121
import java.io.IOException;
2222
import org.apache.lucene.index.VectorValues;
23+
import org.apache.lucene.search.TopDocs;
2324
import org.apache.lucene.util.Accountable;
2425

2526
/** Reads vectors from an index. */
@@ -41,6 +42,22 @@ protected VectorReader() {}
4142
/** Returns the {@link VectorValues} for the given {@code field} */
4243
public abstract VectorValues getVectorValues(String field) throws IOException;
4344

45+
/**
46+
* Return the k nearest neighbor documents as determined by comparison of their vector values for
47+
* this field, to the given vector, by the field's search strategy. If the search strategy is
48+
* reversed, lower values indicate nearer vectors, otherwise higher scores indicate nearer
49+
* vectors. Unlike relevance scores, vector scores may be negative.
50+
*
51+
* @param field the vector field to search
52+
* @param target the vector-valued query
53+
* @param k the number of docs to return
54+
* @param fanout control the accuracy/speed tradeoff - larger values give better recall at higher
55+
* cost
56+
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
57+
*/
58+
public abstract TopDocs search(String field, float[] target, int k, int fanout)
59+
throws IOException;
60+
4461
/**
4562
* Returns an instance optimized for merging. This instance may only be consumed in the thread
4663
* that called {@link #getMergeInstance()}.

lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import org.apache.lucene.index.RandomAccessVectorValues;
3131
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
3232
import org.apache.lucene.index.VectorValues;
33-
import org.apache.lucene.search.TopDocs;
3433
import org.apache.lucene.util.BytesRef;
3534

3635
/** Writes vectors to an index. */
@@ -246,11 +245,6 @@ public SearchStrategy searchStrategy() {
246245
return subs.get(0).values.searchStrategy();
247246
}
248247

249-
@Override
250-
public TopDocs search(float[] target, int k, int fanout) throws IOException {
251-
throw new UnsupportedOperationException();
252-
}
253-
254248
class MergerRandomAccess implements RandomAccessVectorValues {
255249

256250
private final List<RandomAccessVectorValues> raSubs;

lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90VectorReader.java

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,36 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
154154
if (info == null) {
155155
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
156156
}
157-
fields.put(info.name, readField(meta));
157+
158+
FieldEntry fieldEntry = readField(meta);
159+
validateFieldEntry(info, fieldEntry);
160+
fields.put(info.name, fieldEntry);
161+
}
162+
}
163+
164+
private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
165+
int dimension = info.getVectorDimension();
166+
if (dimension != fieldEntry.dimension) {
167+
throw new IllegalStateException(
168+
"Inconsistent vector dimension for field=\""
169+
+ info.name
170+
+ "\"; "
171+
+ dimension
172+
+ " != "
173+
+ fieldEntry.dimension);
174+
}
175+
176+
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES;
177+
if (numBytes != fieldEntry.vectorDataLength) {
178+
throw new IllegalStateException(
179+
"Vector data length "
180+
+ fieldEntry.vectorDataLength
181+
+ " not matching size="
182+
+ fieldEntry.size()
183+
+ " * dim="
184+
+ dimension
185+
+ " * 4 = "
186+
+ numBytes);
158187
}
159188
}
160189

@@ -199,40 +228,47 @@ public void checkIntegrity() throws IOException {
199228

200229
@Override
201230
public VectorValues getVectorValues(String field) throws IOException {
202-
FieldInfo info = fieldInfos.fieldInfo(field);
203-
if (info == null) {
231+
FieldEntry fieldEntry = fields.get(field);
232+
if (fieldEntry == null || fieldEntry.dimension == 0) {
204233
return null;
205234
}
206-
int dimension = info.getVectorDimension();
207-
if (dimension == 0) {
208-
return VectorValues.EMPTY;
209-
}
235+
236+
return getOffHeapVectorValues(fieldEntry);
237+
}
238+
239+
@Override
240+
public TopDocs search(String field, float[] target, int k, int fanout) throws IOException {
210241
FieldEntry fieldEntry = fields.get(field);
211-
if (fieldEntry == null) {
212-
// There is a FieldInfo, but no vectors. Should we have deleted the FieldInfo?
242+
if (fieldEntry == null || fieldEntry.dimension == 0) {
213243
return null;
214244
}
215-
if (dimension != fieldEntry.dimension) {
216-
throw new IllegalStateException(
217-
"Inconsistent vector dimension for field=\""
218-
+ field
219-
+ "\"; "
220-
+ dimension
221-
+ " != "
222-
+ fieldEntry.dimension);
223-
}
224-
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES;
225-
if (numBytes != fieldEntry.vectorDataLength) {
226-
throw new IllegalStateException(
227-
"Vector data length "
228-
+ fieldEntry.vectorDataLength
229-
+ " not matching size="
230-
+ fieldEntry.size()
231-
+ " * dim="
232-
+ dimension
233-
+ " * 4 = "
234-
+ numBytes);
245+
246+
OffHeapVectorValues vectorValues = getOffHeapVectorValues(fieldEntry);
247+
248+
// use a seed that is fixed for the index so we get reproducible results for the same query
249+
final Random random = new Random(checksumSeed);
250+
NeighborQueue results =
251+
HnswGraph.search(target, k, k + fanout, vectorValues, getGraphValues(fieldEntry), random);
252+
int i = 0;
253+
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)];
254+
boolean reversed = fieldEntry.searchStrategy.reversed;
255+
while (results.size() > 0) {
256+
int node = results.topNode();
257+
float score = results.topScore();
258+
results.pop();
259+
if (reversed) {
260+
score = (float) Math.exp(-score / target.length);
261+
}
262+
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
235263
}
264+
// always return >= the case where we can assert == is only when there are fewer than topK
265+
// vectors in the index
266+
return new TopDocs(
267+
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
268+
scoreDocs);
269+
}
270+
271+
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
236272
IndexInput bytesSlice =
237273
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
238274
return new OffHeapVectorValues(fieldEntry, bytesSlice);
@@ -408,32 +444,6 @@ public RandomAccessVectorValues randomAccess() {
408444
return new OffHeapVectorValues(fieldEntry, dataIn.clone());
409445
}
410446

411-
@Override
412-
public TopDocs search(float[] vector, int topK, int fanout) throws IOException {
413-
// use a seed that is fixed for the index so we get reproducible results for the same query
414-
final Random random = new Random(checksumSeed);
415-
NeighborQueue results =
416-
HnswGraph.search(
417-
vector, topK, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
418-
int i = 0;
419-
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
420-
boolean reversed = searchStrategy().reversed;
421-
while (results.size() > 0) {
422-
int node = results.topNode();
423-
float score = results.topScore();
424-
results.pop();
425-
if (reversed) {
426-
score = (float) Math.exp(-score / vector.length);
427-
}
428-
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
429-
}
430-
// always return >= the case where we can assert == is only when there are fewer than topK
431-
// vectors in the index
432-
return new TopDocs(
433-
new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
434-
scoreDocs);
435-
}
436-
437447
@Override
438448
public float[] vectorValue(int targetOrd) throws IOException {
439449
dataIn.seek((long) targetOrd * byteSize);

lucene/core/src/java/org/apache/lucene/index/CodecReader.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.codecs.StoredFieldsReader;
2626
import org.apache.lucene.codecs.TermVectorsReader;
2727
import org.apache.lucene.codecs.VectorReader;
28+
import org.apache.lucene.search.TopDocs;
2829

2930
/** LeafReader implemented by codec APIs. */
3031
public abstract class CodecReader extends LeafReader {
@@ -218,6 +219,19 @@ public final VectorValues getVectorValues(String field) throws IOException {
218219
return getVectorReader().getVectorValues(field);
219220
}
220221

222+
@Override
223+
public final TopDocs searchNearestVectors(String field, float[] target, int k, int fanout)
224+
throws IOException {
225+
ensureOpen();
226+
FieldInfo fi = getFieldInfos().fieldInfo(field);
227+
if (fi == null || fi.getVectorDimension() == 0) {
228+
// Field does not exist or does not index vectors
229+
return null;
230+
}
231+
232+
return getVectorReader().search(field, target, k, fanout);
233+
}
234+
221235
@Override
222236
protected void doClose() throws IOException {}
223237

lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.index;
1919

2020
import java.io.IOException;
21+
import org.apache.lucene.search.TopDocs;
2122
import org.apache.lucene.util.Bits;
2223

2324
abstract class DocValuesLeafReader extends LeafReader {
@@ -51,6 +52,12 @@ public final VectorValues getVectorValues(String field) throws IOException {
5152
throw new UnsupportedOperationException();
5253
}
5354

55+
@Override
56+
public TopDocs searchNearestVectors(String field, float[] target, int k, int fanout)
57+
throws IOException {
58+
throw new UnsupportedOperationException();
59+
}
60+
5461
@Override
5562
public final void checkIntegrity() throws IOException {
5663
throw new UnsupportedOperationException();

lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.Iterator;
21+
import org.apache.lucene.search.TopDocs;
2122
import org.apache.lucene.util.AttributeSource;
2223
import org.apache.lucene.util.Bits;
2324
import org.apache.lucene.util.BytesRef;
@@ -343,6 +344,12 @@ public VectorValues getVectorValues(String field) throws IOException {
343344
return in.getVectorValues(field);
344345
}
345346

347+
@Override
348+
public TopDocs searchNearestVectors(String field, float[] target, int k, int fanout)
349+
throws IOException {
350+
return in.searchNearestVectors(field, target, k, fanout);
351+
}
352+
346353
@Override
347354
public Fields getTermVectors(int docID) throws IOException {
348355
ensureOpen();

lucene/core/src/java/org/apache/lucene/index/LeafReader.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.lucene.index;
1818

1919
import java.io.IOException;
20+
import org.apache.lucene.search.TopDocs;
2021
import org.apache.lucene.util.Bits;
2122

2223
/**
@@ -207,9 +208,28 @@ public final PostingsEnum postings(Term term) throws IOException {
207208
/**
208209
* Returns {@link VectorValues} for this field, or null if no {@link VectorValues} were indexed.
209210
* The returned instance should only be used by a single thread.
211+
*
212+
* @lucene.experimental
210213
*/
211214
public abstract VectorValues getVectorValues(String field) throws IOException;
212215

216+
/**
217+
* Return the k nearest neighbor documents as determined by comparison of their vector values for
218+
* this field, to the given vector, by the field's search strategy. If the search strategy is
219+
* reversed, lower values indicate nearer vectors, otherwise higher scores indicate nearer
220+
* vectors. Unlike relevance scores, vector scores may be negative.
221+
*
222+
* @param field the vector field to search
223+
* @param target the vector-valued query
224+
* @param k the number of docs to return
225+
* @param fanout control the accuracy/speed tradeoff - larger values give better recall at higher
226+
* cost
227+
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
228+
* @lucene.experimental
229+
*/
230+
public abstract TopDocs searchNearestVectors(String field, float[] target, int k, int fanout)
231+
throws IOException;
232+
213233
/**
214234
* Get the {@link FieldInfos} describing all fields in this reader.
215235
*

lucene/core/src/java/org/apache/lucene/index/MergeReaderWrapper.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.codecs.NormsProducer;
2525
import org.apache.lucene.codecs.StoredFieldsReader;
2626
import org.apache.lucene.codecs.TermVectorsReader;
27+
import org.apache.lucene.search.TopDocs;
2728
import org.apache.lucene.util.Bits;
2829

2930
/**
@@ -202,6 +203,12 @@ public VectorValues getVectorValues(String fieldName) throws IOException {
202203
return in.getVectorValues(fieldName);
203204
}
204205

206+
@Override
207+
public TopDocs searchNearestVectors(String field, float[] target, int k, int fanout)
208+
throws IOException {
209+
return in.searchNearestVectors(field, target, k, fanout);
210+
}
211+
205212
@Override
206213
public int numDocs() {
207214
return in.numDocs();

0 commit comments

Comments
 (0)