3535import org .elasticsearch .xpack .esql .querydsl .query .KnnQuery ;
3636
3737import java .io .IOException ;
38+ import java .util .ArrayList ;
3839import java .util .HashMap ;
3940import java .util .List ;
4041import java .util .Map ;
4142import java .util .Objects ;
4243
4344import static java .util .Map .entry ;
44- import static org .elasticsearch .TransportVersions .ESQL_KNN_K_PARAM_MANDATORY ;
4545import static org .elasticsearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
4646import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .K_FIELD ;
4747import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .NUM_CANDS_FIELD ;
@@ -64,7 +64,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
6464 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
6565
6666 private final Expression field ;
67- private final Expression k ;
67+ // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
68+ private final transient Expression k ;
6869 private final Expression options ;
6970
7071 public static final Map <String , DataType > ALLOWED_OPTIONS = Map .ofEntries (
@@ -79,9 +80,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
7980 preview = true ,
8081 description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
8182 + "knn function finds nearest vectors through approximate search on indexed dense_vectors." ,
82- examples = {
83- @ Example (file = "knn-function" , tag = "knn-function" ),
84- @ Example (file = "knn-function" , tag = "knn-function-options" ), },
83+ examples = { @ Example (file = "knn-function" , tag = "knn-function" ) },
8584 appliesTo = { @ FunctionAppliesTo (lifeCycle = FunctionAppliesToLifecycle .DEVELOPMENT ) }
8685 )
8786 public Knn (
@@ -141,12 +140,25 @@ public Knn(
141140 }
142141
143142 private Knn (Source source , Expression field , Expression query , Expression k , Expression options , QueryBuilder queryBuilder ) {
144- super (source , query , options == null ? List . of ( field , query , k ) : List . of (field , query , k , options ), queryBuilder );
143+ super (source , query , expressionList (field , query , k , options ), queryBuilder );
145144 this .field = field ;
146145 this .k = k ;
147146 this .options = options ;
148147 }
149148
149+ private static List <Expression > expressionList (Expression field , Expression query , Expression k , Expression options ) {
150+ List <Expression > result = new ArrayList <>();
151+ result .add (field );
152+ result .add (query );
153+ if (k != null ) {
154+ result .add (k );
155+ }
156+ if (options != null ) {
157+ result .add (options );
158+ }
159+ return result ;
160+ }
161+
150162 public Expression field () {
151163 return field ;
152164 }
@@ -275,11 +287,7 @@ private static Knn readFrom(StreamInput in) throws IOException {
275287 Expression field = in .readNamedWriteable (Expression .class );
276288 Expression query = in .readNamedWriteable (Expression .class );
277289 QueryBuilder queryBuilder = in .readOptionalNamedWriteable (QueryBuilder .class );
278- Expression k = null ;
279- if (in .getTransportVersion ().onOrAfter (ESQL_KNN_K_PARAM_MANDATORY )) {
280- k = in .readNamedWriteable (Expression .class );
281- }
282- return new Knn (source , field , query , k , null , queryBuilder );
290+ return new Knn (source , field , query , null , null , queryBuilder );
283291 }
284292
285293 @ Override
@@ -288,9 +296,6 @@ public void writeTo(StreamOutput out) throws IOException {
288296 out .writeNamedWriteable (field ());
289297 out .writeNamedWriteable (query ());
290298 out .writeOptionalNamedWriteable (queryBuilder ());
291- if (out .getTransportVersion ().onOrAfter (ESQL_KNN_K_PARAM_MANDATORY )) {
292- out .writeNamedWriteable (k ());
293- }
294299 }
295300
296301 @ Override
@@ -301,13 +306,12 @@ public boolean equals(Object o) {
301306 Knn knn = (Knn ) o ;
302307 return Objects .equals (field (), knn .field ())
303308 && Objects .equals (query (), knn .query ())
304- && Objects .equals (k (), knn .k ())
305309 && Objects .equals (queryBuilder (), knn .queryBuilder ());
306310 }
307311
308312 @ Override
309313 public int hashCode () {
310- return Objects .hash (field (), query (), k (), queryBuilder ());
314+ return Objects .hash (field (), query (), queryBuilder ());
311315 }
312316
313317}
0 commit comments