Skip to content

Commit 928fdca

Browse files
committed
fp float vector values source with tests
1 parent 41abd7a commit 928fdca

File tree

3 files changed

+234
-2
lines changed

3 files changed

+234
-2
lines changed

lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ public static DoubleValues similarityToQueryVector(
281281
+ " does not have the expected vector encoding: "
282282
+ VectorEncoding.FLOAT32);
283283
}
284-
return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
284+
return new FloatVectorSimilarityValuesSource(queryVector, vectorField, true).getValues(ctx, null);
285285
}
286286

287287
/**

lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
import java.io.IOException;
2121
import java.util.Arrays;
2222
import java.util.Objects;
23+
24+
import org.apache.lucene.index.FieldInfo;
2325
import org.apache.lucene.index.FloatVectorValues;
26+
import org.apache.lucene.index.KnnVectorValues;
2427
import org.apache.lucene.index.LeafReaderContext;
28+
import org.apache.lucene.index.VectorSimilarityFunction;
2529

2630
/**
2731
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
@@ -30,10 +34,31 @@
3034
class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
3135

3236
private final float[] queryVector;
37+
private final boolean useFullPrecision;
3338

39+
/**
40+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
41+
* query vector and field for documents. Uses the scorer exposed by configured vectors reader.
42+
* @param vector the query vector
43+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
44+
*/
3445
public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
46+
this(vector, fieldName, false);
47+
}
48+
49+
/**
50+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
51+
* query vector and field for documents.
52+
*
53+
* @param vector the query vector
54+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
55+
* @param useFullPrecision uses full precision raw vectors for similarity computation if true, otherwise
56+
* the configured vectors reader is used, which may be quantized or full precision.
57+
*/
58+
public FloatVectorSimilarityValuesSource(float[] vector, String fieldName, boolean useFullPrecision) {
3559
super(fieldName);
3660
this.queryVector = vector;
61+
this.useFullPrecision = useFullPrecision;
3762
}
3863

3964
@Override
@@ -43,7 +68,35 @@ public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
4368
FloatVectorValues.checkField(ctx.reader(), fieldName);
4469
return null;
4570
}
46-
return vectorValues.scorer(queryVector);
71+
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
72+
final VectorSimilarityFunction vectorSimilarityFunction = fi.getVectorSimilarityFunction();
73+
if (fi.getVectorDimension() != queryVector.length) {
74+
throw new IllegalArgumentException(
75+
"Query vector dimension does not match field dimension: "
76+
+ queryVector.length
77+
+ " != "
78+
+ fi.getVectorDimension());
79+
}
80+
81+
if (useFullPrecision == false) {
82+
// use default VectorScorer for configured reader
83+
return vectorValues.scorer(queryVector);
84+
}
85+
86+
// return a full precision vector scorer
87+
return new VectorScorer() {
88+
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
89+
90+
@Override
91+
public float score() throws IOException {
92+
return vectorSimilarityFunction.compare(queryVector, vectorValues.vectorValue(iterator.index()));
93+
}
94+
95+
@Override
96+
public DocIdSetIterator iterator() {
97+
return iterator;
98+
}
99+
};
47100
}
48101

49102
@Override
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package org.apache.lucene.search;
2+
3+
import org.apache.lucene.codecs.Codec;
4+
import org.apache.lucene.codecs.KnnVectorsFormat;
5+
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
6+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
7+
import org.apache.lucene.document.Document;
8+
import org.apache.lucene.document.Field;
9+
import org.apache.lucene.document.IntField;
10+
import org.apache.lucene.document.KnnFloatVectorField;
11+
import org.apache.lucene.index.DirectoryReader;
12+
import org.apache.lucene.index.IndexReader;
13+
import org.apache.lucene.index.IndexWriter;
14+
import org.apache.lucene.index.LeafReaderContext;
15+
import org.apache.lucene.index.StoredFields;
16+
import org.apache.lucene.index.VectorSimilarityFunction;
17+
import org.apache.lucene.store.Directory;
18+
import org.apache.lucene.tests.util.LuceneTestCase;
19+
import org.apache.lucene.tests.util.TestUtil;
20+
import org.apache.lucene.util.TestVectorUtil;
21+
import org.junit.Before;
22+
import org.junit.Test;
23+
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
27+
public class TestQuantizedVectorSimilarityValueSource extends LuceneTestCase {
28+
29+
private Codec savedCodec;
30+
31+
private static final String KNN_FIELD = "knnField";
32+
private static final int NUM_VECTORS = 1000;
33+
private static final int VECTOR_DIMENSION = 128;
34+
35+
KnnVectorsFormat format;
36+
Float confidenceInterval;
37+
int bits;
38+
39+
@Before
40+
@Override
41+
public void setUp() throws Exception {
42+
super.setUp();
43+
bits = random().nextBoolean() ? 4 : 7;
44+
confidenceInterval = random().nextBoolean() ? random().nextFloat(0.90f, 1.0f) : null;
45+
if (random().nextBoolean()) {
46+
confidenceInterval = 0f;
47+
}
48+
format = getKnnFormat(bits);
49+
savedCodec = Codec.getDefault();
50+
Codec.setDefault(getCodec());
51+
}
52+
53+
@Override
54+
public void tearDown() throws Exception {
55+
Codec.setDefault(savedCodec); // restore
56+
super.tearDown();
57+
}
58+
59+
protected Codec getCodec() {
60+
return TestUtil.alwaysKnnVectorsFormat(format);
61+
}
62+
63+
private final KnnVectorsFormat getKnnFormat(int bits) {
64+
return new Lucene99HnswScalarQuantizedVectorsFormat(
65+
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
66+
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
67+
1,
68+
bits,
69+
bits == 4 ? random().nextBoolean() : false,
70+
confidenceInterval,
71+
null);
72+
}
73+
74+
@Test
75+
public void testFullPrecisionVectorSimilarityDVS() throws Exception {
76+
List<float[]> vectors = new ArrayList<>();
77+
int numVectors = atLeast(NUM_VECTORS);
78+
int numSegments = random().nextInt(2, 10);
79+
final VectorSimilarityFunction vectorSimilarityFunction =
80+
VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
81+
82+
try (Directory dir = newDirectory()) {
83+
int id = 0;
84+
85+
// index some 4 bit quantized vectors
86+
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(4))))) {
87+
for (int j = 0; j < numSegments; j++) {
88+
for (int i = 0; i < numVectors; i++) {
89+
Document doc = new Document();
90+
if (random().nextInt(100) < 30) {
91+
// skip vector for some docs to create sparse vector field
92+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
93+
} else {
94+
float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION);
95+
vectors.add(vector);
96+
doc.add(new IntField("id", id++, Field.Store.YES));
97+
doc.add(new KnnFloatVectorField(KNN_FIELD, vector, vectorSimilarityFunction));
98+
doc.add(new IntField("has_vector", 1, Field.Store.YES));
99+
}
100+
w.addDocument(doc);
101+
w.flush();
102+
}
103+
}
104+
// add a segment with no vectors
105+
for (int i = 0; i < 100; i++) {
106+
Document doc = new Document();
107+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
108+
w.addDocument(doc);
109+
}
110+
w.flush();
111+
}
112+
113+
// index some 7 bit quantized vectors
114+
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) {
115+
for (int j = 0; j < numSegments; j++) {
116+
for (int i = 0; i < numVectors; i++) {
117+
Document doc = new Document();
118+
if (random().nextInt(100) < 30) {
119+
// skip vector for some docs to create sparse vector field
120+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
121+
} else {
122+
float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION);
123+
vectors.add(vector);
124+
doc.add(new IntField("id", id++, Field.Store.YES));
125+
doc.add(new KnnFloatVectorField(KNN_FIELD, vector, vectorSimilarityFunction));
126+
doc.add(new IntField("has_vector", 1, Field.Store.YES));
127+
}
128+
w.addDocument(doc);
129+
w.flush();
130+
}
131+
}
132+
// add a segment with no vectors
133+
for (int i = 0; i < 100; i++) {
134+
Document doc = new Document();
135+
doc.add(new IntField("has_vector", 0, Field.Store.YES));
136+
w.addDocument(doc);
137+
}
138+
w.flush();
139+
}
140+
141+
float[] queryVector = TestVectorUtil.randomVector(VECTOR_DIMENSION);
142+
FloatVectorSimilarityValuesSource fpSimValueSource = new FloatVectorSimilarityValuesSource(queryVector, KNN_FIELD, true);
143+
FloatVectorSimilarityValuesSource quantizedSimValueSource = new FloatVectorSimilarityValuesSource(queryVector, KNN_FIELD);
144+
145+
try (IndexReader reader = DirectoryReader.open(dir)) {
146+
FieldExistsQuery query = new FieldExistsQuery(KNN_FIELD);
147+
for (LeafReaderContext ctx: reader.leaves()) {
148+
DoubleValues fpSimValues = fpSimValueSource.getValues(ctx, null);
149+
DoubleValues quantizedSimValues = quantizedSimValueSource.getValues(ctx, null);
150+
// validate when segment has no vectors
151+
if (fpSimValues == DoubleValues.EMPTY || quantizedSimValues == DoubleValues.EMPTY) {
152+
assertEquals(fpSimValues, quantizedSimValues);
153+
assertNull(ctx.reader().getFloatVectorValues(KNN_FIELD));
154+
continue;
155+
}
156+
StoredFields storedFields = ctx.reader().storedFields();
157+
VectorScorer quantizedScorer = ctx.reader().getFloatVectorValues(KNN_FIELD).scorer(queryVector);
158+
DocIdSetIterator disi = quantizedScorer.iterator();
159+
while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
160+
int doc = disi.docID();
161+
fpSimValues.advanceExact(doc);
162+
quantizedSimValues.advanceExact(doc);
163+
int idValue = Integer.parseInt(storedFields.document(doc).get("id"));
164+
float[] docVector = vectors.get(idValue);
165+
assert docVector != null : "Vector for id " + idValue + " not found";
166+
// validate full precision vector scores
167+
double expectedFpScore = vectorSimilarityFunction.compare(queryVector, docVector);
168+
double actualFpScore = fpSimValues.doubleValue();
169+
assertEquals(expectedFpScore, actualFpScore, 1e-5);
170+
// validate quantized vector scores
171+
double expectedQScore = quantizedScorer.score();
172+
double actualQScore = quantizedSimValues.doubleValue();
173+
assertEquals(expectedQScore, actualQScore, 1e-5);
174+
}
175+
}
176+
}
177+
}
178+
}
179+
}

0 commit comments

Comments
 (0)