99
1010package org .elasticsearch .search .vectors ;
1111
12+ import com .carrotsearch .randomizedtesting .annotations .ParametersFactory ;
13+
1214import org .apache .lucene .document .Document ;
15+ import org .apache .lucene .document .KnnByteVectorField ;
1316import org .apache .lucene .document .KnnFloatVectorField ;
17+ import org .apache .lucene .index .ByteVectorValues ;
1418import org .apache .lucene .index .DirectoryReader ;
1519import org .apache .lucene .index .FloatVectorValues ;
1620import org .apache .lucene .index .IndexReader ;
1721import org .apache .lucene .index .IndexWriter ;
1822import org .apache .lucene .index .KnnVectorValues ;
23+ import org .apache .lucene .index .LeafReader ;
1924import org .apache .lucene .index .LeafReaderContext ;
2025import org .apache .lucene .index .VectorSimilarityFunction ;
2126import org .apache .lucene .search .IndexSearcher ;
2429import org .apache .lucene .store .Directory ;
2530import org .elasticsearch .test .ESTestCase ;
2631
32+ import java .io .IOException ;
2733import java .util .ArrayList ;
2834import java .util .Arrays ;
2935import java .util .Collection ;
36+ import java .util .HashSet ;
37+ import java .util .List ;
3038import java .util .Map ;
3139import java .util .PriorityQueue ;
3240import java .util .stream .Collectors ;
3745public class RescoreKnnVectorQueryTests extends ESTestCase {
3846
3947 public static final String FIELD_NAME = "float_vector" ;
48+ private final int numDocs ;
49+ private final VectorProvider vectorProvider ;
50+ private final Integer k ;
4051
41- public void testRescoresTopK () throws Exception {
42- int numDocs = randomIntBetween (10 , 100 );
43- testRescoreDocs (numDocs , randomIntBetween (5 , numDocs - 1 ));
44- }
45-
46- public void testRescoresNoKParameter () throws Exception {
47- testRescoreDocs (randomIntBetween (10 , 100 ), null );
52+ public RescoreKnnVectorQueryTests (VectorProvider vectorProvider , boolean useK ) {
53+ this .vectorProvider = vectorProvider ;
54+ this .numDocs = randomIntBetween (10 , 100 );;
55+ this .k = useK ? randomIntBetween (1 , numDocs - 1 ) : null ;
4856 }
4957
50- private void testRescoreDocs (int numDocs , Integer k ) throws Exception {
58+ public void testRescoreDocs () throws Exception {
5159 int numDims = randomIntBetween (5 , 100 );
5260
61+ Integer adjustedK = k ;
5362 if (k == null ) {
54- k = numDocs ;
63+ adjustedK = numDocs ;
5564 }
5665
5766 try (Directory d = newDirectory ()) {
58- try (IndexWriter w = new IndexWriter (d , newIndexWriterConfig ())) {
59- for (int i = 0 ; i < numDocs ; i ++) {
60- Document document = new Document ();
61- float [] vector = randomVector (numDims );
62- KnnFloatVectorField vectorField = new KnnFloatVectorField (FIELD_NAME , vector );
63- document .add (vectorField );
64- w .addDocument (document );
65- }
66- w .commit ();
67- w .forceMerge (1 );
68- }
67+ addRandomDocuments (numDocs , d , numDims , vectorProvider );
6968
7069 try (IndexReader reader = DirectoryReader .open (d )) {
71- float [] queryVector = randomVector (numDims );
7270
73- RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery (
74- FIELD_NAME ,
75- queryVector ,
76- VectorSimilarityFunction .COSINE ,
77- k ,
78- new MatchAllDocsQuery ()
79- );
71+ // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
72+ // and thus we're rescoring the top k docs.
73+ VectorData queryVector = vectorProvider .randomVector (numDims );
74+ RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider .createRescoreQuery (queryVector , adjustedK );
8075
8176 IndexSearcher searcher = newSearcher (reader , true , false );
8277 TopDocs docs = searcher .search (rescoreKnnVectorQuery , numDocs );
8378 Map <Integer , Float > rescoredDocs = Arrays .stream (docs .scoreDocs )
8479 .collect (Collectors .toMap (scoreDoc -> scoreDoc .doc , scoreDoc -> scoreDoc .score ));
8580
86- assertThat (rescoredDocs .size (), equalTo (k ));
81+ assertThat (rescoredDocs .size (), equalTo (adjustedK ));
82+
83+ Collection <Float > rescoredScores = new HashSet <>(rescoredDocs .values ());
8784
88- Collection < Float > rescoredScores = new ArrayList <>( rescoredDocs . values ());
85+ // Collect all docs sequentially, and score them using the similarity function to get the top K scores
8986 PriorityQueue <Float > topK = new PriorityQueue <>((o1 , o2 ) -> Float .compare (o2 , o1 ));
9087
9188 for (LeafReaderContext leafReaderContext : reader .leaves ()) {
92- FloatVectorValues floatVectorValues = leafReaderContext .reader (). getFloatVectorValues ( FIELD_NAME );
93- KnnVectorValues .DocIndexIterator iterator = floatVectorValues .iterator ();
89+ KnnVectorValues vectorValues = vectorProvider . vectorValues ( leafReaderContext .reader ());
90+ KnnVectorValues .DocIndexIterator iterator = vectorValues .iterator ();
9491 while (iterator .nextDoc () != NO_MORE_DOCS ) {
95- float [] vector = floatVectorValues . vectorValue ( iterator .index ());
96- float score = VectorSimilarityFunction . COSINE . compare (queryVector , vector );
92+ VectorData vectorData = vectorProvider . dataVectorForDoc ( vectorValues , iterator .docID ());
93+ float score = vectorProvider . score (queryVector , vectorData );
9794 topK .add (score );
9895 int docId = iterator .docID ();
96+ // If the doc has been retrieved from the RescoreKnnVectorQuery, check the score is the same and remove it
97+ // to ensure we found them all
9998 if (rescoredDocs .containsKey (docId )) {
10099 assertThat (rescoredDocs .get (docId ), equalTo (score ));
101100 rescoredDocs .remove (docId );
@@ -106,7 +105,7 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
106105 assertThat (rescoredDocs .size (), equalTo (0 ));
107106
108107 // Check top scoring docs are contained in rescored docs
109- for (int i = 0 ; i < k ; i ++) {
108+ for (int i = 0 ; i < adjustedK ; i ++) {
110109 Float topScore = topK .poll ();
111110 if (rescoredScores .contains (topScore ) == false ) {
112111 fail ("Top score " + topScore + " not contained in rescored doc scores " + rescoredScores );
@@ -116,12 +115,154 @@ private void testRescoreDocs(int numDocs, Integer k) throws Exception {
116115 }
117116 }
118117
119- private static float [] randomVector (int numDims ) {
120- float [] vector = new float [numDims ];
121- for (int j = 0 ; j < numDims ; j ++) {
122- vector [j ] = randomFloatBetween (0 , 1 , true );
118+ private interface VectorProvider {
119+ VectorData randomVector (int numDimensions );
120+
121+ RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k );
122+
123+ KnnVectorValues vectorValues (LeafReader leafReader ) throws IOException ;
124+
125+ void addVectorField (Document document , VectorData vector );
126+
127+ VectorData dataVectorForDoc (KnnVectorValues vectorValues , int docId ) throws IOException ;
128+
129+ float score (VectorData queryVector , VectorData dataVector );
130+ }
131+
132+ private static class FloatVectorProvider implements VectorProvider {
133+ @ Override
134+ public VectorData randomVector (int numDimensions ) {
135+ float [] vector = new float [numDimensions ];
136+ for (int j = 0 ; j < numDimensions ; j ++) {
137+ vector [j ] = randomFloatBetween (0 , 1 , true );
138+ }
139+ return VectorData .fromFloats (vector );
140+ }
141+
142+ @ Override
143+ public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k ) {
144+ return new RescoreKnnVectorQuery (
145+ FIELD_NAME ,
146+ queryVector .floatVector (),
147+ VectorSimilarityFunction .COSINE ,
148+ k ,
149+ new MatchAllDocsQuery ()
150+ );
151+ }
152+
153+ @ Override
154+ public KnnVectorValues vectorValues (LeafReader leafReader ) throws IOException {
155+ return leafReader .getFloatVectorValues (FIELD_NAME );
156+ }
157+
158+ @ Override
159+ public void addVectorField (Document document , VectorData vector ) {
160+ KnnFloatVectorField vectorField = new KnnFloatVectorField (FIELD_NAME , vector .floatVector ());
161+ document .add (vectorField );
162+ }
163+
164+ @ Override
165+ public VectorData dataVectorForDoc (KnnVectorValues vectorValues , int docId ) throws IOException {
166+ return VectorData .fromFloats (((FloatVectorValues )vectorValues ).vectorValue (docId ));
167+ }
168+
169+ @ Override
170+ public float score (VectorData queryVector , VectorData dataVector ) {
171+ return VectorSimilarityFunction .COSINE .compare (queryVector .floatVector (), dataVector .floatVector ());
123172 }
124- return vector ;
125173 }
126174
175+ private static class ByteVectorProvider implements VectorProvider {
176+ @ Override
177+ public VectorData randomVector (int numDimensions ) {
178+ byte [] vector = new byte [numDimensions ];
179+ for (int j = 0 ; j < numDimensions ; j ++) {
180+ vector [j ] = randomByte ();
181+ }
182+ return VectorData .fromBytes (vector );
183+ }
184+
185+ @ Override
186+ public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k ) {
187+ return new RescoreKnnVectorQuery (
188+ FIELD_NAME ,
189+ queryVector .byteVector (),
190+ VectorSimilarityFunction .COSINE ,
191+ k ,
192+ new MatchAllDocsQuery ()
193+ );
194+ }
195+
196+ @ Override
197+ public KnnVectorValues vectorValues (LeafReader leafReader ) throws IOException {
198+ return leafReader .getByteVectorValues (FIELD_NAME );
199+ }
200+
201+ @ Override
202+ public void addVectorField (Document document , VectorData vector ) {
203+ KnnByteVectorField vectorField = new KnnByteVectorField (FIELD_NAME , vector .byteVector ());
204+ document .add (vectorField );
205+ }
206+
207+ @ Override
208+ public VectorData dataVectorForDoc (KnnVectorValues vectorValues , int docId ) throws IOException {
209+ return VectorData .fromBytes (((ByteVectorValues )vectorValues ).vectorValue (docId ));
210+ }
211+
212+ @ Override
213+ public float score (VectorData queryVector , VectorData dataVector ) {
214+ return VectorSimilarityFunction .COSINE .compare (queryVector .byteVector (), dataVector .byteVector ());
215+ }
216+ }
217+
218+ private static void addRandomDocuments (int numDocs , Directory d , int numDims , VectorProvider vectorProvider ) throws IOException {
219+ try (IndexWriter w = new IndexWriter (d , newIndexWriterConfig ())) {
220+ for (int i = 0 ; i < numDocs ; i ++) {
221+ Document document = new Document ();
222+ VectorData vector = vectorProvider .randomVector (numDims );
223+ vectorProvider .addVectorField (document , vector );
224+ w .addDocument (document );
225+ }
226+ w .commit ();
227+ w .forceMerge (1 );
228+ }
229+ }
230+
231+ @ ParametersFactory
232+ public static Iterable <Object []> parameters () {
233+
234+ List <Object []> params = new ArrayList <>();
235+ params .add (new Object [] {new FloatVectorProvider (), true });
236+ params .add (new Object [] {new FloatVectorProvider (), false });
237+ params .add (new Object [] {new ByteVectorProvider (), true });
238+ params .add (new Object [] {new ByteVectorProvider (), false });
239+
240+ return params ;
241+ }
242+
243+ // public void testProfiling() throws Exception {
244+ // int numDocs = randomIntBetween(10, 100);
245+ // int numDims = randomIntBetween(5, 100);
246+ //
247+ // try (Directory d = newDirectory()) {
248+ // addRandomDocuments(numDocs, d, numDims, vectorProvider);
249+ //
250+ // try (IndexReader reader = DirectoryReader.open(d)) {
251+ // float[] queryVector = randomVector(numDims);
252+ //
253+ // RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
254+ // FIELD_NAME,
255+ // queryVector,
256+ // VectorSimilarityFunction.COSINE,
257+ // randomIntBetween(5, numDocs - 1),
258+ // new MatchAllDocsQuery()
259+ // );
260+ //
261+ // IndexSearcher searcher = newSearcher(reader, true, false);
262+ // QueryProfiler queryProfiler = new QueryProfiler();
263+ // rescoreKnnVectorQuery.profile(queryProfiler);
264+ // }
265+ // }
266+ // }
267+
127268}
0 commit comments