Skip to content

Commit af3a0bc

Browse files
zacharymornjtibshirani
authored andcommitted
LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues (#534)
1 parent c5c082e commit af3a0bc

File tree

8 files changed

+133
-57
lines changed

8 files changed

+133
-57
lines changed

lucene/CHANGES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ API Changes
1818
org.apache.lucene.* to org.apache.lucene.tests.* to avoid package name conflicts with the
1919
core module. (Dawid Weiss)
2020

21+
* LUCENE-10183: KnnVectorsWriter#writeField to take KnnVectorsReader instead of VectorValues.
22+
(Zach Chen, Michael Sokolov, Julie Tibshirani, Adrien Grand)
23+
2124
* LUCENE-10335: Deprecate helper methods for resource loading in IOUtils and StopwordAnalyzerBase
2225
that are not compatible with module system (Class#getResourceAsStream() and Class#getResource()
2326
are caller sensitive in Java 11). Instead add utility method IOUtils#requireResourceNonNull(T)

lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.ArrayList;
2424
import java.util.Arrays;
2525
import java.util.List;
26+
import org.apache.lucene.codecs.KnnVectorsReader;
2627
import org.apache.lucene.codecs.KnnVectorsWriter;
2728
import org.apache.lucene.index.FieldInfo;
2829
import org.apache.lucene.index.IndexFileNames;
@@ -74,7 +75,9 @@ public class SimpleTextKnnVectorsWriter extends KnnVectorsWriter {
7475
}
7576

7677
@Override
77-
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
78+
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
79+
throws IOException {
80+
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
7881
long vectorDataOffset = vectorData.getFilePointer();
7982
List<Integer> docIds = new ArrayList<>();
8083
int docV;

lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
3232
import org.apache.lucene.index.VectorSimilarityFunction;
3333
import org.apache.lucene.index.VectorValues;
34+
import org.apache.lucene.search.TopDocs;
35+
import org.apache.lucene.util.Bits;
3436
import org.apache.lucene.util.BytesRef;
3537

3638
/** Writes vectors to an index. */
@@ -40,7 +42,8 @@ public abstract class KnnVectorsWriter implements Closeable {
4042
protected KnnVectorsWriter() {}
4143

4244
/** Write all values contained in the provided reader */
43-
public abstract void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException;
45+
public abstract void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
46+
throws IOException;
4447

4548
/** Called once at the end before close */
4649
public abstract void finish() throws IOException;
@@ -67,47 +70,77 @@ private void mergeVectors(FieldInfo mergeFieldInfo, final MergeState mergeState)
6770
if (mergeState.infoStream.isEnabled("VV")) {
6871
mergeState.infoStream.message("VV", "merging " + mergeState.segmentInfo);
6972
}
70-
List<VectorValuesSub> subs = new ArrayList<>();
71-
int dimension = -1;
72-
VectorSimilarityFunction similarityFunction = null;
73-
int nonEmptySegmentIndex = 0;
74-
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
75-
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
76-
if (knnVectorsReader != null) {
77-
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
78-
int segmentDimension = mergeFieldInfo.getVectorDimension();
79-
VectorSimilarityFunction segmentSimilarityFunction =
80-
mergeFieldInfo.getVectorSimilarityFunction();
81-
if (dimension == -1) {
82-
dimension = segmentDimension;
83-
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
84-
} else if (dimension != segmentDimension) {
85-
throw new IllegalStateException(
86-
"Varying dimensions for vector-valued field "
87-
+ mergeFieldInfo.name
88-
+ ": "
89-
+ dimension
90-
+ "!="
91-
+ segmentDimension);
92-
} else if (similarityFunction != segmentSimilarityFunction) {
93-
throw new IllegalStateException(
94-
"Varying similarity functions for vector-valued field "
95-
+ mergeFieldInfo.name
96-
+ ": "
97-
+ similarityFunction
98-
+ "!="
99-
+ segmentSimilarityFunction);
100-
}
101-
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
102-
if (values != null) {
103-
subs.add(new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
104-
}
105-
}
106-
}
107-
}
10873
// Create a new VectorValues by iterating over the sub vectors, mapping the resulting
10974
// docids using docMaps in the mergeState.
110-
writeField(mergeFieldInfo, new VectorValuesMerger(subs, mergeState));
75+
writeField(
76+
mergeFieldInfo,
77+
new KnnVectorsReader() {
78+
@Override
79+
public long ramBytesUsed() {
80+
return 0;
81+
}
82+
83+
@Override
84+
public void close() throws IOException {
85+
throw new UnsupportedOperationException();
86+
}
87+
88+
@Override
89+
public void checkIntegrity() throws IOException {
90+
throw new UnsupportedOperationException();
91+
}
92+
93+
@Override
94+
public VectorValues getVectorValues(String field) throws IOException {
95+
List<VectorValuesSub> subs = new ArrayList<>();
96+
int dimension = -1;
97+
VectorSimilarityFunction similarityFunction = null;
98+
int nonEmptySegmentIndex = 0;
99+
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
100+
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
101+
if (knnVectorsReader != null) {
102+
if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) {
103+
int segmentDimension = mergeFieldInfo.getVectorDimension();
104+
VectorSimilarityFunction segmentSimilarityFunction =
105+
mergeFieldInfo.getVectorSimilarityFunction();
106+
if (dimension == -1) {
107+
dimension = segmentDimension;
108+
similarityFunction = mergeFieldInfo.getVectorSimilarityFunction();
109+
} else if (dimension != segmentDimension) {
110+
throw new IllegalStateException(
111+
"Varying dimensions for vector-valued field "
112+
+ mergeFieldInfo.name
113+
+ ": "
114+
+ dimension
115+
+ "!="
116+
+ segmentDimension);
117+
} else if (similarityFunction != segmentSimilarityFunction) {
118+
throw new IllegalStateException(
119+
"Varying similarity functions for vector-valued field "
120+
+ mergeFieldInfo.name
121+
+ ": "
122+
+ similarityFunction
123+
+ "!="
124+
+ segmentSimilarityFunction);
125+
}
126+
VectorValues values = knnVectorsReader.getVectorValues(mergeFieldInfo.name);
127+
if (values != null) {
128+
subs.add(
129+
new VectorValuesSub(nonEmptySegmentIndex++, mergeState.docMaps[i], values));
130+
}
131+
}
132+
}
133+
}
134+
return new VectorValuesMerger(subs, mergeState);
135+
}
136+
137+
@Override
138+
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
139+
throws IOException {
140+
throw new UnsupportedOperationException();
141+
}
142+
});
143+
111144
if (mergeState.infoStream.isEnabled("VV")) {
112145
mergeState.infoStream.message("VV", "merge done " + mergeState.segmentInfo);
113146
}

lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsWriter.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.IOException;
2323
import java.util.Arrays;
2424
import org.apache.lucene.codecs.CodecUtil;
25+
import org.apache.lucene.codecs.KnnVectorsReader;
2526
import org.apache.lucene.codecs.KnnVectorsWriter;
2627
import org.apache.lucene.index.FieldInfo;
2728
import org.apache.lucene.index.IndexFileNames;
@@ -107,7 +108,9 @@ public final class Lucene90HnswVectorsWriter extends KnnVectorsWriter {
107108
}
108109

109110
@Override
110-
public void writeField(FieldInfo fieldInfo, VectorValues vectors) throws IOException {
111+
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
112+
throws IOException {
113+
VectorValues vectors = knnVectorsReader.getVectorValues(fieldInfo.name);
111114
long pos = vectorData.getFilePointer();
112115
// write floats aligned at 4 bytes. This will not survive CFS, but it shows a small benefit when
113116
// CFS is not used, eg for larger indexes

lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ private class FieldsWriter extends KnnVectorsWriter {
9898
}
9999

100100
@Override
101-
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
102-
getInstance(fieldInfo).writeField(fieldInfo, values);
101+
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
102+
throws IOException {
103+
getInstance(fieldInfo).writeField(fieldInfo, knnVectorsReader);
103104
}
104105

105106
@Override

lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
import java.nio.ByteOrder;
2323
import java.util.ArrayList;
2424
import java.util.List;
25+
import org.apache.lucene.codecs.KnnVectorsReader;
2526
import org.apache.lucene.codecs.KnnVectorsWriter;
2627
import org.apache.lucene.search.DocIdSetIterator;
28+
import org.apache.lucene.search.TopDocs;
2729
import org.apache.lucene.util.ArrayUtil;
30+
import org.apache.lucene.util.Bits;
2831
import org.apache.lucene.util.BytesRef;
2932
import org.apache.lucene.util.Counter;
3033
import org.apache.lucene.util.RamUsageEstimator;
@@ -107,13 +110,38 @@ private void updateBytesUsed() {
107110
* @throws IOException if there is an error writing the field and its values
108111
*/
109112
public void flush(Sorter.DocMap sortMap, KnnVectorsWriter knnVectorsWriter) throws IOException {
110-
VectorValues vectorValues =
111-
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
112-
if (sortMap != null) {
113-
knnVectorsWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap));
114-
} else {
115-
knnVectorsWriter.writeField(fieldInfo, vectorValues);
116-
}
113+
KnnVectorsReader knnVectorsReader =
114+
new KnnVectorsReader() {
115+
@Override
116+
public long ramBytesUsed() {
117+
return 0;
118+
}
119+
120+
@Override
121+
public void close() throws IOException {
122+
throw new UnsupportedOperationException();
123+
}
124+
125+
@Override
126+
public void checkIntegrity() throws IOException {
127+
throw new UnsupportedOperationException();
128+
}
129+
130+
@Override
131+
public VectorValues getVectorValues(String field) throws IOException {
132+
VectorValues vectorValues =
133+
new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension());
134+
return sortMap != null ? new SortingVectorValues(vectorValues, sortMap) : vectorValues;
135+
}
136+
137+
@Override
138+
public TopDocs search(String field, float[] target, int k, Bits acceptDocs)
139+
throws IOException {
140+
throw new UnsupportedOperationException();
141+
}
142+
};
143+
144+
knnVectorsWriter.writeField(fieldInfo, knnVectorsReader);
117145
}
118146

119147
static class SortingVectorValues extends VectorValues

lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.apache.lucene.index.NoMergePolicy;
4040
import org.apache.lucene.index.SegmentReadState;
4141
import org.apache.lucene.index.SegmentWriteState;
42-
import org.apache.lucene.index.VectorValues;
4342
import org.apache.lucene.search.TopDocs;
4443
import org.apache.lucene.store.Directory;
4544
import org.apache.lucene.tests.analysis.MockAnalyzer;
@@ -172,9 +171,10 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
172171
KnnVectorsWriter writer = delegate.fieldsWriter(state);
173172
return new KnnVectorsWriter() {
174173
@Override
175-
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
174+
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
175+
throws IOException {
176176
fieldsWritten.add(fieldInfo.name);
177-
writer.writeField(fieldInfo, values);
177+
writer.writeField(fieldInfo, knnVectorsReader);
178178
}
179179

180180
@Override

lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,15 @@ static class AssertingKnnVectorsWriter extends KnnVectorsWriter {
5858
}
5959

6060
@Override
61-
public void writeField(FieldInfo fieldInfo, VectorValues values) throws IOException {
61+
public void writeField(FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader)
62+
throws IOException {
6263
assert fieldInfo != null;
63-
assert values != null;
64-
delegate.writeField(fieldInfo, values);
64+
assert knnVectorsReader != null;
65+
// assert that knnVectorsReader#getVectorValues returns different instances upon repeated
66+
// calls
67+
assert knnVectorsReader.getVectorValues(fieldInfo.name)
68+
!= knnVectorsReader.getVectorValues(fieldInfo.name);
69+
delegate.writeField(fieldInfo, knnVectorsReader);
6570
}
6671

6772
@Override

0 commit comments

Comments
 (0)