Skip to content

Commit 66f8496

Browse files
committed
Add verifier tests
1 parent b352673 commit 66f8496

File tree

2 files changed

+86
-206
lines changed
  • x-pack/plugin/esql/src
    • main/java/org/elasticsearch/xpack/esql/expression/function/vector
    • test/java/org/elasticsearch/xpack/esql/analysis

2 files changed

+86
-206
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@
4646
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
4747
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
4848
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
49+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
4950
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
51+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
5052
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
53+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
5154
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
5255
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
5356
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
@@ -72,8 +75,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
7275
@FunctionInfo(
7376
returnType = "boolean",
7477
preview = true,
75-
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " +
76-
"knn function finds nearest vectors through approximate search on indexed dense_vectors.",
78+
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
79+
+ "knn function finds nearest vectors through approximate search on indexed dense_vectors.",
7780
examples = {
7881
@Example(file = "knn-function", tag = "knn-function"),
7982
@Example(file = "knn-function", tag = "knn-function-options"), },
@@ -156,12 +159,48 @@ public DataType dataType() {
156159

157160
@Override
158161
protected TypeResolution resolveParams() {
159-
if (childrenResolved() == false) {
160-
return new TypeResolution("Unresolved children");
162+
return resolveField().and(resolveQuery()).and(resolveOptions());
163+
}
164+
165+
private TypeResolution resolveField() {
166+
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
167+
}
168+
169+
private TypeResolution resolveQuery() {
170+
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
171+
isNotNullAndFoldable(query(), sourceText(), SECOND)
172+
);
173+
}
174+
175+
private TypeResolution resolveOptions() {
176+
if (options() != null) {
177+
TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
178+
if (resolution.unresolved()) {
179+
return resolution;
180+
}
181+
// MapExpression does not have a DataType associated with it
182+
resolution = isMapExpression(options(), sourceText(), THIRD);
183+
if (resolution.unresolved()) {
184+
return resolution;
185+
}
186+
187+
try {
188+
knnQueryOptions();
189+
} catch (InvalidArgumentException e) {
190+
return new TypeResolution(e.getMessage());
191+
}
192+
}
193+
return TypeResolution.TYPE_RESOLVED;
194+
}
195+
196+
private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
197+
if (options() == null) {
198+
return Map.of();
161199
}
162200

163-
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"))
164-
.and(isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector"));
201+
Map<String, Object> matchOptions = new HashMap<>();
202+
populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
203+
return matchOptions;
165204
}
166205

167206
@Override
@@ -240,8 +279,8 @@ public boolean equals(Object o) {
240279
if (o == null || getClass() != o.getClass()) return false;
241280
Knn knn = (Knn) o;
242281
return Objects.equals(field(), knn.field())
243-
&& Objects.equals(query(), knn.query())
244-
&& Objects.equals(queryBuilder(), knn.queryBuilder());
282+
&& Objects.equals(query(), knn.query())
283+
&& Objects.equals(queryBuilder(), knn.queryBuilder());
245284
}
246285

247286
@Override

0 commit comments

Comments
 (0)