77
88package org .elasticsearch .xpack .esql .expression .function .vector ;
99
10+ import org .elasticsearch .TransportVersions ;
1011import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
1112import org .elasticsearch .common .io .stream .StreamInput ;
1213import org .elasticsearch .common .io .stream .StreamOutput ;
4142import java .util .Objects ;
4243
4344import static java .util .Map .entry ;
45+ import static org .elasticsearch .TransportVersions .ESQL_KNN_K_PARAM_MANDATORY ;
4446import static org .elasticsearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
4547import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .K_FIELD ;
4648import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .NUM_CANDS_FIELD ;
@@ -62,10 +64,10 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
6264 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
6365
6466 private final Expression field ;
67+ private final Expression k ;
6568 private final Expression options ;
6669
6770 public static final Map <String , DataType > ALLOWED_OPTIONS = Map .ofEntries (
68- entry (K_FIELD .getPreferredName (), INTEGER ),
6971 entry (NUM_CANDS_FIELD .getPreferredName (), INTEGER ),
7072 entry (VECTOR_SIMILARITY_FIELD .getPreferredName (), FLOAT ),
7173 entry (BOOST_FIELD .getPreferredName (), FLOAT ),
@@ -90,6 +92,13 @@ public Knn(
9092 type = { "dense_vector" },
9193 description = "Vector value to find top nearest neighbours for."
9294 ) Expression query ,
95+ @ Param (
96+ name = "k" ,
97+ type = { "integer" },
98+ description = "The number of nearest neighbors to return from each shard. "
99+ + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
100+ + "This value must be less than or equal to num_candidates."
101+ ) Expression k ,
93102 @ MapParam (
94103 name = "options" ,
95104 params = {
@@ -100,14 +109,6 @@ public Knn(
100109 description = "Floating point number used to decrease or increase the relevance scores of the query."
101110 + "Defaults to 1.0."
102111 ),
103- @ MapParam .MapParamEntry (
104- name = "k" ,
105- type = "integer" ,
106- valueHint = { "10" },
107- description = "The number of nearest neighbors to return from each shard. "
108- + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
109- + "This value must be less than or equal to num_candidates. Defaults to 10."
110- ),
111112 @ MapParam .MapParamEntry (
112113 name = "num_candidates" ,
113114 type = "integer" ,
@@ -136,19 +137,24 @@ public Knn(
136137 optional = true
137138 ) Expression options
138139 ) {
139- this (source , field , query , options , null );
140+ this (source , field , query , k , options , null );
140141 }
141142
142- private Knn (Source source , Expression field , Expression query , Expression options , QueryBuilder queryBuilder ) {
143- super (source , query , options == null ? List .of (field , query ) : List .of (field , query , options ), queryBuilder );
143+ private Knn (Source source , Expression field , Expression query , Expression k , Expression options , QueryBuilder queryBuilder ) {
144+ super (source , query , options == null ? List .of (field , query , k ) : List .of (field , query , k , options ), queryBuilder );
144145 this .field = field ;
146+ this .k = k ;
145147 this .options = options ;
146148 }
147149
148150 public Expression field () {
149151 return field ;
150152 }
151153
154+ public Expression k () {
155+ return k ;
156+ }
157+
152158 public Expression options () {
153159 return options ;
154160 }
@@ -160,7 +166,7 @@ public DataType dataType() {
160166
161167 @ Override
162168 protected TypeResolution resolveParams () {
163- return resolveField ().and (resolveQuery ()).and (resolveOptions ());
169+ return resolveField ().and (resolveQuery ()).and (resolveK ()). and ( resolveOptions ());
164170 }
165171
166172 private TypeResolution resolveField () {
@@ -173,14 +179,19 @@ private TypeResolution resolveQuery() {
173179 );
174180 }
175181
182+ private TypeResolution resolveK () {
183+ return isNotNull (k (), sourceText (), TypeResolutions .ParamOrdinal .THIRD )
184+ .and (isType (k (), dt -> dt == INTEGER , sourceText (), TypeResolutions .ParamOrdinal .THIRD , "integer" ));
185+ }
186+
176187 private TypeResolution resolveOptions () {
177188 if (options () != null ) {
178- TypeResolution resolution = isNotNull (options (), sourceText (), THIRD );
189+ TypeResolution resolution = isNotNull (options (), sourceText (), TypeResolutions . ParamOrdinal . FOURTH );
179190 if (resolution .unresolved ()) {
180191 return resolution ;
181192 }
182193 // MapExpression does not have a DataType associated with it
183- resolution = isMapExpression (options (), sourceText (), THIRD );
194+ resolution = isMapExpression (options (), sourceText (), TypeResolutions . ParamOrdinal . FOURTH );
184195 if (resolution .unresolved ()) {
185196 return resolution ;
186197 }
@@ -200,7 +211,7 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
200211 }
201212
202213 Map <String , Object > matchOptions = new HashMap <>();
203- populateOptionsMap ((MapExpression ) options (), matchOptions , THIRD , sourceText (), ALLOWED_OPTIONS );
214+ populateOptionsMap ((MapExpression ) options (), matchOptions , TypeResolutions . ParamOrdinal . FOURTH , sourceText (), ALLOWED_OPTIONS );
204215 return matchOptions ;
205216 }
206217
@@ -216,22 +227,24 @@ protected Query translate(TranslatorHandler handler) {
216227 for (int i = 0 ; i < queryFolded .size (); i ++) {
217228 queryAsFloats [i ] = queryFolded .get (i ).floatValue ();
218229 }
230+ int kValue = ((Number ) k ().fold (FoldContext .small ())).intValue ();
231+
232+ Map <String , Object > opts = queryOptions ();
233+ opts .put (K_FIELD .getPreferredName (), kValue );
219234
220- return new KnnQuery (source (), fieldName , queryAsFloats , queryOptions () );
235+ return new KnnQuery (source (), fieldName , queryAsFloats , opts );
221236 }
222237
223238 @ Override
224239 public Expression replaceQueryBuilder (QueryBuilder queryBuilder ) {
225- return new Knn (source (), field (), query (), options (), queryBuilder );
240+ return new Knn (source (), field (), query (), k (), options (), queryBuilder );
226241 }
227242
228243 private Map <String , Object > queryOptions () throws InvalidArgumentException {
229- if (options () == null ) {
230- return Map .of ();
231- }
232-
233244 Map <String , Object > options = new HashMap <>();
234- populateOptionsMap ((MapExpression ) options (), options , THIRD , sourceText (), ALLOWED_OPTIONS );
245+ if (options () != null ) {
246+ populateOptionsMap ((MapExpression ) options (), options , TypeResolutions .ParamOrdinal .FOURTH , sourceText (), ALLOWED_OPTIONS );
247+ }
235248 return options ;
236249 }
237250
@@ -241,14 +254,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
241254 source (),
242255 newChildren .get (0 ),
243256 newChildren .get (1 ),
244- newChildren .size () > 2 ? newChildren .get (2 ) : null ,
257+ newChildren .get (2 ),
258+ newChildren .size () > 3 ? newChildren .get (3 ) : null ,
245259 queryBuilder ()
246260 );
247261 }
248262
249263 @ Override
250264 protected NodeInfo <? extends Expression > info () {
251- return NodeInfo .create (this , Knn ::new , field (), query (), options ());
265+ return NodeInfo .create (this , Knn ::new , field (), query (), k (), options ());
252266 }
253267
254268 @ Override
@@ -261,8 +275,11 @@ private static Knn readFrom(StreamInput in) throws IOException {
261275 Expression field = in .readNamedWriteable (Expression .class );
262276 Expression query = in .readNamedWriteable (Expression .class );
263277 QueryBuilder queryBuilder = in .readOptionalNamedWriteable (QueryBuilder .class );
264-
265- return new Knn (source , field , query , null , queryBuilder );
278+ Expression k = null ;
279+ if (in .getTransportVersion ().onOrAfter (ESQL_KNN_K_PARAM_MANDATORY )) {
280+ k = in .readNamedWriteable (Expression .class );
281+ }
282+ return new Knn (source , field , query , k , null , queryBuilder );
266283 }
267284
268285 @ Override
@@ -271,6 +288,9 @@ public void writeTo(StreamOutput out) throws IOException {
271288 out .writeNamedWriteable (field ());
272289 out .writeNamedWriteable (query ());
273290 out .writeOptionalNamedWriteable (queryBuilder ());
291+ if (out .getTransportVersion ().onOrAfter (ESQL_KNN_K_PARAM_MANDATORY )) {
292+ out .writeNamedWriteable (k ());
293+ }
274294 }
275295
276296 @ Override
@@ -281,12 +301,13 @@ public boolean equals(Object o) {
281301 Knn knn = (Knn ) o ;
282302 return Objects .equals (field (), knn .field ())
283303 && Objects .equals (query (), knn .query ())
304+ && Objects .equals (k (), knn .k ())
284305 && Objects .equals (queryBuilder (), knn .queryBuilder ());
285306 }
286307
287308 @ Override
288309 public int hashCode () {
289- return Objects .hash (field (), query (), queryBuilder ());
310+ return Objects .hash (field (), query (), k (), queryBuilder ());
290311 }
291312
292313}
0 commit comments