5252import static java .util .Map .entry ;
5353import static org .elasticsearch .common .logging .LoggerMessageFormat .format ;
5454import static org .elasticsearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
55+ import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .K_FIELD ;
5556import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .VECTOR_SIMILARITY_FIELD ;
57+ import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .VISIT_PERCENTAGE_FIELD ;
5658import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FOURTH ;
5759import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
5860import static org .elasticsearch .xpack .esql .core .type .DataType .FLOAT ;
@@ -64,16 +66,18 @@ public class Knn extends SingleFieldFullTextFunction implements OptionalArgument
6466
6567 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
6668
67- // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
68- private final transient Integer k ;
69+ // Implicit k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
70+ private final transient Integer implicitK ;
6971 // Expressions to be used as prefilters in knn query
7072 private final List <Expression > filterExpressions ;
7173
7274 public static final String MIN_CANDIDATES_OPTION = "min_candidates" ;
7375
7476 public static final Map <String , DataType > ALLOWED_OPTIONS = Map .ofEntries (
77+ entry (K_FIELD .getPreferredName (), INTEGER ),
7578 entry (MIN_CANDIDATES_OPTION , INTEGER ),
7679 entry (VECTOR_SIMILARITY_FIELD .getPreferredName (), FLOAT ),
80+ entry (VISIT_PERCENTAGE_FIELD .getPreferredName (), FLOAT ),
7781 entry (BOOST_FIELD .getPreferredName (), FLOAT ),
7882 entry (KnnQuery .RESCORE_OVERSAMPLE_FIELD , FLOAT )
7983 );
@@ -102,6 +106,15 @@ public Knn(
102106 @ MapParam (
103107 name = "options" ,
104108 params = {
109+ @ MapParam .MapParamEntry (
110+ name = "k" ,
111+ type = "integer" ,
112+ valueHint = { "10" },
113+ description = "The number of nearest neighbors to return from each shard. "
114+ + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
115+ + "This value must be less than or equal to num_candidates. "
116+ + "This value is automatically set with any LIMIT applied to the function."
117+ ),
105118 @ MapParam .MapParamEntry (
106119 name = "boost" ,
107120 type = "float" ,
@@ -116,7 +129,17 @@ public Knn(
116129 description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. "
117130 + " KNN may use a higher number of candidates in case the query can't use a approximate results. "
118131 + "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
119- + "Defaults to 1.5 * LIMIT used for the query."
132+ + "Defaults to 1.5 * k (or LIMIT) used for the query."
133+ ),
134+ @ MapParam .MapParamEntry (
135+ name = "visit_percentage" ,
136+ type = "float" ,
137+ valueHint = { "10" },
138+ description = "The percentage of vectors to explore per shard while doing knn search with bbq_disk. "
139+ + "Must be between 0 and 100. 0 will default to using num_candidates for calculating the percent visited. "
140+ + "Increasing visit_percentage tends to improve the accuracy of the final results. "
141+ + "If visit_percentage is set for bbq_disk, num_candidates is ignored. "
142+ + "Defaults to ~1% per shard for every 1 million vectors"
120143 ),
121144 @ MapParam .MapParamEntry (
122145 name = "similarity" ,
@@ -146,12 +169,12 @@ public Knn(
146169 Expression field ,
147170 Expression query ,
148171 Expression options ,
149- Integer k ,
172+ Integer implicitK ,
150173 QueryBuilder queryBuilder ,
151174 List <Expression > filterExpressions
152175 ) {
153176 super (source , field , query , options , expressionList (field , query , options ), queryBuilder );
154- this .k = k ;
177+ this .implicitK = implicitK ;
155178 this .filterExpressions = filterExpressions ;
156179 }
157180
@@ -165,15 +188,15 @@ private static List<Expression> expressionList(Expression field, Expression quer
165188 return result ;
166189 }
167190
168- public Integer k () {
169- return k ;
191+ public Integer implicitK () {
192+ return implicitK ;
170193 }
171194
172195 public List <Expression > filterExpressions () {
173196 return filterExpressions ;
174197 }
175198
176- public Knn replaceK (Integer k ) {
199+ public Knn withImplicitK (Integer k ) {
177200 Check .notNull (k , "k must not be null" );
178201 return new Knn (source (), field (), query (), options (), k , queryBuilder (), filterExpressions ());
179202 }
@@ -191,7 +214,7 @@ public List<Number> queryAsObject() {
191214
192215 @ Override
193216 public Expression replaceQueryBuilder (QueryBuilder queryBuilder ) {
194- return new Knn (source (), field (), query (), options (), k (), queryBuilder , filterExpressions ());
217+ return new Knn (source (), field (), query (), options (), implicitK (), queryBuilder , filterExpressions ());
195218 }
196219
197220 @ Override
@@ -207,7 +230,7 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
207230
208231 @ Override
209232 protected Query translate (LucenePushdownPredicates pushdownPredicates , TranslatorHandler handler ) {
210- assert k () != null : "Knn function must have a k value set before translation" ;
233+ assert implicitK () != null : "Knn function must have a k value set before translation" ;
211234 var fieldAttribute = fieldAsFieldAttribute (field ());
212235
213236 Check .notNull (fieldAttribute , "Knn must have a field attribute as the first argument" );
@@ -226,7 +249,10 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
226249 }
227250 }
228251
229- return new KnnQuery (source (), fieldName , queryAsFloats , k (), queryOptions (), filterQueries );
252+ Map <String , Object > options = queryOptions ();
253+ Integer explicitK = (Integer ) options .get (K_FIELD .getPreferredName ());
254+
255+ return new KnnQuery (source (), fieldName , queryAsFloats , explicitK != null ? explicitK : implicitK (), options , filterQueries );
230256 }
231257
232258 private float [] queryAsFloats () {
@@ -239,7 +265,7 @@ private float[] queryAsFloats() {
239265 }
240266
241267 public Expression withFilters (List <Expression > filterExpressions ) {
242- return new Knn (source (), field (), query (), options (), k (), queryBuilder (), filterExpressions );
268+ return new Knn (source (), field (), query (), options (), implicitK (), queryBuilder (), filterExpressions );
243269 }
244270
245271 private Map <String , Object > queryOptions () throws InvalidArgumentException {
@@ -264,7 +290,7 @@ protected QueryBuilder evaluatorQueryBuilder() {
264290 @ Override
265291 public void postOptimizationVerification (Failures failures ) {
266292 // Check that a k has been set
267- if (k () == null ) {
293+ if (implicitK () == null ) {
268294 failures .add (
269295 Failure .fail (this , "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find" )
270296 );
@@ -278,15 +304,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
278304 newChildren .get (0 ),
279305 newChildren .get (1 ),
280306 newChildren .size () > 2 ? newChildren .get (2 ) : null ,
281- k (),
307+ implicitK (),
282308 queryBuilder (),
283309 filterExpressions ()
284310 );
285311 }
286312
287313 @ Override
288314 protected NodeInfo <? extends Expression > info () {
289- return NodeInfo .create (this , Knn ::new , field (), query (), options (), k (), queryBuilder (), filterExpressions ());
315+ return NodeInfo .create (this , Knn ::new , field (), query (), options (), implicitK (), queryBuilder (), filterExpressions ());
290316 }
291317
292318 @ Override
@@ -334,12 +360,14 @@ public boolean equals(Object o) {
334360 // ignore options when comparing two Knn functions
335361 if (o == null || getClass () != o .getClass ()) return false ;
336362 Knn knn = (Knn ) o ;
337- return super .equals (knn ) && Objects .equals (k (), knn .k ()) && Objects .equals (filterExpressions (), knn .filterExpressions ());
363+ return super .equals (knn )
364+ && Objects .equals (implicitK (), knn .implicitK ())
365+ && Objects .equals (filterExpressions (), knn .filterExpressions ());
338366 }
339367
340368 @ Override
341369 public int hashCode () {
342- return Objects .hash (field (), query (), queryBuilder (), k (), filterExpressions ());
370+ return Objects .hash (field (), query (), queryBuilder (), implicitK (), filterExpressions ());
343371 }
344372
345373}
0 commit comments