|
| 1 | +package org.apache.lucene.search; |
| 2 | + |
| 3 | +import org.apache.lucene.codecs.Codec; |
| 4 | +import org.apache.lucene.codecs.KnnVectorsFormat; |
| 5 | +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; |
| 6 | +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; |
| 7 | +import org.apache.lucene.document.Document; |
| 8 | +import org.apache.lucene.document.Field; |
| 9 | +import org.apache.lucene.document.IntField; |
| 10 | +import org.apache.lucene.document.KnnFloatVectorField; |
| 11 | +import org.apache.lucene.index.DirectoryReader; |
| 12 | +import org.apache.lucene.index.IndexReader; |
| 13 | +import org.apache.lucene.index.IndexWriter; |
| 14 | +import org.apache.lucene.index.LeafReaderContext; |
| 15 | +import org.apache.lucene.index.StoredFields; |
| 16 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 17 | +import org.apache.lucene.store.Directory; |
| 18 | +import org.apache.lucene.tests.util.LuceneTestCase; |
| 19 | +import org.apache.lucene.tests.util.TestUtil; |
| 20 | +import org.apache.lucene.util.TestVectorUtil; |
| 21 | +import org.junit.Before; |
| 22 | +import org.junit.Test; |
| 23 | + |
| 24 | +import java.util.ArrayList; |
| 25 | +import java.util.List; |
| 26 | + |
| 27 | +public class TestQuantizedVectorSimilarityValueSource extends LuceneTestCase { |
| 28 | + |
| 29 | + private Codec savedCodec; |
| 30 | + |
| 31 | + private static final String KNN_FIELD = "knnField"; |
| 32 | + private static final int NUM_VECTORS = 1000; |
| 33 | + private static final int VECTOR_DIMENSION = 128; |
| 34 | + |
| 35 | + KnnVectorsFormat format; |
| 36 | + Float confidenceInterval; |
| 37 | + int bits; |
| 38 | + |
| 39 | + @Before |
| 40 | + @Override |
| 41 | + public void setUp() throws Exception { |
| 42 | + super.setUp(); |
| 43 | + bits = random().nextBoolean() ? 4 : 7; |
| 44 | + confidenceInterval = random().nextBoolean() ? random().nextFloat(0.90f, 1.0f) : null; |
| 45 | + if (random().nextBoolean()) { |
| 46 | + confidenceInterval = 0f; |
| 47 | + } |
| 48 | + format = getKnnFormat(bits); |
| 49 | + savedCodec = Codec.getDefault(); |
| 50 | + Codec.setDefault(getCodec()); |
| 51 | + } |
| 52 | + |
| 53 | + @Override |
| 54 | + public void tearDown() throws Exception { |
| 55 | + Codec.setDefault(savedCodec); // restore |
| 56 | + super.tearDown(); |
| 57 | + } |
| 58 | + |
| 59 | + protected Codec getCodec() { |
| 60 | + return TestUtil.alwaysKnnVectorsFormat(format); |
| 61 | + } |
| 62 | + |
| 63 | + private final KnnVectorsFormat getKnnFormat(int bits) { |
| 64 | + return new Lucene99HnswScalarQuantizedVectorsFormat( |
| 65 | + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, |
| 66 | + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, |
| 67 | + 1, |
| 68 | + bits, |
| 69 | + bits == 4 ? random().nextBoolean() : false, |
| 70 | + confidenceInterval, |
| 71 | + null); |
| 72 | + } |
| 73 | + |
| 74 | + @Test |
| 75 | + public void testFullPrecisionVectorSimilarityDVS() throws Exception { |
| 76 | + List<float[]> vectors = new ArrayList<>(); |
| 77 | + int numVectors = atLeast(NUM_VECTORS); |
| 78 | + int numSegments = random().nextInt(2, 10); |
| 79 | + final VectorSimilarityFunction vectorSimilarityFunction = |
| 80 | + VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)]; |
| 81 | + |
| 82 | + try (Directory dir = newDirectory()) { |
| 83 | + int id = 0; |
| 84 | + |
| 85 | + // index some 4 bit quantized vectors |
| 86 | + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(4))))) { |
| 87 | + for (int j = 0; j < numSegments; j++) { |
| 88 | + for (int i = 0; i < numVectors; i++) { |
| 89 | + Document doc = new Document(); |
| 90 | + if (random().nextInt(100) < 30) { |
| 91 | + // skip vector for some docs to create sparse vector field |
| 92 | + doc.add(new IntField("has_vector", 0, Field.Store.YES)); |
| 93 | + } else { |
| 94 | + float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION); |
| 95 | + vectors.add(vector); |
| 96 | + doc.add(new IntField("id", id++, Field.Store.YES)); |
| 97 | + doc.add(new KnnFloatVectorField(KNN_FIELD, vector, vectorSimilarityFunction)); |
| 98 | + doc.add(new IntField("has_vector", 1, Field.Store.YES)); |
| 99 | + } |
| 100 | + w.addDocument(doc); |
| 101 | + w.flush(); |
| 102 | + } |
| 103 | + } |
| 104 | + // add a segment with no vectors |
| 105 | + for (int i = 0; i < 100; i++) { |
| 106 | + Document doc = new Document(); |
| 107 | + doc.add(new IntField("has_vector", 0, Field.Store.YES)); |
| 108 | + w.addDocument(doc); |
| 109 | + } |
| 110 | + w.flush(); |
| 111 | + } |
| 112 | + |
| 113 | + // index some 7 bit quantized vectors |
| 114 | + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(getKnnFormat(7))))) { |
| 115 | + for (int j = 0; j < numSegments; j++) { |
| 116 | + for (int i = 0; i < numVectors; i++) { |
| 117 | + Document doc = new Document(); |
| 118 | + if (random().nextInt(100) < 30) { |
| 119 | + // skip vector for some docs to create sparse vector field |
| 120 | + doc.add(new IntField("has_vector", 0, Field.Store.YES)); |
| 121 | + } else { |
| 122 | + float[] vector = TestVectorUtil.randomVector(VECTOR_DIMENSION); |
| 123 | + vectors.add(vector); |
| 124 | + doc.add(new IntField("id", id++, Field.Store.YES)); |
| 125 | + doc.add(new KnnFloatVectorField(KNN_FIELD, vector, vectorSimilarityFunction)); |
| 126 | + doc.add(new IntField("has_vector", 1, Field.Store.YES)); |
| 127 | + } |
| 128 | + w.addDocument(doc); |
| 129 | + w.flush(); |
| 130 | + } |
| 131 | + } |
| 132 | + // add a segment with no vectors |
| 133 | + for (int i = 0; i < 100; i++) { |
| 134 | + Document doc = new Document(); |
| 135 | + doc.add(new IntField("has_vector", 0, Field.Store.YES)); |
| 136 | + w.addDocument(doc); |
| 137 | + } |
| 138 | + w.flush(); |
| 139 | + } |
| 140 | + |
| 141 | + float[] queryVector = TestVectorUtil.randomVector(VECTOR_DIMENSION); |
| 142 | + FloatVectorSimilarityValuesSource fpSimValueSource = new FloatVectorSimilarityValuesSource(queryVector, KNN_FIELD, true); |
| 143 | + FloatVectorSimilarityValuesSource quantizedSimValueSource = new FloatVectorSimilarityValuesSource(queryVector, KNN_FIELD); |
| 144 | + |
| 145 | + try (IndexReader reader = DirectoryReader.open(dir)) { |
| 146 | + FieldExistsQuery query = new FieldExistsQuery(KNN_FIELD); |
| 147 | + for (LeafReaderContext ctx: reader.leaves()) { |
| 148 | + DoubleValues fpSimValues = fpSimValueSource.getValues(ctx, null); |
| 149 | + DoubleValues quantizedSimValues = quantizedSimValueSource.getValues(ctx, null); |
| 150 | + // validate when segment has no vectors |
| 151 | + if (fpSimValues == DoubleValues.EMPTY || quantizedSimValues == DoubleValues.EMPTY) { |
| 152 | + assertEquals(fpSimValues, quantizedSimValues); |
| 153 | + assertNull(ctx.reader().getFloatVectorValues(KNN_FIELD)); |
| 154 | + continue; |
| 155 | + } |
| 156 | + StoredFields storedFields = ctx.reader().storedFields(); |
| 157 | + VectorScorer quantizedScorer = ctx.reader().getFloatVectorValues(KNN_FIELD).scorer(queryVector); |
| 158 | + DocIdSetIterator disi = quantizedScorer.iterator(); |
| 159 | + while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { |
| 160 | + int doc = disi.docID(); |
| 161 | + fpSimValues.advanceExact(doc); |
| 162 | + quantizedSimValues.advanceExact(doc); |
| 163 | + int idValue = Integer.parseInt(storedFields.document(doc).get("id")); |
| 164 | + float[] docVector = vectors.get(idValue); |
| 165 | + assert docVector != null : "Vector for id " + idValue + " not found"; |
| 166 | + // validate full precision vector scores |
| 167 | + double expectedFpScore = vectorSimilarityFunction.compare(queryVector, docVector); |
| 168 | + double actualFpScore = fpSimValues.doubleValue(); |
| 169 | + assertEquals(expectedFpScore, actualFpScore, 1e-5); |
| 170 | + // validate quantized vector scores |
| 171 | + double expectedQScore = quantizedScorer.score(); |
| 172 | + double actualQScore = quantizedSimValues.doubleValue(); |
| 173 | + assertEquals(expectedQScore, actualQScore, 1e-5); |
| 174 | + } |
| 175 | + } |
| 176 | + } |
| 177 | + } |
| 178 | + } |
| 179 | +} |
0 commit comments