@@ -56,6 +56,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
5656 public static final ParseField NAME_FIELD = AbstractQueryBuilder .NAME_FIELD ;
5757 public static final ParseField BOOST_FIELD = AbstractQueryBuilder .BOOST_FIELD ;
5858 public static final ParseField INNER_HITS_FIELD = new ParseField ("inner_hits" );
59+ public static final ParseField RESCORE_FIELD = new ParseField ("rescore" );
5960
6061 @ SuppressWarnings ("unchecked" )
6162 private static final ConstructingObjectParser <KnnSearchBuilder .Builder , Void > PARSER = new ConstructingObjectParser <>("knn" , args -> {
@@ -65,7 +66,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
6566 .queryVectorBuilder ((QueryVectorBuilder ) args [4 ])
6667 .k ((Integer ) args [2 ])
6768 .numCandidates ((Integer ) args [3 ])
68- .similarity ((Float ) args [5 ]);
69+ .similarity ((Float ) args [5 ])
70+ .rescoreVectorBuilder ((RescoreVectorBuilder ) args [6 ]);
6971 });
7072
7173 static {
@@ -78,13 +80,18 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
7880 );
7981 PARSER .declareInt (optionalConstructorArg (), K_FIELD );
8082 PARSER .declareInt (optionalConstructorArg (), NUM_CANDS_FIELD );
81-
8283 PARSER .declareNamedObject (
8384 optionalConstructorArg (),
8485 (p , c , n ) -> p .namedObject (QueryVectorBuilder .class , n , c ),
8586 QUERY_VECTOR_BUILDER_FIELD
8687 );
8788 PARSER .declareFloat (optionalConstructorArg (), VECTOR_SIMILARITY );
89+ PARSER .declareField (
90+ optionalConstructorArg (),
91+ (p , c ) -> RescoreVectorBuilder .fromXContent (p ),
92+ RESCORE_FIELD ,
93+ ObjectParser .ValueType .OBJECT_OR_NULL
94+ );
8895 PARSER .declareFieldArray (
8996 KnnSearchBuilder .Builder ::addFilterQueries ,
9097 (p , c ) -> AbstractQueryBuilder .parseTopLevelQuery (p ),
@@ -116,6 +123,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw
116123 String queryName ;
117124 float boost = DEFAULT_BOOST ;
118125 InnerHitBuilder innerHitBuilder ;
126+ final RescoreVectorBuilder rescoreVectorBuilder ;
119127
120128 /**
121129 * Defines a kNN search.
@@ -124,14 +132,23 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw
124132 * @param queryVector the query vector
125133 * @param k the final number of nearest neighbors to return as top hits
126134 * @param numCands the number of nearest neighbor candidates to consider per shard
135+ * @param rescoreVectorBuilder rescore vector information
127136 */
128- public KnnSearchBuilder (String field , float [] queryVector , int k , int numCands , Float similarity ) {
137+ public KnnSearchBuilder (
138+ String field ,
139+ float [] queryVector ,
140+ int k ,
141+ int numCands ,
142+ RescoreVectorBuilder rescoreVectorBuilder ,
143+ Float similarity
144+ ) {
129145 this (
130146 field ,
131147 Objects .requireNonNull (VectorData .fromFloats (queryVector ), format ("[%s] cannot be null" , QUERY_VECTOR_FIELD )),
132148 null ,
133149 k ,
134150 numCands ,
151+ rescoreVectorBuilder ,
135152 similarity
136153 );
137154 }
@@ -144,8 +161,15 @@ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands,
144161 * @param k the final number of nearest neighbors to return as top hits
145162 * @param numCands the number of nearest neighbor candidates to consider per shard
146163 */
147- public KnnSearchBuilder (String field , VectorData queryVector , int k , int numCands , Float similarity ) {
148- this (field , queryVector , null , k , numCands , similarity );
164+ public KnnSearchBuilder (
165+ String field ,
166+ VectorData queryVector ,
167+ int k ,
168+ int numCands ,
169+ RescoreVectorBuilder rescoreVectorBuilder ,
170+ Float similarity
171+ ) {
172+ this (field , queryVector , null , k , numCands , rescoreVectorBuilder , similarity );
149173 }
150174
151175 /**
@@ -156,13 +180,21 @@ public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCand
156180 * @param k the final number of nearest neighbors to return as top hits
157181 * @param numCands the number of nearest neighbor candidates to consider per shard
158182 */
159- public KnnSearchBuilder (String field , QueryVectorBuilder queryVectorBuilder , int k , int numCands , Float similarity ) {
183+ public KnnSearchBuilder (
184+ String field ,
185+ QueryVectorBuilder queryVectorBuilder ,
186+ int k ,
187+ int numCands ,
188+ RescoreVectorBuilder rescoreVectorBuilder ,
189+ Float similarity
190+ ) {
160191 this (
161192 field ,
162193 null ,
163194 Objects .requireNonNull (queryVectorBuilder , format ("[%s] cannot be null" , QUERY_VECTOR_BUILDER_FIELD .getPreferredName ())),
164195 k ,
165196 numCands ,
197+ rescoreVectorBuilder ,
166198 similarity
167199 );
168200 }
@@ -173,16 +205,30 @@ public KnnSearchBuilder(
173205 QueryVectorBuilder queryVectorBuilder ,
174206 int k ,
175207 int numCands ,
208+ RescoreVectorBuilder rescoreVectorBuilder ,
176209 Float similarity
177210 ) {
178- this (field , queryVectorBuilder , queryVector , new ArrayList <>(), k , numCands , similarity , null , null , DEFAULT_BOOST );
211+ this (
212+ field ,
213+ queryVectorBuilder ,
214+ queryVector ,
215+ new ArrayList <>(),
216+ k ,
217+ numCands ,
218+ rescoreVectorBuilder ,
219+ similarity ,
220+ null ,
221+ null ,
222+ DEFAULT_BOOST
223+ );
179224 }
180225
181226 private KnnSearchBuilder (
182227 String field ,
183228 Supplier <float []> querySupplier ,
184229 Integer k ,
185230 Integer numCands ,
231+ RescoreVectorBuilder rescoreVectorBuilder ,
186232 List <QueryBuilder > filterQueries ,
187233 Float similarity
188234 ) {
@@ -194,6 +240,7 @@ private KnnSearchBuilder(
194240 this .filterQueries = filterQueries ;
195241 this .querySupplier = querySupplier ;
196242 this .similarity = similarity ;
243+ this .rescoreVectorBuilder = rescoreVectorBuilder ;
197244 }
198245
199246 private KnnSearchBuilder (
@@ -203,6 +250,7 @@ private KnnSearchBuilder(
203250 List <QueryBuilder > filterQueries ,
204251 int k ,
205252 int numCandidates ,
253+ RescoreVectorBuilder rescoreVectorBuilder ,
206254 Float similarity ,
207255 InnerHitBuilder innerHitBuilder ,
208256 String queryName ,
@@ -242,6 +290,7 @@ private KnnSearchBuilder(
242290 this .queryVectorBuilder = queryVectorBuilder ;
243291 this .k = k ;
244292 this .numCands = numCandidates ;
293+ this .rescoreVectorBuilder = rescoreVectorBuilder ;
245294 this .innerHitBuilder = innerHitBuilder ;
246295 this .similarity = similarity ;
247296 this .queryName = queryName ;
@@ -280,6 +329,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException {
280329 if (in .getTransportVersion ().onOrAfter (V_8_11_X )) {
281330 this .innerHitBuilder = in .readOptionalWriteable (InnerHitBuilder ::new );
282331 }
332+ if (in .getTransportVersion ().onOrAfter (TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE )) {
333+ this .rescoreVectorBuilder = in .readOptional (RescoreVectorBuilder ::new );
334+ } else {
335+ this .rescoreVectorBuilder = null ;
336+ }
283337 }
284338
285339 public int k () {
@@ -290,6 +344,10 @@ public int getNumCands() {
290344 return numCands ;
291345 }
292346
347+ public RescoreVectorBuilder getRescoreVectorBuilder () {
348+ return rescoreVectorBuilder ;
349+ }
350+
293351 public QueryVectorBuilder getQueryVectorBuilder () {
294352 return queryVectorBuilder ;
295353 }
@@ -358,7 +416,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
358416 if (querySupplier .get () == null ) {
359417 return this ;
360418 }
361- return new KnnSearchBuilder (field , querySupplier .get (), k , numCands , similarity ).boost (boost )
419+ return new KnnSearchBuilder (field , querySupplier .get (), k , numCands , rescoreVectorBuilder , similarity ).boost (boost )
362420 .queryName (queryName )
363421 .addFilterQueries (filterQueries )
364422 .innerHit (innerHitBuilder );
@@ -381,7 +439,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
381439 }
382440 ll .onResponse (null );
383441 })));
384- return new KnnSearchBuilder (field , toSet ::get , k , numCands , filterQueries , similarity ).boost (boost )
442+ return new KnnSearchBuilder (field , toSet ::get , k , numCands , rescoreVectorBuilder , filterQueries , similarity ).boost (boost )
385443 .queryName (queryName )
386444 .innerHit (innerHitBuilder );
387445 }
@@ -395,7 +453,7 @@ public KnnSearchBuilder rewrite(QueryRewriteContext ctx) throws IOException {
395453 rewrittenQueries .add (rewrittenQuery );
396454 }
397455 if (changed ) {
398- return new KnnSearchBuilder (field , queryVector , k , numCands , similarity ).boost (boost )
456+ return new KnnSearchBuilder (field , queryVector , k , numCands , rescoreVectorBuilder , similarity ).boost (boost )
399457 .queryName (queryName )
400458 .addFilterQueries (rewrittenQueries )
401459 .innerHit (innerHitBuilder );
@@ -407,7 +465,7 @@ public KnnVectorQueryBuilder toQueryBuilder() {
407465 if (queryVectorBuilder != null ) {
408466 throw new IllegalArgumentException ("missing rewrite" );
409467 }
410- return new KnnVectorQueryBuilder (field , queryVector , null , numCands , null , similarity ).boost (boost )
468+ return new KnnVectorQueryBuilder (field , queryVector , null , numCands , rescoreVectorBuilder , similarity ).boost (boost )
411469 .queryName (queryName )
412470 .addFilterQueries (filterQueries );
413471 }
@@ -423,6 +481,7 @@ public boolean equals(Object o) {
423481 KnnSearchBuilder that = (KnnSearchBuilder ) o ;
424482 return k == that .k
425483 && numCands == that .numCands
484+ && Objects .equals (rescoreVectorBuilder , that .rescoreVectorBuilder )
426485 && Objects .equals (field , that .field )
427486 && Objects .equals (queryVector , that .queryVector )
428487 && Objects .equals (queryVectorBuilder , that .queryVectorBuilder )
@@ -442,6 +501,7 @@ public int hashCode() {
442501 numCands ,
443502 querySupplier ,
444503 queryVectorBuilder ,
504+ rescoreVectorBuilder ,
445505 similarity ,
446506 Objects .hashCode (queryVector ),
447507 Objects .hashCode (filterQueries ),
@@ -486,6 +546,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
486546 if (queryName != null ) {
487547 builder .field (NAME_FIELD .getPreferredName (), queryName );
488548 }
549+ if (rescoreVectorBuilder != null ) {
550+ builder .startObject (RESCORE_FIELD .getPreferredName ());
551+ rescoreVectorBuilder .toXContent (builder , params );
552+ builder .endObject ();
553+ }
489554
490555 return builder ;
491556 }
@@ -526,6 +591,9 @@ public void writeTo(StreamOutput out) throws IOException {
526591 if (out .getTransportVersion ().onOrAfter (V_8_11_X )) {
527592 out .writeOptionalWriteable (innerHitBuilder );
528593 }
594+ if (out .getTransportVersion ().onOrAfter (TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE )) {
595+ out .writeOptionalWriteable (rescoreVectorBuilder );
596+ }
529597 }
530598
531599 public static class Builder {
@@ -540,6 +608,7 @@ public static class Builder {
540608 private String queryName ;
541609 private float boost = DEFAULT_BOOST ;
542610 private InnerHitBuilder innerHitBuilder ;
611+ private RescoreVectorBuilder rescoreVectorBuilder ;
543612
544613 public Builder addFilterQueries (List <QueryBuilder > filterQueries ) {
545614 Objects .requireNonNull (filterQueries );
@@ -592,6 +661,11 @@ public Builder similarity(Float similarity) {
592661 return this ;
593662 }
594663
664+ public Builder rescoreVectorBuilder (RescoreVectorBuilder rescoreVectorBuilder ) {
665+ this .rescoreVectorBuilder = rescoreVectorBuilder ;
666+ return this ;
667+ }
668+
595669 public KnnSearchBuilder build (int size ) {
596670 int requestSize = size < 0 ? DEFAULT_SIZE : size ;
597671 int adjustedK = k == null ? requestSize : k ;
@@ -605,6 +679,7 @@ public KnnSearchBuilder build(int size) {
605679 filterQueries ,
606680 adjustedK ,
607681 adjustedNumCandidates ,
682+ rescoreVectorBuilder ,
608683 similarity ,
609684 innerHitBuilder ,
610685 queryName ,
0 commit comments