|
1 | 1 | package org.apache.lucene.search; |
2 | 2 |
|
| 3 | +import java.util.HashMap; |
| 4 | +import java.util.Map; |
| 5 | +import java.util.Random; |
3 | 6 | import org.apache.lucene.codecs.FilterCodec; |
4 | 7 | import org.apache.lucene.codecs.KnnVectorsFormat; |
5 | 8 | import org.apache.lucene.codecs.lucene100.Lucene100Codec; |
|
19 | 22 | import org.junit.Before; |
20 | 23 | import org.junit.Test; |
21 | 24 |
|
22 | | -import java.util.HashMap; |
23 | | -import java.util.Map; |
24 | | -import java.util.Random; |
25 | | - |
26 | 25 | public class TestTwoPhaseKnnVectorQuery { |
27 | 26 |
|
28 | | - private static final String FIELD = "vector"; |
29 | | - public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = VectorSimilarityFunction.COSINE; |
30 | | - private Directory directory; |
31 | | - private IndexWriterConfig config; |
32 | | - private static final int NUM_VECTORS = 1000; |
33 | | - private static final int VECTOR_DIMENSION = 128; |
| 27 | + private static final String FIELD = "vector"; |
| 28 | + public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION = |
| 29 | + VectorSimilarityFunction.COSINE; |
| 30 | + private Directory directory; |
| 31 | + private IndexWriterConfig config; |
| 32 | + private static final int NUM_VECTORS = 1000; |
| 33 | + private static final int VECTOR_DIMENSION = 128; |
34 | 34 |
|
35 | | - @Before |
36 | | - public void setUp() throws Exception { |
37 | | - directory = new ByteBuffersDirectory(); |
| 35 | + @Before |
| 36 | + public void setUp() throws Exception { |
| 37 | + directory = new ByteBuffersDirectory(); |
38 | 38 |
|
39 | | - // Set up the IndexWriterConfig to use quantized vector storage |
40 | | - config = new IndexWriterConfig(); |
41 | | - config.setCodec(new QuantizedCodec()); |
42 | | - } |
| 39 | + // Set up the IndexWriterConfig to use quantized vector storage |
| 40 | + config = new IndexWriterConfig(); |
| 41 | + config.setCodec(new QuantizedCodec()); |
| 42 | + } |
43 | 43 |
|
44 | | - @Test |
45 | | - public void testTwoPhaseKnnVectorQuery() throws Exception { |
46 | | - Map<Integer, float[]> vectors = new HashMap<>(); |
| 44 | + @Test |
| 45 | + public void testTwoPhaseKnnVectorQuery() throws Exception { |
| 46 | + Map<Integer, float[]> vectors = new HashMap<>(); |
47 | 47 |
|
48 | | - // Step 1: Index random vectors in quantized format |
49 | | - try (IndexWriter writer = new IndexWriter(directory, config)) { |
50 | | - Random random = new Random(); |
51 | | - for (int i = 0; i < NUM_VECTORS; i++) { |
52 | | - float[] vector = randomFloatVector(VECTOR_DIMENSION, random); |
53 | | - Document doc = new Document(); |
54 | | - doc.add(new IntField("id", i, Field.Store.YES)); |
55 | | - doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); |
56 | | - writer.addDocument(doc); |
57 | | - vectors.put(i, vector); |
58 | | - } |
59 | | - } |
| 48 | + // Step 1: Index random vectors in quantized format |
| 49 | + try (IndexWriter writer = new IndexWriter(directory, config)) { |
| 50 | + Random random = new Random(); |
| 51 | + for (int i = 0; i < NUM_VECTORS; i++) { |
| 52 | + float[] vector = randomFloatVector(VECTOR_DIMENSION, random); |
| 53 | + Document doc = new Document(); |
| 54 | + doc.add(new IntField("id", i, Field.Store.YES)); |
| 55 | + doc.add(new KnnFloatVectorField(FIELD, vector, VECTOR_SIMILARITY_FUNCTION)); |
| 56 | + writer.addDocument(doc); |
| 57 | + vectors.put(i, vector); |
| 58 | + } |
| 59 | + } |
60 | 60 |
|
61 | | - // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector |
62 | | - try (IndexReader reader = DirectoryReader.open(directory)) { |
63 | | - IndexSearcher searcher = new IndexSearcher(reader); |
64 | | - float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); |
65 | | - int k = 10; |
66 | | - double oversample = 1.0; |
| 61 | + // Step 2: Run TwoPhaseKnnVectorQuery with a random target vector |
| 62 | + try (IndexReader reader = DirectoryReader.open(directory)) { |
| 63 | + IndexSearcher searcher = new IndexSearcher(reader); |
| 64 | + float[] targetVector = randomFloatVector(VECTOR_DIMENSION, new Random()); |
| 65 | + int k = 10; |
| 66 | + double oversample = 1.0; |
67 | 67 |
|
68 | | - TwoPhaseKnnVectorQuery query = new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); |
69 | | - TopDocs topDocs = searcher.search(query, k); |
| 68 | + TwoPhaseKnnVectorQuery query = |
| 69 | + new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null); |
| 70 | + TopDocs topDocs = searcher.search(query, k); |
70 | 71 |
|
71 | | - // Step 3: Verify that TopDocs scores match similarity with unquantized vectors |
72 | | - for (ScoreDoc scoreDoc : topDocs.scoreDocs) { |
73 | | - Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); |
74 | | - float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); |
75 | | - float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); |
76 | | - Assert.assertEquals( |
77 | | - "Score does not match expected similarity for docId: " + scoreDoc.doc, |
78 | | - expectedScore, scoreDoc.score, 1e-5); |
79 | | - } |
80 | | - } |
| 72 | + // Step 3: Verify that TopDocs scores match similarity with unquantized vectors |
| 73 | + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { |
| 74 | + Document retrievedDoc = searcher.storedFields().document(scoreDoc.doc); |
| 75 | + float[] docVector = vectors.get(retrievedDoc.getField("id").numericValue().intValue()); |
| 76 | + float expectedScore = VECTOR_SIMILARITY_FUNCTION.compare(targetVector, docVector); |
| 77 | + Assert.assertEquals( |
| 78 | + "Score does not match expected similarity for docId: " + scoreDoc.doc, |
| 79 | + expectedScore, |
| 80 | + scoreDoc.score, |
| 81 | + 1e-5); |
| 82 | + } |
81 | 83 | } |
| 84 | + } |
82 | 85 |
|
83 | | - private float[] randomFloatVector(int dimension, Random random) { |
84 | | - float[] vector = new float[dimension]; |
85 | | - for (int i = 0; i < dimension; i++) { |
86 | | - vector[i] = random.nextFloat(); |
87 | | - } |
88 | | - return vector; |
| 86 | + private float[] randomFloatVector(int dimension, Random random) { |
| 87 | + float[] vector = new float[dimension]; |
| 88 | + for (int i = 0; i < dimension; i++) { |
| 89 | + vector[i] = random.nextFloat(); |
89 | 90 | } |
| 91 | + return vector; |
| 92 | + } |
90 | 93 |
|
91 | | - public static class QuantizedCodec extends FilterCodec { |
| 94 | + public static class QuantizedCodec extends FilterCodec { |
92 | 95 |
|
93 | | - public QuantizedCodec() { |
94 | | - super("QuantizedCodec", new Lucene100Codec()); |
95 | | - } |
| 96 | + public QuantizedCodec() { |
| 97 | + super("QuantizedCodec", new Lucene100Codec()); |
| 98 | + } |
96 | 99 |
|
97 | | - @Override |
98 | | - public KnnVectorsFormat knnVectorsFormat() { |
99 | | - return new Lucene99HnswScalarQuantizedVectorsFormat(); |
100 | | - } |
| 100 | + @Override |
| 101 | + public KnnVectorsFormat knnVectorsFormat() { |
| 102 | + return new Lucene99HnswScalarQuantizedVectorsFormat(); |
101 | 103 | } |
| 104 | + } |
102 | 105 | } |
0 commit comments