4545import java .util .Objects ;
4646import java .util .function .Supplier ;
4747
48+ import static org .elasticsearch .TransportVersions .KNN_QUERY_RESCORE_OVERSAMPLE ;
4849import static org .elasticsearch .common .Strings .format ;
4950import static org .elasticsearch .search .SearchService .DEFAULT_SIZE ;
5051import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
@@ -66,6 +67,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
6667 public static final ParseField NUM_CANDS_FIELD = new ParseField ("num_candidates" );
6768 public static final ParseField QUERY_VECTOR_FIELD = new ParseField ("query_vector" );
6869 public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField ("similarity" );
70+ public static final ParseField RESCORE_VECTOR_OVERSAMPLE = new ParseField ("rescore_vector_oversample" );
6971 public static final ParseField FILTER_FIELD = new ParseField ("filter" );
7072 public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField ("query_vector_builder" );
7173
@@ -79,7 +81,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
7981 null ,
8082 (Integer ) args [2 ],
8183 (Integer ) args [3 ],
82- (Float ) args [4 ]
84+ (Float ) args [4 ],
85+ (Float ) args [5 ]
8386 )
8487 );
8588
@@ -106,6 +109,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
106109 ObjectParser .ValueType .OBJECT_ARRAY
107110 );
108111 declareStandardFields (PARSER );
112+ PARSER .declareFloat (optionalConstructorArg (), RESCORE_VECTOR_OVERSAMPLE );
109113 }
110114
111115 public static KnnVectorQueryBuilder fromXContent (XContentParser parser ) {
@@ -120,6 +124,7 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
120124 private final Float vectorSimilarity ;
121125 private final QueryVectorBuilder queryVectorBuilder ;
122126 private final Supplier <float []> queryVectorSupplier ;
127+ private final Float rescoreOversample ;
123128
124129 public KnnVectorQueryBuilder (String fieldName , float [] queryVector , Integer k , Integer numCands , Float vectorSimilarity ) {
125130 this (fieldName , VectorData .fromFloats (queryVector ), null , null , k , numCands , vectorSimilarity );
@@ -151,6 +156,19 @@ private KnnVectorQueryBuilder(
151156 Integer k ,
152157 Integer numCands ,
153158 Float vectorSimilarity
159+ ) {
160+ this (fieldName , queryVector , null , null , k , numCands , vectorSimilarity , 0F );
161+ }
162+
163+ private KnnVectorQueryBuilder (
164+ String fieldName ,
165+ VectorData queryVector ,
166+ QueryVectorBuilder queryVectorBuilder ,
167+ Supplier <float []> queryVectorSupplier ,
168+ Integer k ,
169+ Integer numCands ,
170+ Float vectorSimilarity ,
171+ Float rescoreOversample
154172 ) {
155173 if (k != null && k < 1 ) {
156174 throw new IllegalArgumentException ("[" + K_FIELD .getPreferredName () + "] must be greater than 0" );
@@ -187,6 +205,7 @@ private KnnVectorQueryBuilder(
187205 this .vectorSimilarity = vectorSimilarity ;
188206 this .queryVectorBuilder = queryVectorBuilder ;
189207 this .queryVectorSupplier = queryVectorSupplier ;
208+ this .rescoreOversample = rescoreOversample ;
190209 }
191210
192211 public KnnVectorQueryBuilder (StreamInput in ) throws IOException {
@@ -227,6 +246,12 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException {
227246 } else {
228247 this .queryVectorBuilder = null ;
229248 }
249+ if (in .getTransportVersion ().onOrAfter (KNN_QUERY_RESCORE_OVERSAMPLE )) {
250+ this .rescoreOversample = in .readOptionalFloat ();
251+ } else {
252+ this .rescoreOversample = null ;
253+ }
254+
230255 this .queryVectorSupplier = null ;
231256 }
232257
@@ -252,6 +277,10 @@ public Integer numCands() {
252277 return numCands ;
253278 }
254279
280+ public Float rescoreOversample () {
281+ return rescoreOversample ;
282+ }
283+
255284 public List <QueryBuilder > filterQueries () {
256285 return filterQueries ;
257286 }
@@ -327,6 +356,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
327356 if (out .getTransportVersion ().onOrAfter (TransportVersions .V_8_14_0 )) {
328357 out .writeOptionalNamedWriteable (queryVectorBuilder );
329358 }
359+ if (out .getTransportVersion ().onOrAfter (KNN_QUERY_RESCORE_OVERSAMPLE )) {
360+ out .writeOptionalFloat (rescoreOversample );
361+ }
330362 }
331363
332364 @ Override
@@ -491,14 +523,31 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
491523 // Now join the filterQuery & parentFilter to provide the matching blocks of children
492524 filterQuery = new ToChildBlockJoinQuery (filterQuery , parentBitSet );
493525 }
494- return vectorFieldType .createKnnQuery (queryVector , k , adjustedNumCands , filterQuery , vectorSimilarity , parentBitSet );
526+ return vectorFieldType .createKnnQuery (
527+ queryVector ,
528+ k ,
529+ adjustedNumCands ,
530+ rescoreOversample ,
531+ filterQuery ,
532+ vectorSimilarity ,
533+ parentBitSet
534+ );
495535 }
496- return vectorFieldType .createKnnQuery (queryVector , k , adjustedNumCands , filterQuery , vectorSimilarity , null );
536+ return vectorFieldType .createKnnQuery (queryVector , k , adjustedNumCands , rescoreOversample , filterQuery , vectorSimilarity , null );
497537 }
498538
499539 @ Override
500540 protected int doHashCode () {
501- return Objects .hash (fieldName , Objects .hashCode (queryVector ), k , numCands , filterQueries , vectorSimilarity , queryVectorBuilder );
541+ return Objects .hash (
542+ fieldName ,
543+ Objects .hashCode (queryVector ),
544+ k ,
545+ numCands ,
546+ filterQueries ,
547+ vectorSimilarity ,
548+ queryVectorBuilder ,
549+ rescoreOversample
550+ );
502551 }
503552
504553 @ Override
@@ -509,7 +558,8 @@ protected boolean doEquals(KnnVectorQueryBuilder other) {
509558 && Objects .equals (numCands , other .numCands )
510559 && Objects .equals (filterQueries , other .filterQueries )
511560 && Objects .equals (vectorSimilarity , other .vectorSimilarity )
512- && Objects .equals (queryVectorBuilder , other .queryVectorBuilder );
561+ && Objects .equals (queryVectorBuilder , other .queryVectorBuilder )
562+ && Objects .equals (rescoreOversample , other .rescoreOversample );
513563 }
514564
515565 @ Override
0 commit comments