Skip to content

Commit 9588048

Browse files
committed
Add options
1 parent 8317911 commit 9588048

File tree

5 files changed

+94
-15
lines changed

5 files changed

+94
-15
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
2828

2929
private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
3030

31-
public void testKnn() {
31+
public void testKnnDefaults() {
3232
var query = """
3333
FROM test METADATA _score
3434
| WHERE knn(vector, [1.0, 1.0, 1.0])
@@ -61,6 +61,23 @@ public void testKnn() {
6161
}
6262
}
6363

64+
public void testKnnOptions() {
65+
var query = """
66+
FROM test METADATA _score
67+
| WHERE knn(vector, [1.0, 1.0, 1.0], {"k": 5})
68+
| KEEP id, floats, _score, vector
69+
| SORT _score DESC
70+
""";
71+
72+
try (var resp = run(query)) {
73+
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
74+
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
75+
76+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
77+
assertEquals(5, valuesList.size());
78+
}
79+
}
80+
6481
@Before
6582
public void setup() throws IOException {
6683
var indexName = "test";

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
448448
def(AvgOverTime.class, uni(AvgOverTime::new), "avg_over_time"),
449449
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
450450
def(Term.class, bi(Term::new), "term"),
451-
def(Knn.class, bi(Knn::new), "knn") } };
451+
def(Knn.class, tri(Knn::new), "knn") } };
452452
}
453453

454454
public EsqlFunctionRegistry snapshotRegistry() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
333333
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
334334
}
335335

336-
protected static void populateOptionsMap(
336+
public static void populateOptionsMap(
337337
final MapExpression options,
338338
final Map<String, Object> optionsMap,
339339
final TypeResolutions.ParamOrdinal paramOrdinal,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
14+
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
1415
import org.elasticsearch.xpack.esql.core.expression.Expression;
1516
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
17+
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
1618
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
1719
import org.elasticsearch.xpack.esql.core.expression.function.Function;
1820
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
@@ -23,28 +25,51 @@
2325
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
2426
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
2527
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
28+
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
2629
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
2730
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2831
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
2932
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3033
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
3134

3235
import java.io.IOException;
36+
import java.util.HashMap;
3337
import java.util.List;
38+
import java.util.Map;
3439
import java.util.Objects;
3540

41+
import static java.util.Map.entry;
42+
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
43+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
44+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
45+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
46+
import static org.elasticsearch.search.vectors.RescoreVectorBuilder.OVERSAMPLE_FIELD;
3647
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
48+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
3749
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
3850
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
3951
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
52+
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
53+
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
54+
import static org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction.populateOptionsMap;
4055
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
4156

42-
public class Knn extends Function implements TranslationAware {
57+
public class Knn extends Function implements TranslationAware, OptionalArgument {
4358

4459
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
4560

4661
private final Expression field;
4762
private final Expression query;
63+
// TODO Options could be serialized via QueryBuilder in case we want to rewrite it in the coordinator node (for query text inference)
64+
private final Expression options;
65+
66+
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
67+
entry(K_FIELD.getPreferredName(), INTEGER),
68+
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
69+
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
70+
entry(BOOST_FIELD.getPreferredName(), FLOAT),
71+
entry(OVERSAMPLE_FIELD.getPreferredName(), FLOAT)
72+
);
4873

4974
@FunctionInfo(
5075
returnType = "boolean",
@@ -58,10 +83,11 @@ public class Knn extends Function implements TranslationAware {
5883
lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT
5984
) }
6085
)
61-
public Knn(Source source, Expression field, Expression query) {
62-
super(source, List.of(field, query));
86+
public Knn(Source source, Expression field, Expression query, Expression options) {
87+
super(source, options == null ? List.of(field, query) : List.of(field, query, options));
6388
this.field = field;
6489
this.query = query;
90+
this.options = options;
6591
}
6692

6793
public Expression field() {
@@ -72,6 +98,10 @@ public Expression query() {
7298
return query;
7399
}
74100

101+
public Expression options() {
102+
return options;
103+
}
104+
75105
@Override
76106
public DataType dataType() {
77107
return DataType.BOOLEAN;
@@ -104,17 +134,28 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand
104134
for (int i = 0; i < queryFolded.size(); i++) {
105135
queryAsFloats[i] = queryFolded.get(i).floatValue();
106136
}
107-
return new KnnQuery(source(), fieldName, queryAsFloats);
137+
138+
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
139+
}
140+
141+
private Map<String, Object> queryOptions() throws InvalidArgumentException {
142+
if (options() == null) {
143+
return Map.of();
144+
}
145+
146+
Map<String, Object> options = new HashMap<>();
147+
populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
148+
return options;
108149
}
109150

110151
@Override
111152
public Expression replaceChildren(List<Expression> newChildren) {
112-
return new Knn(source(), newChildren.get(0), newChildren.get(1));
153+
return new Knn(source(), newChildren.get(0), newChildren.get(1), newChildren.size() > 2 ? newChildren.get(2) : null);
113154
}
114155

115156
@Override
116157
protected NodeInfo<? extends Expression> info() {
117-
return NodeInfo.create(this, Knn::new, field(), query());
158+
return NodeInfo.create(this, Knn::new, field(), query(), options());
118159
}
119160

120161
@Override
@@ -126,15 +167,17 @@ private static Knn readFrom(StreamInput in) throws IOException {
126167
Source source = Source.readFrom((PlanStreamInput) in);
127168
Expression field = in.readNamedWriteable(Expression.class);
128169
Expression query = in.readNamedWriteable(Expression.class);
170+
Expression options = in.readOptionalNamedWriteable(Expression.class);
129171

130-
return new Knn(source, field, query);
172+
return new Knn(source, field, query, options);
131173
}
132174

133175
@Override
134176
public void writeTo(StreamOutput out) throws IOException {
135177
source().writeTo(out);
136178
out.writeNamedWriteable(field());
137179
out.writeNamedWriteable(query());
180+
out.writeOptionalNamedWriteable(options());
138181
}
139182

140183
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,62 @@
99

1010
import org.elasticsearch.index.query.QueryBuilder;
1111
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
12+
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
1213
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
1314
import org.elasticsearch.xpack.esql.core.tree.Source;
1415

1516
import java.util.Arrays;
17+
import java.util.Map;
1618
import java.util.Objects;
1719

20+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
21+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
22+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
23+
1824
public class KnnQuery extends Query {
1925

2026
private final String field;
2127
private final float[] query;
28+
private final Map<String, Object> options;
2229

23-
public KnnQuery(Source source, String field, float[] query) {
30+
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options) {
2431
super(source);
32+
assert options != null;
2533
this.field = field;
2634
this.query = query;
35+
this.options = options;
2736
}
2837

2938
@Override
3039
protected QueryBuilder asBuilder() {
31-
return new KnnVectorQueryBuilder(field, query, null, null, null, null);
40+
Integer k = (Integer) options.get(K_FIELD.getPreferredName());
41+
Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
42+
RescoreVectorBuilder rescoreVectorBuilder = null;
43+
Float oversample = (Float) options.get(RescoreVectorBuilder.OVERSAMPLE_FIELD.getPreferredName());
44+
if (oversample != null) {
45+
rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
46+
}
47+
Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
48+
49+
return new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
3250
}
3351

3452
@Override
3553
protected String innerToString() {
36-
return "knn(" + field + ", " + Arrays.toString(query) + ")";
54+
return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))";
3755
}
3856

3957
@Override
4058
public boolean equals(Object o) {
4159
if (!(o instanceof KnnQuery knnQuery)) return false;
4260
if (super.equals(o) == false) return false;
43-
return Objects.equals(field, knnQuery.field) && Objects.deepEquals(query, knnQuery.query);
61+
return Objects.equals(field, knnQuery.field)
62+
&& Objects.deepEquals(query, knnQuery.query) && Objects.equals(options, knnQuery.options);
4463
}
4564

4665
@Override
4766
public int hashCode() {
48-
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query));
67+
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options);
4968
}
5069

5170
@Override

0 commit comments

Comments
 (0)