2525import org .apache .lucene .index .VectorSimilarityFunction ;
2626import org .apache .lucene .search .IndexSearcher ;
2727import org .apache .lucene .search .MatchAllDocsQuery ;
28+ import org .apache .lucene .search .Query ;
29+ import org .apache .lucene .search .QueryVisitor ;
30+ import org .apache .lucene .search .ScoreMode ;
2831import org .apache .lucene .search .TopDocs ;
32+ import org .apache .lucene .search .Weight ;
2933import org .apache .lucene .store .Directory ;
34+ import org .elasticsearch .search .profile .query .QueryProfiler ;
3035import org .elasticsearch .test .ESTestCase ;
3136
3237import java .io .IOException ;
4146
4247import static org .apache .lucene .search .DocIdSetIterator .NO_MORE_DOCS ;
4348import static org .hamcrest .Matchers .equalTo ;
49+ import static org .hamcrest .Matchers .greaterThan ;
4450
4551public class RescoreKnnVectorQueryTests extends ESTestCase {
4652
@@ -51,7 +57,8 @@ public class RescoreKnnVectorQueryTests extends ESTestCase {
5157
5258 public RescoreKnnVectorQueryTests (VectorProvider vectorProvider , boolean useK ) {
5359 this .vectorProvider = vectorProvider ;
54- this .numDocs = randomIntBetween (10 , 100 );;
60+ this .numDocs = randomIntBetween (10 , 100 );
61+ ;
5562 this .k = useK ? randomIntBetween (1 , numDocs - 1 ) : null ;
5663 }
5764
@@ -71,7 +78,11 @@ public void testRescoreDocs() throws Exception {
7178 // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
7279 // and thus we're rescoring the top k docs.
7380 VectorData queryVector = vectorProvider .randomVector (numDims );
74- RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider .createRescoreQuery (queryVector , adjustedK );
81+ RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider .createRescoreQuery (
82+ queryVector ,
83+ adjustedK ,
84+ new MatchAllDocsQuery ()
85+ );
7586
7687 IndexSearcher searcher = newSearcher (reader , true , false );
7788 TopDocs docs = searcher .search (rescoreKnnVectorQuery , numDocs );
@@ -115,10 +126,90 @@ public void testRescoreDocs() throws Exception {
115126 }
116127 }
117128
129+ public void testProfiling () throws Exception {
130+ int numDims = randomIntBetween (5 , 100 );
131+
132+ try (Directory d = newDirectory ()) {
133+ addRandomDocuments (numDocs , d , numDims , vectorProvider );
134+
135+ try (IndexReader reader = DirectoryReader .open (d )) {
136+ VectorData queryVector = vectorProvider .randomVector (numDims );
137+
138+ checkProfiling (queryVector , reader , new MatchAllDocsQuery ());
139+ checkProfiling (queryVector , reader , new MockProfilingQuery (randomIntBetween (1 , 100 )));
140+ }
141+ }
142+ }
143+
144+ private void checkProfiling (VectorData queryVector , IndexReader reader , Query innerQuery ) throws IOException {
145+ RescoreKnnVectorQuery rescoreKnnVectorQuery = vectorProvider .createRescoreQuery (queryVector , k , innerQuery );
146+ IndexSearcher searcher = newSearcher (reader , true , false );
147+ searcher .search (rescoreKnnVectorQuery , numDocs );
148+
149+ QueryProfiler queryProfiler = new QueryProfiler ();
150+ rescoreKnnVectorQuery .profile (queryProfiler );
151+
152+ long expectedVectorOpsCount = 0 ;
153+ if (k != null ) {
154+ expectedVectorOpsCount += k ;
155+ }
156+ if (innerQuery instanceof ProfilingQuery profilingQuery ) {
157+ QueryProfiler anotherProfiler = new QueryProfiler ();
158+ profilingQuery .profile (anotherProfiler );
159+ assertThat (anotherProfiler .getVectorOpsCount (), greaterThan (0L ));
160+ expectedVectorOpsCount += anotherProfiler .getVectorOpsCount ();
161+ }
162+
163+ assertThat (queryProfiler .getVectorOpsCount (), equalTo (expectedVectorOpsCount ));
164+ }
165+
166+ /**
167+ * A mock query that is used to test profiling
168+ */
169+ private static class MockProfilingQuery extends Query implements ProfilingQuery {
170+
171+ private final long vectorOpsCount ;
172+
173+ private MockProfilingQuery (long vectorOpsCount ) {
174+ this .vectorOpsCount = vectorOpsCount ;
175+ }
176+
177+ @ Override
178+ public String toString (String field ) {
179+ return "" ;
180+ }
181+
182+ @ Override
183+ public Weight createWeight (IndexSearcher searcher , ScoreMode scoreMode , float boost ) throws IOException {
184+ return new MatchAllDocsQuery ().createWeight (searcher , scoreMode , boost );
185+ }
186+
187+ @ Override
188+ public void visit (QueryVisitor visitor ) {}
189+
190+ @ Override
191+ public boolean equals (Object obj ) {
192+ return obj instanceof MockProfilingQuery ;
193+ }
194+
195+ @ Override
196+ public int hashCode () {
197+ return 0 ;
198+ }
199+
200+ @ Override
201+ public void profile (QueryProfiler queryProfiler ) {
202+ queryProfiler .addVectorOpsCount (vectorOpsCount );
203+ }
204+ }
205+
206+ /**
207+ * Vector operations depend on the type of vector field used. This interface abstracts the operations needed to perform the tests
208+ */
118209 private interface VectorProvider {
119210 VectorData randomVector (int numDimensions );
120211
121- RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k );
212+ RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k , Query innerQuery );
122213
123214 KnnVectorValues vectorValues (LeafReader leafReader ) throws IOException ;
124215
@@ -140,14 +231,8 @@ public VectorData randomVector(int numDimensions) {
140231 }
141232
142233 @ Override
143- public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k ) {
144- return new RescoreKnnVectorQuery (
145- FIELD_NAME ,
146- queryVector .floatVector (),
147- VectorSimilarityFunction .COSINE ,
148- k ,
149- new MatchAllDocsQuery ()
150- );
234+ public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k , Query innerQuery ) {
235+ return new RescoreKnnVectorQuery (FIELD_NAME , queryVector .floatVector (), VectorSimilarityFunction .COSINE , k , innerQuery );
151236 }
152237
153238 @ Override
@@ -163,7 +248,7 @@ public void addVectorField(Document document, VectorData vector) {
163248
164249 @ Override
165250 public VectorData dataVectorForDoc (KnnVectorValues vectorValues , int docId ) throws IOException {
166- return VectorData .fromFloats (((FloatVectorValues )vectorValues ).vectorValue (docId ));
251+ return VectorData .fromFloats (((FloatVectorValues ) vectorValues ).vectorValue (docId ));
167252 }
168253
169254 @ Override
@@ -183,14 +268,8 @@ public VectorData randomVector(int numDimensions) {
183268 }
184269
185270 @ Override
186- public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k ) {
187- return new RescoreKnnVectorQuery (
188- FIELD_NAME ,
189- queryVector .byteVector (),
190- VectorSimilarityFunction .COSINE ,
191- k ,
192- new MatchAllDocsQuery ()
193- );
271+ public RescoreKnnVectorQuery createRescoreQuery (VectorData queryVector , Integer k , Query innerQuery ) {
272+ return new RescoreKnnVectorQuery (FIELD_NAME , queryVector .byteVector (), VectorSimilarityFunction .COSINE , k , innerQuery );
194273 }
195274
196275 @ Override
@@ -206,7 +285,7 @@ public void addVectorField(Document document, VectorData vector) {
206285
207286 @ Override
208287 public VectorData dataVectorForDoc (KnnVectorValues vectorValues , int docId ) throws IOException {
209- return VectorData .fromBytes (((ByteVectorValues )vectorValues ).vectorValue (docId ));
288+ return VectorData .fromBytes (((ByteVectorValues ) vectorValues ).vectorValue (docId ));
210289 }
211290
212291 @ Override
@@ -230,39 +309,12 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims, Ve
230309
231310 @ ParametersFactory
232311 public static Iterable <Object []> parameters () {
233-
234312 List <Object []> params = new ArrayList <>();
235- params .add (new Object [] {new FloatVectorProvider (), true });
236- params .add (new Object [] {new FloatVectorProvider (), false });
237- params .add (new Object [] {new ByteVectorProvider (), true });
238- params .add (new Object [] {new ByteVectorProvider (), false });
313+ params .add (new Object [] { new FloatVectorProvider (), true });
314+ params .add (new Object [] { new FloatVectorProvider (), false });
315+ params .add (new Object [] { new ByteVectorProvider (), true });
316+ params .add (new Object [] { new ByteVectorProvider (), false });
239317
240318 return params ;
241319 }
242-
243- // public void testProfiling() throws Exception {
244- // int numDocs = randomIntBetween(10, 100);
245- // int numDims = randomIntBetween(5, 100);
246- //
247- // try (Directory d = newDirectory()) {
248- // addRandomDocuments(numDocs, d, numDims, vectorProvider);
249- //
250- // try (IndexReader reader = DirectoryReader.open(d)) {
251- // float[] queryVector = randomVector(numDims);
252- //
253- // RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
254- // FIELD_NAME,
255- // queryVector,
256- // VectorSimilarityFunction.COSINE,
257- // randomIntBetween(5, numDocs - 1),
258- // new MatchAllDocsQuery()
259- // );
260- //
261- // IndexSearcher searcher = newSearcher(reader, true, false);
262- // QueryProfiler queryProfiler = new QueryProfiler();
263- // rescoreKnnVectorQuery.profile(queryProfiler);
264- // }
265- // }
266- // }
267-
268320}
0 commit comments