4646import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .NUM_CANDS_FIELD ;
4747import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .VECTOR_SIMILARITY_FIELD ;
4848import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
49+ import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
4950import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .THIRD ;
51+ import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isMapExpression ;
5052import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isNotNull ;
53+ import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isNotNullAndFoldable ;
5154import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isType ;
5255import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
5356import static org .elasticsearch .xpack .esql .core .type .DataType .FLOAT ;
@@ -72,8 +75,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
7275 @ FunctionInfo (
7376 returnType = "boolean" ,
7477 preview = true ,
75- description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " +
76- "knn function finds nearest vectors through approximate search on indexed dense_vectors." ,
78+ description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
79+ + "knn function finds nearest vectors through approximate search on indexed dense_vectors." ,
7780 examples = {
7881 @ Example (file = "knn-function" , tag = "knn-function" ),
7982 @ Example (file = "knn-function" , tag = "knn-function-options" ), },
@@ -156,12 +159,48 @@ public DataType dataType() {
156159
157160 @ Override
158161 protected TypeResolution resolveParams () {
159- if (childrenResolved () == false ) {
160- return new TypeResolution ("Unresolved children" );
162+ return resolveField ().and (resolveQuery ()).and (resolveOptions ());
163+ }
164+
165+ private TypeResolution resolveField () {
166+ return isNotNull (field (), sourceText (), FIRST ).and (isType (field (), dt -> dt == DENSE_VECTOR , sourceText (), FIRST , "dense_vector" ));
167+ }
168+
169+ private TypeResolution resolveQuery () {
170+ return isType (query (), dt -> dt == DENSE_VECTOR , sourceText (), TypeResolutions .ParamOrdinal .SECOND , "dense_vector" ).and (
171+ isNotNullAndFoldable (query (), sourceText (), SECOND )
172+ );
173+ }
174+
175+ private TypeResolution resolveOptions () {
176+ if (options () != null ) {
177+ TypeResolution resolution = isNotNull (options (), sourceText (), THIRD );
178+ if (resolution .unresolved ()) {
179+ return resolution ;
180+ }
181+ // MapExpression does not have a DataType associated with it
182+ resolution = isMapExpression (options (), sourceText (), THIRD );
183+ if (resolution .unresolved ()) {
184+ return resolution ;
185+ }
186+
187+ try {
188+ knnQueryOptions ();
189+ } catch (InvalidArgumentException e ) {
190+ return new TypeResolution (e .getMessage ());
191+ }
192+ }
193+ return TypeResolution .TYPE_RESOLVED ;
194+ }
195+
196+ private Map <String , Object > knnQueryOptions () throws InvalidArgumentException {
197+ if (options () == null ) {
198+ return Map .of ();
161199 }
162200
163- return isNotNull (field (), sourceText (), FIRST ).and (isType (field (), dt -> dt == DENSE_VECTOR , sourceText (), FIRST , "dense_vector" ))
164- .and (isType (query (), dt -> dt == DENSE_VECTOR , sourceText (), TypeResolutions .ParamOrdinal .SECOND , "dense_vector" ));
201+ Map <String , Object > matchOptions = new HashMap <>();
202+ populateOptionsMap ((MapExpression ) options (), matchOptions , THIRD , sourceText (), ALLOWED_OPTIONS );
203+ return matchOptions ;
165204 }
166205
167206 @ Override
@@ -240,8 +279,8 @@ public boolean equals(Object o) {
240279 if (o == null || getClass () != o .getClass ()) return false ;
241280 Knn knn = (Knn ) o ;
242281 return Objects .equals (field (), knn .field ())
243- && Objects .equals (query (), knn .query ())
244- && Objects .equals (queryBuilder (), knn .queryBuilder ());
282+ && Objects .equals (query (), knn .query ())
283+ && Objects .equals (queryBuilder (), knn .queryBuilder ());
245284 }
246285
247286 @ Override
0 commit comments