Skip to content

Commit d25776a

Browse files
committed
Add k as a param
1 parent 9b2ca99 commit d25776a

File tree

3 files changed

+51
-29
lines changed

3 files changed

+51
-29
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ static TransportVersion def(int id) {
306306
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
307307
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
308308
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
309+
public static final TransportVersion ESQL_KNN_K_PARAM_MANDATORY = def(9_105_0_00);
309310

310311
/*
311312
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
487487
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
488488
def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"),
489489
def(Term.class, bi(Term::new), "term"),
490-
def(Knn.class, tri(Knn::new), "knn") } };
490+
def(Knn.class, Knn::new, "knn") } };
491491
}
492492

493493
public EsqlFunctionRegistry snapshotRegistry() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.vector;
99

10+
import org.elasticsearch.TransportVersions;
1011
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1112
import org.elasticsearch.common.io.stream.StreamInput;
1213
import org.elasticsearch.common.io.stream.StreamOutput;
@@ -41,6 +42,7 @@
4142
import java.util.Objects;
4243

4344
import static java.util.Map.entry;
45+
import static org.elasticsearch.TransportVersions.ESQL_KNN_K_PARAM_MANDATORY;
4446
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
4547
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
4648
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
@@ -62,10 +64,10 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
6264
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
6365

6466
private final Expression field;
67+
private final Expression k;
6568
private final Expression options;
6669

6770
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
68-
entry(K_FIELD.getPreferredName(), INTEGER),
6971
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
7072
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
7173
entry(BOOST_FIELD.getPreferredName(), FLOAT),
@@ -90,6 +92,13 @@ public Knn(
9092
type = { "dense_vector" },
9193
description = "Vector value to find top nearest neighbours for."
9294
) Expression query,
95+
@Param(
96+
name = "k",
97+
type = { "integer" },
98+
description = "The number of nearest neighbors to return from each shard. "
99+
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
100+
+ "This value must be less than or equal to num_candidates."
101+
) Expression k,
93102
@MapParam(
94103
name = "options",
95104
params = {
@@ -100,14 +109,6 @@ public Knn(
100109
description = "Floating point number used to decrease or increase the relevance scores of the query."
101110
+ "Defaults to 1.0."
102111
),
103-
@MapParam.MapParamEntry(
104-
name = "k",
105-
type = "integer",
106-
valueHint = { "10" },
107-
description = "The number of nearest neighbors to return from each shard. "
108-
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
109-
+ "This value must be less than or equal to num_candidates. Defaults to 10."
110-
),
111112
@MapParam.MapParamEntry(
112113
name = "num_candidates",
113114
type = "integer",
@@ -136,19 +137,24 @@ public Knn(
136137
optional = true
137138
) Expression options
138139
) {
139-
this(source, field, query, options, null);
140+
this(source, field, query, k, options, null);
140141
}
141142

142-
private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
143-
super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
143+
private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) {
144+
super(source, query, options == null ? List.of(field, query, k) : List.of(field, query, k, options), queryBuilder);
144145
this.field = field;
146+
this.k = k;
145147
this.options = options;
146148
}
147149

148150
public Expression field() {
149151
return field;
150152
}
151153

154+
public Expression k() {
155+
return k;
156+
}
157+
152158
public Expression options() {
153159
return options;
154160
}
@@ -160,7 +166,7 @@ public DataType dataType() {
160166

161167
@Override
162168
protected TypeResolution resolveParams() {
163-
return resolveField().and(resolveQuery()).and(resolveOptions());
169+
return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions());
164170
}
165171

166172
private TypeResolution resolveField() {
@@ -173,14 +179,19 @@ private TypeResolution resolveQuery() {
173179
);
174180
}
175181

182+
private TypeResolution resolveK() {
183+
return isNotNull(k(), sourceText(), TypeResolutions.ParamOrdinal.THIRD)
184+
.and(isType(k(), dt -> dt == INTEGER, sourceText(), TypeResolutions.ParamOrdinal.THIRD, "integer"));
185+
}
186+
176187
private TypeResolution resolveOptions() {
177188
if (options() != null) {
178-
TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
189+
TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
179190
if (resolution.unresolved()) {
180191
return resolution;
181192
}
182193
// MapExpression does not have a DataType associated with it
183-
resolution = isMapExpression(options(), sourceText(), THIRD);
194+
resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
184195
if (resolution.unresolved()) {
185196
return resolution;
186197
}
@@ -200,7 +211,7 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
200211
}
201212

202213
Map<String, Object> matchOptions = new HashMap<>();
203-
populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
214+
populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
204215
return matchOptions;
205216
}
206217

@@ -216,22 +227,24 @@ protected Query translate(TranslatorHandler handler) {
216227
for (int i = 0; i < queryFolded.size(); i++) {
217228
queryAsFloats[i] = queryFolded.get(i).floatValue();
218229
}
230+
int kValue = ((Number) k().fold(FoldContext.small())).intValue();
231+
232+
Map<String, Object> opts = queryOptions();
233+
opts.put(K_FIELD.getPreferredName(), kValue);
219234

220-
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
235+
return new KnnQuery(source(), fieldName, queryAsFloats, opts);
221236
}
222237

223238
@Override
224239
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
225-
return new Knn(source(), field(), query(), options(), queryBuilder);
240+
return new Knn(source(), field(), query(), k(), options(), queryBuilder);
226241
}
227242

228243
private Map<String, Object> queryOptions() throws InvalidArgumentException {
229-
if (options() == null) {
230-
return Map.of();
231-
}
232-
233244
Map<String, Object> options = new HashMap<>();
234-
populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
245+
if (options() != null) {
246+
populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
247+
}
235248
return options;
236249
}
237250

@@ -241,14 +254,15 @@ public Expression replaceChildren(List<Expression> newChildren) {
241254
source(),
242255
newChildren.get(0),
243256
newChildren.get(1),
244-
newChildren.size() > 2 ? newChildren.get(2) : null,
257+
newChildren.get(2),
258+
newChildren.size() > 3 ? newChildren.get(3) : null,
245259
queryBuilder()
246260
);
247261
}
248262

249263
@Override
250264
protected NodeInfo<? extends Expression> info() {
251-
return NodeInfo.create(this, Knn::new, field(), query(), options());
265+
return NodeInfo.create(this, Knn::new, field(), query(), k(), options());
252266
}
253267

254268
@Override
@@ -261,8 +275,11 @@ private static Knn readFrom(StreamInput in) throws IOException {
261275
Expression field = in.readNamedWriteable(Expression.class);
262276
Expression query = in.readNamedWriteable(Expression.class);
263277
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
264-
265-
return new Knn(source, field, query, null, queryBuilder);
278+
Expression k = null;
279+
if (in.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) {
280+
k = in.readNamedWriteable(Expression.class);
281+
}
282+
return new Knn(source, field, query, k, null, queryBuilder);
266283
}
267284

268285
@Override
@@ -271,6 +288,9 @@ public void writeTo(StreamOutput out) throws IOException {
271288
out.writeNamedWriteable(field());
272289
out.writeNamedWriteable(query());
273290
out.writeOptionalNamedWriteable(queryBuilder());
291+
if (out.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) {
292+
out.writeNamedWriteable(k());
293+
}
274294
}
275295

276296
@Override
@@ -281,12 +301,13 @@ public boolean equals(Object o) {
281301
Knn knn = (Knn) o;
282302
return Objects.equals(field(), knn.field())
283303
&& Objects.equals(query(), knn.query())
304+
&& Objects.equals(k(), knn.k())
284305
&& Objects.equals(queryBuilder(), knn.queryBuilder());
285306
}
286307

287308
@Override
288309
public int hashCode() {
289-
return Objects.hash(field(), query(), queryBuilder());
310+
return Objects.hash(field(), query(), k(), queryBuilder());
290311
}
291312

292313
}

0 commit comments

Comments
 (0)