77
88package org .elasticsearch .xpack .esql .expression .function .vector ;
99
10- import org .apache .logging .log4j .LogManager ;
11- import org .apache .logging .log4j .Logger ;
1210import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
1311import org .elasticsearch .common .io .stream .StreamInput ;
1412import org .elasticsearch .common .io .stream .StreamOutput ;
5654import static java .util .Map .entry ;
5755import static org .elasticsearch .common .logging .LoggerMessageFormat .format ;
5856import static org .elasticsearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
59- import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .K_FIELD ;
6057import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .NUM_CANDS_FIELD ;
6158import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .VECTOR_SIMILARITY_FIELD ;
6259import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
6360import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FOURTH ;
6461import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
6562import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .THIRD ;
66- import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isFoldable ;
6763import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isNotNull ;
6864import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isType ;
6965import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
7369import static org .elasticsearch .xpack .esql .expression .function .FunctionUtils .resolveTypeQuery ;
7470
7571public class Knn extends FullTextFunction implements OptionalArgument , VectorFunction , PostAnalysisPlanVerificationAware {
76- private final Logger log = LogManager .getLogger (getClass ());
7772
7873 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
7974
8075 private final Expression field ;
8176 // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
82- private final transient Expression k ;
77+ private final transient Integer k ;
8378 private final Expression options ;
8479 // Expressions to be used as prefilters in knn query
8580 private final List <Expression > filterExpressions ;
@@ -107,13 +102,6 @@ public Knn(
107102 type = { "dense_vector" },
108103 description = "Vector value to find top nearest neighbours for."
109104 ) Expression query ,
110- @ Param (
111- name = "k" ,
112- type = { "integer" },
113- description = "The number of nearest neighbors to return from each shard. "
114- + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
115- + "This value must be less than or equal to num_candidates."
116- ) Expression k ,
117105 @ MapParam (
118106 name = "options" ,
119107 params = {
@@ -125,12 +113,13 @@ public Knn(
125113 + "Defaults to 1.0."
126114 ),
127115 @ MapParam .MapParamEntry (
128- name = "num_candidates " ,
116+ name = "min_candidates " ,
129117 type = "integer" ,
130118 valueHint = { "10" },
131- description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
132- + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
133- + "Defaults to 1.5 * k"
119+ description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. " +
120+ " KNN may use a higher number of candidates in case the query can't use a approximate results. "
121+ + "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
122+ + "Defaults to 1.5 * LIMIT used for the query."
134123 ),
135124 @ MapParam .MapParamEntry (
136125 name = "similarity" ,
@@ -152,32 +141,29 @@ public Knn(
152141 optional = true
153142 ) Expression options
154143 ) {
155- this (source , field , query , k , options , null , List .of ());
144+ this (source , field , query , options , null , null , List .of ());
156145 }
157146
158147 public Knn (
159148 Source source ,
160149 Expression field ,
161150 Expression query ,
162- Expression k ,
163151 Expression options ,
152+ Integer k ,
164153 QueryBuilder queryBuilder ,
165154 List <Expression > filterExpressions
166155 ) {
167- super (source , query , expressionList (field , query , k , options ), queryBuilder );
156+ super (source , query , expressionList (field , query , options ), queryBuilder );
168157 this .field = field ;
169158 this .k = k ;
170159 this .options = options ;
171160 this .filterExpressions = filterExpressions ;
172161 }
173162
174- private static List <Expression > expressionList (Expression field , Expression query , Expression k , Expression options ) {
163+ private static List <Expression > expressionList (Expression field , Expression query , Expression options ) {
175164 List <Expression > result = new ArrayList <>();
176165 result .add (field );
177166 result .add (query );
178- if (k != null ) {
179- result .add (k );
180- }
181167 if (options != null ) {
182168 result .add (options );
183169 }
@@ -188,7 +174,7 @@ public Expression field() {
188174 return field ;
189175 }
190176
191- public Expression k () {
177+ public Integer k () {
192178 return k ;
193179 }
194180
@@ -207,7 +193,7 @@ public DataType dataType() {
207193
208194 @ Override
209195 protected TypeResolution resolveParams () {
210- return resolveField ().and (resolveQuery ()).and (resolveK ()). and ( Options .resolve (options (), source (), FOURTH , ALLOWED_OPTIONS ));
196+ return resolveField ().and (resolveQuery ()).and (Options .resolve (options (), source (), THIRD , ALLOWED_OPTIONS ));
211197 }
212198
213199 private TypeResolution resolveField () {
@@ -227,14 +213,9 @@ private TypeResolution resolveQuery() {
227213 return TypeResolution .TYPE_RESOLVED ;
228214 }
229215
230- private TypeResolution resolveK () {
231- if (k == null ) {
232- // Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed
233- return TypeResolution .TYPE_RESOLVED ;
234- }
235-
236- return isType (k (), dt -> dt == INTEGER , sourceText (), THIRD , "integer" ).and (isFoldable (k (), sourceText (), THIRD ))
237- .and (isNotNull (k (), sourceText (), THIRD ));
216+ public Knn replaceK (Integer k ) {
217+ Check .notNull (k , "k must not be null" );
218+ return new Knn (source (), field (), query (), options (), k , queryBuilder (), filterExpressions ());
238219 }
239220
240221 public List <Number > queryAsObject () {
@@ -248,16 +229,9 @@ public List<Number> queryAsObject() {
248229 throw new EsqlIllegalArgumentException (format (null , "Query value must be a list of numbers in [{}], found [{}]" , source (), query ));
249230 }
250231
251- int getKIntValue () {
252- if (k () instanceof Literal literal ) {
253- return (int ) (Number ) literal .value ();
254- }
255- throw new EsqlIllegalArgumentException (format (null , "K value must be a constant integer in [{}], found [{}]" , source (), k ()));
256- }
257-
258232 @ Override
259233 public Expression replaceQueryBuilder (QueryBuilder queryBuilder ) {
260- return new Knn (source (), field (), query (), k (), options (), queryBuilder , filterExpressions ());
234+ return new Knn (source (), field (), query (), options (), k (), queryBuilder , filterExpressions ());
261235 }
262236
263237 @ Override
@@ -273,15 +247,12 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
273247
274248 @ Override
275249 protected Query translate (LucenePushdownPredicates pushdownPredicates , TranslatorHandler handler ) {
250+ assert k () != null : "Knn function must have a k value set before translation" ;
276251 var fieldAttribute = Match .fieldAsFieldAttribute (field ());
277252
278253 Check .notNull (fieldAttribute , "Knn must have a field attribute as the first argument" );
279254 String fieldName = getNameFromFieldAttribute (fieldAttribute );
280255 float [] queryAsFloats = queryAsFloats ();
281- int kValue = getKIntValue ();
282-
283- Map <String , Object > opts = queryOptions ();
284- opts .put (K_FIELD .getPreferredName (), kValue );
285256
286257 List <QueryBuilder > filterQueries = new ArrayList <>();
287258 for (Expression filterExpression : filterExpressions ()) {
@@ -295,7 +266,7 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
295266 }
296267 }
297268
298- return new KnnQuery (source (), fieldName , queryAsFloats , opts , filterQueries );
269+ return new KnnQuery (source (), fieldName , queryAsFloats , k (), queryOptions () , filterQueries );
299270 }
300271
301272 private float [] queryAsFloats () {
@@ -308,7 +279,7 @@ private float[] queryAsFloats() {
308279 }
309280
310281 public Expression withFilters (List <Expression > filterExpressions ) {
311- return new Knn (source (), field (), query (), k (), options (), queryBuilder (), filterExpressions );
282+ return new Knn (source (), field (), query (), options (), k (), queryBuilder (), filterExpressions );
312283 }
313284
314285 private Map <String , Object > queryOptions () throws InvalidArgumentException {
@@ -343,16 +314,16 @@ public Expression replaceChildren(List<Expression> newChildren) {
343314 source (),
344315 newChildren .get (0 ),
345316 newChildren .get (1 ),
346- newChildren .get (2 ),
347- newChildren . size () > 3 ? newChildren . get ( 3 ) : null ,
317+ newChildren .size () > 2 ? newChildren . get (2 ) : null ,
318+ k () ,
348319 queryBuilder (),
349320 filterExpressions ()
350321 );
351322 }
352323
353324 @ Override
354325 protected NodeInfo <? extends Expression > info () {
355- return NodeInfo .create (this , Knn ::new , field (), query (), k (), options (), queryBuilder (), filterExpressions ());
326+ return NodeInfo .create (this , Knn ::new , field (), query (), options (), k (), queryBuilder (), filterExpressions ());
356327 }
357328
358329 @ Override
0 commit comments