@@ -67,9 +67,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
6767 public static final ParseField NUM_CANDS_FIELD = new ParseField ("num_candidates" );
6868 public static final ParseField QUERY_VECTOR_FIELD = new ParseField ("query_vector" );
6969 public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField ("similarity" );
70- public static final ParseField RESCORE_VECTOR_OVERSAMPLE = new ParseField ("rescore_vector_oversample" );
7170 public static final ParseField FILTER_FIELD = new ParseField ("filter" );
7271 public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField ("query_vector_builder" );
72+ public static final ParseField RESCORE_FIELD = new ParseField ("rescore" );
7373
7474 public static final ConstructingObjectParser <KnnVectorQueryBuilder , Void > PARSER = new ConstructingObjectParser <>(
7575 "knn" ,
@@ -80,8 +80,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
8080 null ,
8181 (Integer ) args [2 ],
8282 (Integer ) args [3 ],
83- (Float ) args [4 ],
84- (Float ) args [6 ]
83+ (RescoreVectorBuilder ) args [6 ],
84+ (Float ) args [4 ]
8585 )
8686 );
8787
@@ -101,7 +101,12 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
101101 (p , c , n ) -> p .namedObject (QueryVectorBuilder .class , n , c ),
102102 QUERY_VECTOR_BUILDER_FIELD
103103 );
104- PARSER .declareFloat (optionalConstructorArg (), RESCORE_VECTOR_OVERSAMPLE );
104+ PARSER .declareField (
105+ optionalConstructorArg (),
106+ (p , c ) -> RescoreVectorBuilder .fromXContent (p ),
107+ RESCORE_FIELD ,
108+ ObjectParser .ValueType .OBJECT_OR_NULL
109+ );
105110 PARSER .declareFieldArray (
106111 KnnVectorQueryBuilder ::addFilterQueries ,
107112 (p , c ) -> AbstractQueryBuilder .parseTopLevelQuery (p ),
@@ -123,10 +128,17 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
123128 private final Float vectorSimilarity ;
124129 private final QueryVectorBuilder queryVectorBuilder ;
125130 private final Supplier <float []> queryVectorSupplier ;
126- private final Float rescoreOversample ;
131+ private final RescoreVectorBuilder rescoreVectorBuilder ;
127132
128- public KnnVectorQueryBuilder (String fieldName , float [] queryVector , Integer k , Integer numCands , Float vectorSimilarity ) {
129- this (fieldName , VectorData .fromFloats (queryVector ), null , null , k , numCands , vectorSimilarity , null );
133+ public KnnVectorQueryBuilder (
134+ String fieldName ,
135+ float [] queryVector ,
136+ Integer k ,
137+ Integer numCands ,
138+ RescoreVectorBuilder rescoreVectorBuilder ,
139+ Float vectorSimilarity
140+ ) {
141+ this (fieldName , VectorData .fromFloats (queryVector ), null , null , k , numCands , rescoreVectorBuilder , vectorSimilarity );
130142 }
131143
132144 public KnnVectorQueryBuilder (
@@ -136,27 +148,29 @@ public KnnVectorQueryBuilder(
136148 Integer numCands ,
137149 Float vectorSimilarity
138150 ) {
139- this (fieldName , null , queryVectorBuilder , null , k , numCands , vectorSimilarity , null );
151+ this (fieldName , null , queryVectorBuilder , null , k , numCands , null , vectorSimilarity );
140152 }
141153
142- public KnnVectorQueryBuilder (String fieldName , byte [] queryVector , Integer k , Integer numCands , Float vectorSimilarity ) {
143- this (fieldName , VectorData .fromBytes (queryVector ), null , null , k , numCands , vectorSimilarity );
144- }
145-
146- public KnnVectorQueryBuilder (String fieldName , VectorData queryVector , Integer k , Integer numCands , Float vectorSimilarity ) {
147- this (fieldName , queryVector , null , null , k , numCands , vectorSimilarity );
154+ public KnnVectorQueryBuilder (
155+ String fieldName ,
156+ byte [] queryVector ,
157+ Integer k ,
158+ Integer numCands ,
159+ RescoreVectorBuilder rescoreVectorBuilder ,
160+ Float vectorSimilarity
161+ ) {
162+ this (fieldName , VectorData .fromBytes (queryVector ), null , null , k , numCands , rescoreVectorBuilder , vectorSimilarity );
148163 }
149164
150- private KnnVectorQueryBuilder (
165+ public KnnVectorQueryBuilder (
151166 String fieldName ,
152167 VectorData queryVector ,
153- QueryVectorBuilder queryVectorBuilder ,
154- Supplier <float []> queryVectorSupplier ,
155168 Integer k ,
156169 Integer numCands ,
170+ RescoreVectorBuilder rescoreVectorBuilder ,
157171 Float vectorSimilarity
158172 ) {
159- this (fieldName , queryVector , null , null , k , numCands , vectorSimilarity , null );
173+ this (fieldName , queryVector , null , null , k , numCands , rescoreVectorBuilder , vectorSimilarity );
160174 }
161175
162176 private KnnVectorQueryBuilder (
@@ -166,8 +180,8 @@ private KnnVectorQueryBuilder(
166180 Supplier <float []> queryVectorSupplier ,
167181 Integer k ,
168182 Integer numCands ,
169- Float vectorSimilarity ,
170- Float rescoreOversample
183+ RescoreVectorBuilder rescoreVectorBuilder ,
184+ Float vectorSimilarity
171185 ) {
172186 if (k != null && k < 1 ) {
173187 throw new IllegalArgumentException ("[" + K_FIELD .getPreferredName () + "] must be greater than 0" );
@@ -204,7 +218,7 @@ private KnnVectorQueryBuilder(
204218 this .vectorSimilarity = vectorSimilarity ;
205219 this .queryVectorBuilder = queryVectorBuilder ;
206220 this .queryVectorSupplier = queryVectorSupplier ;
207- this .rescoreOversample = rescoreOversample ;
221+ this .rescoreVectorBuilder = rescoreVectorBuilder ;
208222 }
209223
210224 public KnnVectorQueryBuilder (StreamInput in ) throws IOException {
@@ -246,9 +260,9 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
246260 this .queryVectorBuilder = null ;
247261 }
248262 if (in .getTransportVersion ().onOrAfter (KNN_QUERY_RESCORE_OVERSAMPLE )) {
249- this .rescoreOversample = in .readOptionalFloat ( );
263+ this .rescoreVectorBuilder = in .readOptional ( RescoreVectorBuilder :: new );
250264 } else {
251- this .rescoreOversample = null ;
265+ this .rescoreVectorBuilder = null ;
252266 }
253267
254268 this .queryVectorSupplier = null ;
@@ -276,10 +290,6 @@ public Integer numCands() {
276290 return numCands ;
277291 }
278292
279- public Float rescoreOversample () {
280- return rescoreOversample ;
281- }
282-
283293 public List <QueryBuilder > filterQueries () {
284294 return filterQueries ;
285295 }
@@ -289,6 +299,10 @@ public QueryVectorBuilder queryVectorBuilder() {
289299 return queryVectorBuilder ;
290300 }
291301
302+ public RescoreVectorBuilder rescoreVectorBuilder () {
303+ return rescoreVectorBuilder ;
304+ }
305+
292306 public KnnVectorQueryBuilder addFilterQuery (QueryBuilder filterQuery ) {
293307 Objects .requireNonNull (filterQuery );
294308 this .filterQueries .add (filterQuery );
@@ -356,7 +370,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
356370 out .writeOptionalNamedWriteable (queryVectorBuilder );
357371 }
358372 if (out .getTransportVersion ().onOrAfter (KNN_QUERY_RESCORE_OVERSAMPLE )) {
359- out .writeOptionalFloat ( rescoreOversample );
373+ out .writeOptionalWriteable ( rescoreVectorBuilder );
360374 }
361375 }
362376
@@ -391,6 +405,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep
391405 }
392406 builder .endArray ();
393407 }
408+ if (rescoreVectorBuilder != null ) {
409+ builder .startObject (RESCORE_FIELD .getPreferredName ());
410+ rescoreVectorBuilder .toXContent (builder , params );
411+ builder .endObject ();
412+ }
394413 boostAndQueryNameToXContent (builder );
395414 builder .endObject ();
396415 }
@@ -406,7 +425,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
406425 if (queryVectorSupplier .get () == null ) {
407426 return this ;
408427 }
409- return new KnnVectorQueryBuilder (fieldName , queryVectorSupplier .get (), k , numCands , vectorSimilarity ).boost (boost )
428+ return new KnnVectorQueryBuilder (fieldName , queryVectorSupplier .get (), k , numCands , rescoreVectorBuilder , vectorSimilarity )
429+ .boost (boost )
410430 .queryName (queryName )
411431 .addFilterQueries (filterQueries );
412432 }
@@ -428,9 +448,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
428448 }
429449 ll .onResponse (null );
430450 })));
431- return new KnnVectorQueryBuilder (fieldName , queryVector , queryVectorBuilder , toSet ::get , k , numCands , vectorSimilarity ).boost (
432- boost
433- ).queryName (queryName ).addFilterQueries (filterQueries );
451+ return new KnnVectorQueryBuilder (
452+ fieldName ,
453+ queryVector ,
454+ queryVectorBuilder ,
455+ toSet ::get ,
456+ k ,
457+ numCands ,
458+ rescoreVectorBuilder ,
459+ vectorSimilarity
460+ ).boost (boost ).queryName (queryName ).addFilterQueries (filterQueries );
434461 }
435462 if (ctx .convertToInnerHitsRewriteContext () != null ) {
436463 return new ExactKnnQueryBuilder (queryVector , fieldName , vectorSimilarity ).boost (boost ).queryName (queryName );
@@ -448,10 +475,16 @@ protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
448475 rewrittenQueries .add (rewrittenQuery );
449476 }
450477 if (changed ) {
451- return new KnnVectorQueryBuilder (fieldName , queryVector , queryVectorBuilder , queryVectorSupplier , k , numCands , vectorSimilarity )
452- .boost (boost )
453- .queryName (queryName )
454- .addFilterQueries (rewrittenQueries );
478+ return new KnnVectorQueryBuilder (
479+ fieldName ,
480+ queryVector ,
481+ queryVectorBuilder ,
482+ queryVectorSupplier ,
483+ k ,
484+ numCands ,
485+ rescoreVectorBuilder ,
486+ vectorSimilarity
487+ ).boost (boost ).queryName (queryName ).addFilterQueries (rewrittenQueries );
455488 }
456489 return this ;
457490 }
@@ -495,6 +528,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
495528
496529 DenseVectorFieldType vectorFieldType = (DenseVectorFieldType ) fieldType ;
497530 String parentPath = context .nestedLookup ().getNestedParent (fieldName );
531+ Float rescoreOversample = rescoreVectorBuilder () == null ? null : rescoreVectorBuilder .oversample ();
498532
499533 if (parentPath != null ) {
500534 final BitSetProducer parentBitSet ;
@@ -550,7 +584,7 @@ protected int doHashCode() {
550584 filterQueries ,
551585 vectorSimilarity ,
552586 queryVectorBuilder ,
553- rescoreOversample
587+ rescoreVectorBuilder
554588 );
555589 }
556590
@@ -563,7 +597,7 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
563597 && Objects .equals (filterQueries , other .filterQueries )
564598 && Objects .equals (vectorSimilarity , other .vectorSimilarity )
565599 && Objects .equals (queryVectorBuilder , other .queryVectorBuilder )
566- && Objects .equals (rescoreOversample , other .rescoreOversample );
600+ && Objects .equals (rescoreVectorBuilder , other .rescoreVectorBuilder );
567601 }
568602
569603 @ Override
0 commit comments