4242import java .util .ArrayList ;
4343import java .util .List ;
4444
45+ import static org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .OVERSAMPLE_LIMIT ;
4546import static org .elasticsearch .search .SearchService .DEFAULT_SIZE ;
4647import static org .hamcrest .Matchers .containsString ;
4748import static org .hamcrest .Matchers .equalTo ;
@@ -56,7 +57,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
5657
5758 abstract DenseVectorFieldMapper .ElementType elementType ();
5859
59- abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder (String fieldName , Integer k , int numCands , Float similarity );
60+ abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder (
61+ String fieldName ,
62+ Integer k ,
63+ int numCands ,
64+ RescoreVectorBuilder rescoreVectorBuilder ,
65+ Float similarity
66+ );
6067
6168 @ Override
6269 protected void initializeAdditionalMappings (MapperService mapperService ) throws IOException {
@@ -88,7 +95,13 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
8895 String fieldName = randomBoolean () ? VECTOR_FIELD : VECTOR_ALIAS_FIELD ;
8996 Integer k = randomBoolean () ? null : randomIntBetween (1 , 100 );
9097 int numCands = randomIntBetween (k == null ? DEFAULT_SIZE : k + 20 , 1000 );
91- KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder (fieldName , k , numCands , randomFloat ());
98+ KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder (
99+ fieldName ,
100+ k ,
101+ numCands ,
102+ randomRescoreVectorBuilder (),
103+ randomFloat ()
104+ );
92105
93106 if (randomBoolean ()) {
94107 List <QueryBuilder > filters = new ArrayList <>();
@@ -99,11 +112,24 @@ protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
99112 }
100113 queryBuilder .addFilterQueries (filters );
101114 }
115+
102116 return queryBuilder ;
103117 }
104118
119+ protected RescoreVectorBuilder randomRescoreVectorBuilder () {
120+ if (randomBoolean ()) {
121+ return null ;
122+ }
123+
124+ return new RescoreVectorBuilder (randomFloatBetween (1.0f , 10.0f , false ));
125+ }
126+
105127 @ Override
106128 protected void doAssertLuceneQuery (KnnVectorQueryBuilder queryBuilder , Query query , SearchExecutionContext context ) throws IOException {
129+ if (queryBuilder .rescoreVectorBuilder () != null ) {
130+ assertTrue (query instanceof org .apache .lucene .queries .function .FunctionScoreQuery );
131+ query = ((org .apache .lucene .queries .function .FunctionScoreQuery ) query ).getWrappedQuery ();
132+ }
107133 if (queryBuilder .getVectorSimilarity () != null ) {
108134 assertTrue (query instanceof VectorSimilarityQuery );
109135 Query knnQuery = ((VectorSimilarityQuery ) query ).getInnerKnnQuery ();
@@ -126,21 +152,17 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
126152 BooleanQuery booleanQuery = builder .build ();
127153 Query filterQuery = booleanQuery .clauses ().isEmpty () ? null : booleanQuery ;
128154 // The field should always be resolved to the concrete field
155+ Integer k = queryBuilder .k ();
156+ Integer numCands = queryBuilder .numCands ();
157+ if (queryBuilder .rescoreVectorBuilder () != null ) {
158+ Float rescoreOversample = queryBuilder .rescoreVectorBuilder ().oversample ();
159+ k = k == null ? null : Integer .valueOf (Math .min (OVERSAMPLE_LIMIT , (int ) Math .ceil (k * rescoreOversample )));
160+ numCands = numCands == null ? null : Math .max (k == null ? 0 : k , numCands );
161+ }
162+
129163 Query knnVectorQueryBuilt = switch (elementType ()) {
130- case BYTE , BIT -> new ESKnnByteVectorQuery (
131- VECTOR_FIELD ,
132- queryBuilder .queryVector ().asByteVector (),
133- queryBuilder .k (),
134- queryBuilder .numCands (),
135- filterQuery
136- );
137- case FLOAT -> new ESKnnFloatVectorQuery (
138- VECTOR_FIELD ,
139- queryBuilder .queryVector ().asFloatVector (),
140- queryBuilder .k (),
141- queryBuilder .numCands (),
142- filterQuery
143- );
164+ case BYTE , BIT -> new ESKnnByteVectorQuery (VECTOR_FIELD , queryBuilder .queryVector ().asByteVector (), k , numCands , filterQuery );
165+ case FLOAT -> new ESKnnFloatVectorQuery (VECTOR_FIELD , queryBuilder .queryVector ().asFloatVector (), k , numCands , filterQuery );
144166 };
145167 if (query instanceof VectorSimilarityQuery vectorSimilarityQuery ) {
146168 query = vectorSimilarityQuery .getInnerKnnQuery ();
@@ -150,7 +172,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
150172
151173 public void testWrongDimension () {
152174 SearchExecutionContext context = createSearchExecutionContext ();
153- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f }, 5 , 10 , null );
175+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f }, 5 , 10 , null , null );
154176 IllegalArgumentException e = expectThrows (IllegalArgumentException .class , () -> query .doToQuery (context ));
155177 assertThat (
156178 e .getMessage (),
@@ -160,15 +182,15 @@ public void testWrongDimension() {
160182
161183 public void testNonexistentField () {
162184 SearchExecutionContext context = createSearchExecutionContext ();
163- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder ("nonexistent" , new float [] { 1.0f , 1.0f , 1.0f }, 5 , 10 , null );
185+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder ("nonexistent" , new float [] { 1.0f , 1.0f , 1.0f }, 5 , 10 , null , null );
164186 context .setAllowUnmappedFields (false );
165187 QueryShardException e = expectThrows (QueryShardException .class , () -> query .doToQuery (context ));
166188 assertThat (e .getMessage (), containsString ("No field mapping can be found for the field with name [nonexistent]" ));
167189 }
168190
169191 public void testNonexistentFieldReturnEmpty () throws IOException {
170192 SearchExecutionContext context = createSearchExecutionContext ();
171- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder ("nonexistent" , new float [] { 1.0f , 1.0f , 1.0f }, 5 , 10 , null );
193+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder ("nonexistent" , new float [] { 1.0f , 1.0f , 1.0f }, 5 , 10 , null , null );
172194 Query queryNone = query .doToQuery (context );
173195 assertThat (queryNone , instanceOf (MatchNoDocsQuery .class ));
174196 }
@@ -180,6 +202,7 @@ public void testWrongFieldType() {
180202 new float [] { 1.0f , 1.0f , 1.0f },
181203 5 ,
182204 10 ,
205+ null ,
183206 null
184207 );
185208 IllegalArgumentException e = expectThrows (IllegalArgumentException .class , () -> query .doToQuery (context ));
@@ -191,14 +214,14 @@ public void testNumCandsLessThanK() {
191214 int numCands = 3 ;
192215 IllegalArgumentException e = expectThrows (
193216 IllegalArgumentException .class ,
194- () -> new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 1.0f , 1.0f }, k , numCands , null )
217+ () -> new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 1.0f , 1.0f }, k , numCands , null , null )
195218 );
196219 assertThat (e .getMessage (), containsString ("[num_candidates] cannot be less than [k]" ));
197220 }
198221
199222 @ Override
200223 public void testValidOutput () {
201- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f , 3.0f }, null , 10 , null );
224+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f , 3.0f }, null , 10 , null , null );
202225 String expected = """
203226 {
204227 "knn" : {
@@ -213,7 +236,7 @@ public void testValidOutput() {
213236 }""" ;
214237 assertEquals (expected , query .toString ());
215238
216- KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f , 3.0f }, 5 , 10 , null );
239+ KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder (VECTOR_FIELD , new float [] { 1.0f , 2.0f , 3.0f }, 5 , 10 , null , null );
217240 String expected2 = """
218241 {
219242 "knn" : {
@@ -240,6 +263,7 @@ public void testMustRewrite() throws IOException {
240263 new float [] { 1.0f , 2.0f , 3.0f },
241264 VECTOR_DIMENSION ,
242265 null ,
266+ null ,
243267 null
244268 );
245269 query .addFilterQuery (termQuery );
@@ -254,9 +278,14 @@ public void testMustRewrite() throws IOException {
254278 public void testBWCVersionSerializationFilters () throws IOException {
255279 KnnVectorQueryBuilder query = createTestQueryBuilder ();
256280 VectorData vectorData = VectorData .fromFloats (query .queryVector ().asFloatVector ());
257- KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder (query .getFieldName (), vectorData , null , query .numCands (), null )
258- .queryName (query .queryName ())
259- .boost (query .boost ());
281+ KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder (
282+ query .getFieldName (),
283+ vectorData ,
284+ null ,
285+ query .numCands (),
286+ null ,
287+ null
288+ ).queryName (query .queryName ()).boost (query .boost ());
260289 TransportVersion beforeFilterVersion = TransportVersionUtils .randomVersionBetween (
261290 random (),
262291 TransportVersions .V_8_0_0 ,
@@ -268,10 +297,14 @@ public void testBWCVersionSerializationFilters() throws IOException {
268297 public void testBWCVersionSerializationSimilarity () throws IOException {
269298 KnnVectorQueryBuilder query = createTestQueryBuilder ();
270299 VectorData vectorData = VectorData .fromFloats (query .queryVector ().asFloatVector ());
271- KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder (query .getFieldName (), vectorData , null , query .numCands (), null )
272- .queryName (query .queryName ())
273- .boost (query .boost ())
274- .addFilterQueries (query .filterQueries ());
300+ KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder (
301+ query .getFieldName (),
302+ vectorData ,
303+ null ,
304+ query .numCands (),
305+ null ,
306+ null
307+ ).queryName (query .queryName ()).boost (query .boost ()).addFilterQueries (query .filterQueries ());
275308 assertBWCSerialization (query , queryNoSimilarity , TransportVersions .V_8_7_0 );
276309 }
277310
@@ -289,11 +322,29 @@ public void testBWCVersionSerializationQuery() throws IOException {
289322 vectorData ,
290323 null ,
291324 query .numCands (),
325+ null ,
292326 similarity
293327 ).queryName (query .queryName ()).boost (query .boost ()).addFilterQueries (query .filterQueries ());
294328 assertBWCSerialization (query , queryOlderVersion , differentQueryVersion );
295329 }
296330
331+ public void testBWCVersionSerializationRescoreVector () throws IOException {
332+ KnnVectorQueryBuilder query = createTestQueryBuilder ();
333+ KnnVectorQueryBuilder queryNoRescoreVector = new KnnVectorQueryBuilder (
334+ query .getFieldName (),
335+ query .queryVector (),
336+ query .k (),
337+ query .numCands (),
338+ null ,
339+ query .getVectorSimilarity ()
340+ ).queryName (query .queryName ()).boost (query .boost ()).addFilterQueries (query .filterQueries ());
341+ assertBWCSerialization (
342+ query ,
343+ queryNoRescoreVector ,
344+ TransportVersionUtils .randomVersionBetween (random (), TransportVersions .V_8_8_0 , TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE )
345+ );
346+ }
347+
297348 private void assertBWCSerialization (QueryBuilder newQuery , QueryBuilder bwcQuery , TransportVersion version ) throws IOException {
298349 assertSerialization (bwcQuery , version );
299350 try (BytesStreamOutput output = new BytesStreamOutput ()) {
0 commit comments