Skip to content

Commit 5edbd2b

Browse files
committed
add full precision byte vector sim values source
1 parent 83816c5 commit 5edbd2b

File tree

4 files changed

+182
-10
lines changed

4 files changed

+182
-10
lines changed

lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,51 @@
2121
import java.util.Arrays;
2222
import java.util.Objects;
2323
import org.apache.lucene.index.ByteVectorValues;
24+
import org.apache.lucene.index.FieldInfo;
25+
import org.apache.lucene.index.KnnVectorValues;
2426
import org.apache.lucene.index.LeafReaderContext;
27+
import org.apache.lucene.index.VectorEncoding;
28+
import org.apache.lucene.index.VectorSimilarityFunction;
2529

2630
/**
2731
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
2832
* and the {@link org.apache.lucene.document.KnnByteVectorField} for documents.
2933
*/
3034
class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
35+
36+
/** Creates a {@link ByteVectorSimilarityValuesSource} that scores on full precision vector values */
37+
public static DoubleValues fullPrecisionScores(
38+
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
39+
return new ByteVectorSimilarityValuesSource(queryVector, vectorField, true).getValues(ctx, null);
40+
}
41+
3142
private final byte[] queryVector;
43+
private final boolean useFullPrecision;
3244

45+
/**
46+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
47+
* query vector and field for documents. Uses the scorer exposed by configured vectors reader.
48+
*
49+
* @param vector the query vector
50+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnByteVectorField}
51+
*/
3352
public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
53+
this(vector, fieldName, false);
54+
}
55+
56+
/**
57+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
58+
* query vector and field for documents.
59+
*
60+
* @param vector the query vector
61+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnByteVectorField}
62+
* @param useFullPrecision uses full precision raw vectors for similarity computation if true,
63+
* otherwise the configured vectors reader is used, which may be quantized or full precision.
64+
*/
65+
public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName, boolean useFullPrecision) {
3466
super(fieldName);
3567
this.queryVector = vector;
68+
this.useFullPrecision = useFullPrecision;
3669
}
3770

3871
@Override
@@ -42,7 +75,43 @@ public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
4275
ByteVectorValues.checkField(ctx.reader(), fieldName);
4376
return null;
4477
}
45-
return vectorValues.scorer(queryVector);
78+
79+
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
80+
if (fi.getVectorEncoding() != VectorEncoding.BYTE) {
81+
throw new IllegalArgumentException(
82+
"Field "
83+
+ fieldName
84+
+ " does not have the expected vector encoding: "
85+
+ VectorEncoding.BYTE);
86+
}
87+
if (fi.getVectorDimension() != queryVector.length) {
88+
throw new IllegalArgumentException(
89+
"Query vector dimension does not match field dimension: "
90+
+ queryVector.length
91+
+ " != "
92+
+ fi.getVectorDimension());
93+
}
94+
95+
// default vector scorer
96+
if (useFullPrecision == false) {
97+
return vectorValues.scorer(queryVector);
98+
}
99+
100+
final VectorSimilarityFunction vectorSimilarityFunction = fi.getVectorSimilarityFunction();
101+
return new VectorScorer() {
102+
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
103+
104+
@Override
105+
public float score() throws IOException {
106+
return vectorSimilarityFunction.compare(
107+
queryVector, vectorValues.vectorValue(iterator.index()));
108+
}
109+
110+
@Override
111+
public DocIdSetIterator iterator() {
112+
return iterator;
113+
}
114+
};
46115
}
47116

48117
@Override

lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,6 @@ public LongValuesSource rewrite(IndexSearcher searcher) throws IOException {
250250
*/
251251
public static DoubleValues similarityToQueryVector(
252252
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
253-
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
254-
!= VectorEncoding.BYTE) {
255-
throw new IllegalArgumentException(
256-
"Field "
257-
+ vectorField
258-
+ " does not have the expected vector encoding: "
259-
+ VectorEncoding.BYTE);
260-
}
261253
return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
262254
}
263255

lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
*/
3434
class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
3535

36+
/** Creates a {@link FloatVectorSimilarityValuesSource} that scores on full precision vector values */
3637
public static DoubleValues fullPrecisionScores(
3738
LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException {
3839
return new FloatVectorSimilarityValuesSource(queryVector, vectorField, true).getValues(ctx, null);

lucene/core/src/test/org/apache/lucene/search/TestQuantizedVectorSimilarityValueSource.java

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.apache.lucene.document.Document;
1010
import org.apache.lucene.document.Field;
1111
import org.apache.lucene.document.IntField;
12+
import org.apache.lucene.document.KnnByteVectorField;
1213
import org.apache.lucene.document.KnnFloatVectorField;
1314
import org.apache.lucene.index.DirectoryReader;
1415
import org.apache.lucene.index.IndexReader;
@@ -146,7 +147,6 @@ public void testFullPrecisionVectorSimilarityDVS() throws Exception {
146147

147148
float[] queryVector = TestVectorUtil.randomVector(VECTOR_DIMENSION);
148149
try (IndexReader reader = DirectoryReader.open(dir)) {
149-
FieldExistsQuery query = new FieldExistsQuery(KNN_FIELD);
150150
for (LeafReaderContext ctx : reader.leaves()) {
151151
DoubleValues fpSimValues = FloatVectorSimilarityValuesSource.fullPrecisionScores(ctx, queryVector, KNN_FIELD);
152152
DoubleValues quantizedSimValues = DoubleValuesSource.similarityToQueryVector(ctx, queryVector, KNN_FIELD);
@@ -180,4 +180,114 @@ public void testFullPrecisionVectorSimilarityDVS() throws Exception {
180180
}
181181
}
182182
}
183+
184+
@Test
185+
public void testFullPrecisionByteVectorSimilarityDVS() throws Exception {
186+
List<byte[]> vectors = new ArrayList<>();
187+
int numVectors = atLeast(NUM_VECTORS);
188+
int numSegments = random().nextInt(2, 10);
189+
final VectorSimilarityFunction vectorSimilarityFunction =
190+
VectorSimilarityFunction.values()[
191+
random().nextInt(VectorSimilarityFunction.values().length)];
192+
193+
try (Directory dir = newDirectory()) {
194+
int id = 0;
195+
196+
// index some 4 bit quantized vectors
197+
try (IndexWriter w =
198+
new IndexWriter(
199+
dir,
200+
newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(4))))) {
201+
for (int j = 0; j < numSegments; j++) {
202+
for (int i = 0; i < numVectors; i++) {
203+
Document doc = new Document();
204+
if (random().nextInt(100) < 30) {
205+
// skip vector for some docs to create sparse vector field
206+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
207+
} else {
208+
byte[] vector = TestVectorUtil.randomVectorBytes(VECTOR_DIMENSION);
209+
vectors.add(vector);
210+
doc.add(new IntField("id", id++, Field.Store.YES));
211+
doc.add(new KnnByteVectorField(KNN_FIELD, vector, vectorSimilarityFunction));
212+
doc.add(new IntField("has_vector", 1, Field.Store.YES));
213+
}
214+
w.addDocument(doc);
215+
w.flush();
216+
}
217+
}
218+
// add a segment with no vectors
219+
for (int i = 0; i < 100; i++) {
220+
Document doc = new Document();
221+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
222+
w.addDocument(doc);
223+
}
224+
w.flush();
225+
}
226+
227+
// index some 7 bit quantized vectors
228+
try (IndexWriter w =
229+
new IndexWriter(
230+
dir,
231+
newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) {
232+
for (int j = 0; j < numSegments; j++) {
233+
for (int i = 0; i < numVectors; i++) {
234+
Document doc = new Document();
235+
if (random().nextInt(100) < 30) {
236+
// skip vector for some docs to create sparse vector field
237+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
238+
} else {
239+
byte[] vector = TestVectorUtil.randomVectorBytes(VECTOR_DIMENSION);
240+
vectors.add(vector);
241+
doc.add(new IntField("id", id++, Field.Store.YES));
242+
doc.add(new KnnByteVectorField(KNN_FIELD, vector, vectorSimilarityFunction));
243+
doc.add(new IntField("has_vector", 1, Field.Store.YES));
244+
}
245+
w.addDocument(doc);
246+
w.flush();
247+
}
248+
}
249+
// add a segment with no vectors
250+
for (int i = 0; i < 100; i++) {
251+
Document doc = new Document();
252+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
253+
w.addDocument(doc);
254+
}
255+
w.flush();
256+
}
257+
258+
byte[] queryVector = TestVectorUtil.randomVectorBytes(VECTOR_DIMENSION);
259+
try (IndexReader reader = DirectoryReader.open(dir)) {
260+
for (LeafReaderContext ctx : reader.leaves()) {
261+
DoubleValues fpSimValues = ByteVectorSimilarityValuesSource.fullPrecisionScores(ctx, queryVector, KNN_FIELD);
262+
DoubleValues quantizedSimValues = DoubleValuesSource.similarityToQueryVector(ctx, queryVector, KNN_FIELD);
263+
// validate when segment has no vectors
264+
if (fpSimValues == DoubleValues.EMPTY || quantizedSimValues == DoubleValues.EMPTY) {
265+
assertEquals(fpSimValues, quantizedSimValues);
266+
assertNull(ctx.reader().getByteVectorValues(KNN_FIELD));
267+
continue;
268+
}
269+
StoredFields storedFields = ctx.reader().storedFields();
270+
VectorScorer quantizedScorer =
271+
ctx.reader().getByteVectorValues(KNN_FIELD).scorer(queryVector);
272+
DocIdSetIterator disi = quantizedScorer.iterator();
273+
while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
274+
int doc = disi.docID();
275+
fpSimValues.advanceExact(doc);
276+
quantizedSimValues.advanceExact(doc);
277+
int idValue = Integer.parseInt(storedFields.document(doc).get("id"));
278+
byte[] docVector = vectors.get(idValue);
279+
assert docVector != null : "Vector for id " + idValue + " not found";
280+
// validate full precision vector scores
281+
double expectedFpScore = vectorSimilarityFunction.compare(queryVector, docVector);
282+
double actualFpScore = fpSimValues.doubleValue();
283+
assertEquals(expectedFpScore, actualFpScore, 1e-5);
284+
// validate quantized vector scores
285+
double expectedQScore = quantizedScorer.score();
286+
double actualQScore = quantizedSimValues.doubleValue();
287+
assertEquals(expectedQScore, actualQScore, 1e-5);
288+
}
289+
}
290+
}
291+
}
292+
}
183293
}

0 commit comments

Comments
 (0)