Skip to content

Commit 0fb162a

Browse files
committed
KNN k is set via optimizer and limit
1 parent 82779ec commit 0fb162a

File tree

8 files changed

+172
-56
lines changed

8 files changed

+172
-56
lines changed

docs/reference/query-languages/esql/images/functions/knn.svg

Lines changed: 1 addition & 1 deletion
Loading

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
490490
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
491491
def(Score.class, uni(Score::new), Score.NAME),
492492
def(Term.class, bi(Term::new), "term"),
493-
def(Knn.class, quad(Knn::new), "knn"),
493+
def(Knn.class, tri(Knn::new), "knn"),
494494
def(StGeohash.class, StGeohash::new, "st_geohash"),
495495
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
496496
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.vector;
99

10-
import org.apache.logging.log4j.LogManager;
11-
import org.apache.logging.log4j.Logger;
1210
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1311
import org.elasticsearch.common.io.stream.StreamInput;
1412
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -56,14 +54,12 @@
5654
import static java.util.Map.entry;
5755
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
5856
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
59-
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
6057
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
6158
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
6259
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
6360
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
6461
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
6562
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
66-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
6763
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
6864
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
6965
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
@@ -73,13 +69,12 @@
7369
import static org.elasticsearch.xpack.esql.expression.function.FunctionUtils.resolveTypeQuery;
7470

7571
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
76-
private final Logger log = LogManager.getLogger(getClass());
7772

7873
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
7974

8075
private final Expression field;
8176
// k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
82-
private final transient Expression k;
77+
private final transient Integer k;
8378
private final Expression options;
8479
// Expressions to be used as prefilters in knn query
8580
private final List<Expression> filterExpressions;
@@ -107,13 +102,6 @@ public Knn(
107102
type = { "dense_vector" },
108103
description = "Vector value to find top nearest neighbours for."
109104
) Expression query,
110-
@Param(
111-
name = "k",
112-
type = { "integer" },
113-
description = "The number of nearest neighbors to return from each shard. "
114-
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
115-
+ "This value must be less than or equal to num_candidates."
116-
) Expression k,
117105
@MapParam(
118106
name = "options",
119107
params = {
@@ -125,12 +113,13 @@ public Knn(
125113
+ "Defaults to 1.0."
126114
),
127115
@MapParam.MapParamEntry(
128-
name = "num_candidates",
116+
name = "min_candidates",
129117
type = "integer",
130118
valueHint = { "10" },
131-
description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
132-
+ "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
133-
+ "Defaults to 1.5 * k"
119+
description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. " +
120+
" KNN may use a higher number of candidates in case the query can't use a approximate results. "
121+
+ "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
122+
+ "Defaults to 1.5 * LIMIT used for the query."
134123
),
135124
@MapParam.MapParamEntry(
136125
name = "similarity",
@@ -152,32 +141,29 @@ public Knn(
152141
optional = true
153142
) Expression options
154143
) {
155-
this(source, field, query, k, options, null, List.of());
144+
this(source, field, query, options, null, null, List.of());
156145
}
157146

158147
public Knn(
159148
Source source,
160149
Expression field,
161150
Expression query,
162-
Expression k,
163151
Expression options,
152+
Integer k,
164153
QueryBuilder queryBuilder,
165154
List<Expression> filterExpressions
166155
) {
167-
super(source, query, expressionList(field, query, k, options), queryBuilder);
156+
super(source, query, expressionList(field, query, options), queryBuilder);
168157
this.field = field;
169158
this.k = k;
170159
this.options = options;
171160
this.filterExpressions = filterExpressions;
172161
}
173162

174-
private static List<Expression> expressionList(Expression field, Expression query, Expression k, Expression options) {
163+
private static List<Expression> expressionList(Expression field, Expression query, Expression options) {
175164
List<Expression> result = new ArrayList<>();
176165
result.add(field);
177166
result.add(query);
178-
if (k != null) {
179-
result.add(k);
180-
}
181167
if (options != null) {
182168
result.add(options);
183169
}
@@ -188,7 +174,7 @@ public Expression field() {
188174
return field;
189175
}
190176

191-
public Expression k() {
177+
public Integer k() {
192178
return k;
193179
}
194180

@@ -207,7 +193,7 @@ public DataType dataType() {
207193

208194
@Override
209195
protected TypeResolution resolveParams() {
210-
return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS));
196+
return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS));
211197
}
212198

213199
private TypeResolution resolveField() {
@@ -227,14 +213,9 @@ private TypeResolution resolveQuery() {
227213
return TypeResolution.TYPE_RESOLVED;
228214
}
229215

230-
private TypeResolution resolveK() {
231-
if (k == null) {
232-
// Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed
233-
return TypeResolution.TYPE_RESOLVED;
234-
}
235-
236-
return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD))
237-
.and(isNotNull(k(), sourceText(), THIRD));
216+
public Knn replaceK(Integer k) {
217+
Check.notNull(k, "k must not be null");
218+
return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
238219
}
239220

240221
public List<Number> queryAsObject() {
@@ -248,16 +229,9 @@ public List<Number> queryAsObject() {
248229
throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query));
249230
}
250231

251-
int getKIntValue() {
252-
if (k() instanceof Literal literal) {
253-
return (int) (Number) literal.value();
254-
}
255-
throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
256-
}
257-
258232
@Override
259233
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
260-
return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
234+
return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
261235
}
262236

263237
@Override
@@ -273,15 +247,12 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
273247

274248
@Override
275249
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
250+
assert k() != null : "Knn function must have a k value set before translation";
276251
var fieldAttribute = Match.fieldAsFieldAttribute(field());
277252

278253
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
279254
String fieldName = getNameFromFieldAttribute(fieldAttribute);
280255
float[] queryAsFloats = queryAsFloats();
281-
int kValue = getKIntValue();
282-
283-
Map<String, Object> opts = queryOptions();
284-
opts.put(K_FIELD.getPreferredName(), kValue);
285256

286257
List<QueryBuilder> filterQueries = new ArrayList<>();
287258
for (Expression filterExpression : filterExpressions()) {
@@ -295,7 +266,7 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
295266
}
296267
}
297268

298-
return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
269+
return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
299270
}
300271

301272
private float[] queryAsFloats() {
@@ -308,7 +279,7 @@ private float[] queryAsFloats() {
308279
}
309280

310281
public Expression withFilters(List<Expression> filterExpressions) {
311-
return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
282+
return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
312283
}
313284

314285
private Map<String, Object> queryOptions() throws InvalidArgumentException {
@@ -343,16 +314,16 @@ public Expression replaceChildren(List<Expression> newChildren) {
343314
source(),
344315
newChildren.get(0),
345316
newChildren.get(1),
346-
newChildren.get(2),
347-
newChildren.size() > 3 ? newChildren.get(3) : null,
317+
newChildren.size() > 2 ? newChildren.get(2) : null,
318+
k(),
348319
queryBuilder(),
349320
filterExpressions()
350321
);
351322
}
352323

353324
@Override
354325
protected NodeInfo<? extends Expression> info() {
355-
return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions());
326+
return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
356327
}
357328

358329
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
4545
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownJoinPastProject;
4646
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract;
47+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushLimitToKnn;
4748
import org.elasticsearch.xpack.esql.optimizer.rules.logical.RemoveStatsOverride;
4849
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval;
4950
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateNestedExpressionWithEval;
@@ -192,6 +193,7 @@ protected static Batch<LogicalPlan> operators(boolean local) {
192193
new PruneColumns(),
193194
new PruneLiteralsInOrderBy(),
194195
new PushDownAndCombineLimits(),
196+
new PushLimitToKnn(),
195197
new PushDownAndCombineFilters(),
196198
new PushDownConjunctionsToKnnPrefilters(),
197199
new PushDownAndCombineSample(),
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Expression;
11+
import org.elasticsearch.xpack.esql.core.util.Holder;
12+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
13+
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
14+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
15+
import org.elasticsearch.xpack.esql.plan.logical.Filter;
16+
import org.elasticsearch.xpack.esql.plan.logical.Limit;
17+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
18+
import org.elasticsearch.xpack.esql.plan.logical.TopN;
19+
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
20+
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
21+
22+
/**
23+
* Traverses the logical plan and pushes down the limit to the KNN function(s) in filter expressions, so KNN can use
24+
* it to set k if not specified.
25+
*/
26+
public class PushLimitToKnn extends OptimizerRules.ParameterizedOptimizerRule<Limit, LogicalOptimizerContext> {
27+
28+
public PushLimitToKnn() {
29+
super(OptimizerRules.TransformDirection.DOWN);
30+
}
31+
32+
@Override
33+
public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
34+
Holder<Boolean> breakerReached = new Holder<>(false);
35+
Holder<Boolean> firstLimit = new Holder<>(false);
36+
return limit.transformDown(plan -> {
37+
if (breakerReached.get()) {
38+
// We reached a breaker and don't want to continue processing
39+
return plan;
40+
}
41+
if (plan instanceof Filter filter) {
42+
Expression limitAppliedExpression = limitFilterExpressions(filter.condition(), limit, ctx);
43+
if (limitAppliedExpression.equals(filter.condition()) == false) {
44+
return filter.with(limitAppliedExpression);
45+
}
46+
} else if (plan instanceof Limit) {
47+
// Break if it's not the initial limit
48+
breakerReached.set(firstLimit.get());
49+
firstLimit.set(true);
50+
} else if (plan instanceof TopN || plan instanceof Rerank || plan instanceof Aggregate) {
51+
breakerReached.set(true);
52+
}
53+
54+
return plan;
55+
});
56+
}
57+
58+
/**
59+
* Applies a limit to the filter expressions of a condition. Some filter expressions, such as KNN function,
60+
* can be optimized by applying the limit directly to them.
61+
*/
62+
private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
63+
return condition.transformDown(exp -> {
64+
if (exp instanceof Knn knn) {
65+
return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
66+
}
67+
return exp;
68+
});
69+
}
70+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ public class KnnQuery extends Query {
3232
private final List<QueryBuilder> filterQueries;
3333

3434
public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
35+
private final Integer k;
3536

36-
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options, List<QueryBuilder> filterQueries) {
37+
public KnnQuery(Source source, String field, float[] query, Integer k, Map<String, Object> options, List<QueryBuilder> filterQueries) {
3738
super(source);
39+
this.k = k;
3840
assert options != null;
3941
this.field = field;
4042
this.query = query;
@@ -44,7 +46,6 @@ public KnnQuery(Source source, String field, float[] query, Map<String, Object>
4446

4547
@Override
4648
protected QueryBuilder asBuilder() {
47-
Integer k = (Integer) options.get(K_FIELD.getPreferredName());
4849
Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
4950
RescoreVectorBuilder rescoreVectorBuilder = null;
5051
Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private static List<TestCaseSupplier> addFunctionNamedParams(List<TestCaseSuppli
121121

122122
@Override
123123
protected Expression build(Source source, List<Expression> args) {
124-
Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null);
124+
Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
125125
// We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
126126
// thus test the serialization methods. But we can only do this if the parameters make sense .
127127
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {

0 commit comments

Comments
 (0)