Skip to content

Commit 9c924bc

Browse files
authored
ES|QL - KNN function option changes (#138372)
1 parent 59a8bb2 commit 9c924bc

File tree

11 files changed

+91
-61
lines changed

11 files changed

+91
-61
lines changed

docs/changelog/138372.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 138372
2+
summary: ES|QL - KNN function options support k and visit_percentage parameters
3+
area: "ES|QL"
4+
type: enhancement
5+
issues: []

docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md

Lines changed: 7 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/definition/functions/knn.json

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,11 @@ azure | [240.0, 255.0, 255.0]
226226

227227
knnWithNonPushableConjunction
228228
required_capability: knn_function_v5
229+
required_capability: knn_function_options_k_visit_percentage
229230

230231
from colors metadata _score
231232
| eval composed_name = locate(color, " ") > 0
232-
| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false
233+
| where knn(rgb_vector, [128,128,0], {"k": 100}) and composed_name == false
233234
| sort _score desc, color asc
234235
| keep color, composed_name
235236
| limit 10

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ public void testKnnDefaults() {
113113
}
114114
}
115115

116-
public void testKnnOptions() {
116+
public void testKnnKOverridesLimit() {
117117
float[] queryVector = new float[numDims];
118118
Arrays.fill(queryVector, 0.0f);
119119

120120
var query = String.format(Locale.ROOT, """
121121
FROM test METADATA _score
122-
| WHERE knn(vector, %s)
122+
| WHERE knn(vector, %s, {"k": 5, "min_candidates": 20})
123123
| KEEP id, _score, vector
124124
| SORT _score DESC
125-
| LIMIT 5
125+
| LIMIT 10
126126
""", Arrays.toString(queryVector));
127127

128128
try (var resp = run(query)) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,10 @@ public enum Cap {
16791679
*/
16801680
PROMQL_V0(Build.current().isSnapshot()),
16811681

1682+
/**
1683+
* KNN function adds support for k and visit_percentage options
1684+
*/
1685+
KNN_FUNCTION_OPTIONS_K_VISIT_PERCENTAGE,
16821686
// Last capability should still have a comma for fewer merge conflicts when adding new ones :)
16831687
// This comment prevents the semicolon from being on the previous capability when Spotless formats the file.
16841688
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
import static java.util.Map.entry;
5353
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
5454
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
55+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
5556
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
57+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VISIT_PERCENTAGE_FIELD;
5658
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
5759
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
5860
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
@@ -64,16 +66,18 @@ public class Knn extends SingleFieldFullTextFunction implements OptionalArgument
6466

6567
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
6668

67-
// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
68-
private final transient Integer k;
69+
// Implicit k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
70+
private final transient Integer implicitK;
6971
// Expressions to be used as prefilters in knn query
7072
private final List<Expression> filterExpressions;
7173

7274
public static final String MIN_CANDIDATES_OPTION = "min_candidates";
7375

7476
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
77+
entry(K_FIELD.getPreferredName(), INTEGER),
7578
entry(MIN_CANDIDATES_OPTION, INTEGER),
7679
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
80+
entry(VISIT_PERCENTAGE_FIELD.getPreferredName(), FLOAT),
7781
entry(BOOST_FIELD.getPreferredName(), FLOAT),
7882
entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
7983
);
@@ -102,6 +106,15 @@ public Knn(
102106
@MapParam(
103107
name = "options",
104108
params = {
109+
@MapParam.MapParamEntry(
110+
name = "k",
111+
type = "integer",
112+
valueHint = { "10" },
113+
description = "The number of nearest neighbors to return from each shard. "
114+
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
115+
+ "This value must be less than or equal to num_candidates. "
116+
+ "This value is automatically set with any LIMIT applied to the function."
117+
),
105118
@MapParam.MapParamEntry(
106119
name = "boost",
107120
type = "float",
@@ -116,7 +129,17 @@ public Knn(
116129
description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. "
117130
+ " KNN may use a higher number of candidates in case the query can't use a approximate results. "
118131
+ "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
119-
+ "Defaults to 1.5 * LIMIT used for the query."
132+
+ "Defaults to 1.5 * k (or LIMIT) used for the query."
133+
),
134+
@MapParam.MapParamEntry(
135+
name = "visit_percentage",
136+
type = "float",
137+
valueHint = { "10" },
138+
description = "The percentage of vectors to explore per shard while doing knn search with bbq_disk. "
139+
+ "Must be between 0 and 100. 0 will default to using num_candidates for calculating the percent visited. "
140+
+ "Increasing visit_percentage tends to improve the accuracy of the final results. "
141+
+ "If visit_percentage is set for bbq_disk, num_candidates is ignored. "
142+
+ "Defaults to ~1% per shard for every 1 million vectors"
120143
),
121144
@MapParam.MapParamEntry(
122145
name = "similarity",
@@ -146,12 +169,12 @@ public Knn(
146169
Expression field,
147170
Expression query,
148171
Expression options,
149-
Integer k,
172+
Integer implicitK,
150173
QueryBuilder queryBuilder,
151174
List<Expression> filterExpressions
152175
) {
153176
super(source, field, query, options, expressionList(field, query, options), queryBuilder);
154-
this.k = k;
177+
this.implicitK = implicitK;
155178
this.filterExpressions = filterExpressions;
156179
}
157180

@@ -165,15 +188,15 @@ private static List<Expression> expressionList(Expression field, Expression quer
165188
return result;
166189
}
167190

168-
public Integer k() {
169-
return k;
191+
public Integer implicitK() {
192+
return implicitK;
170193
}
171194

172195
public List<Expression> filterExpressions() {
173196
return filterExpressions;
174197
}
175198

176-
public Knn replaceK(Integer k) {
199+
public Knn withImplicitK(Integer k) {
177200
Check.notNull(k, "k must not be null");
178201
return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
179202
}
@@ -191,7 +214,7 @@ public List<Number> queryAsObject() {
191214

192215
@Override
193216
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
194-
return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
217+
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder, filterExpressions());
195218
}
196219

197220
@Override
@@ -207,7 +230,7 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
207230

208231
@Override
209232
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
210-
assert k() != null : "Knn function must have a k value set before translation";
233+
assert implicitK() != null : "Knn function must have a k value set before translation";
211234
var fieldAttribute = fieldAsFieldAttribute(field());
212235

213236
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
@@ -226,7 +249,10 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
226249
}
227250
}
228251

229-
return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
252+
Map<String, Object> options = queryOptions();
253+
Integer explicitK = (Integer) options.get(K_FIELD.getPreferredName());
254+
255+
return new KnnQuery(source(), fieldName, queryAsFloats, explicitK != null ? explicitK : implicitK(), options, filterQueries);
230256
}
231257

232258
private float[] queryAsFloats() {
@@ -239,7 +265,7 @@ private float[] queryAsFloats() {
239265
}
240266

241267
public Expression withFilters(List<Expression> filterExpressions) {
242-
return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
268+
return new Knn(source(), field(), query(), options(), implicitK(), queryBuilder(), filterExpressions);
243269
}
244270

245271
private Map<String, Object> queryOptions() throws InvalidArgumentException {
@@ -264,7 +290,7 @@ protected QueryBuilder evaluatorQueryBuilder() {
264290
@Override
265291
public void postOptimizationVerification(Failures failures) {
266292
// Check that a k has been set
267-
if (k() == null) {
293+
if (implicitK() == null) {
268294
failures.add(
269295
Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find")
270296
);
@@ -278,15 +304,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
278304
newChildren.get(0),
279305
newChildren.get(1),
280306
newChildren.size() > 2 ? newChildren.get(2) : null,
281-
k(),
307+
implicitK(),
282308
queryBuilder(),
283309
filterExpressions()
284310
);
285311
}
286312

287313
@Override
288314
protected NodeInfo<? extends Expression> info() {
289-
return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
315+
return NodeInfo.create(this, Knn::new, field(), query(), options(), implicitK(), queryBuilder(), filterExpressions());
290316
}
291317

292318
@Override
@@ -334,12 +360,14 @@ public boolean equals(Object o) {
334360
// ignore options when comparing two Knn functions
335361
if (o == null || getClass() != o.getClass()) return false;
336362
Knn knn = (Knn) o;
337-
return super.equals(knn) && Objects.equals(k(), knn.k()) && Objects.equals(filterExpressions(), knn.filterExpressions());
363+
return super.equals(knn)
364+
&& Objects.equals(implicitK(), knn.implicitK())
365+
&& Objects.equals(filterExpressions(), knn.filterExpressions());
338366
}
339367

340368
@Override
341369
public int hashCode() {
342-
return Objects.hash(field(), query(), queryBuilder(), k(), filterExpressions());
370+
return Objects.hash(field(), query(), queryBuilder(), implicitK(), filterExpressions());
343371
}
344372

345373
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
6161
private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
6262
return condition.transformDown(exp -> {
6363
if (exp instanceof Knn knn) {
64-
return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
64+
return knn.withImplicitK((Integer) limit.limit().fold(ctx.foldCtx()));
6565
}
6666
return exp;
6767
});

0 commit comments

Comments
 (0)