1111import org .elasticsearch .common .io .stream .StreamInput ;
1212import org .elasticsearch .common .io .stream .StreamOutput ;
1313import org .elasticsearch .xpack .esql .capabilities .TranslationAware ;
14+ import org .elasticsearch .xpack .esql .core .InvalidArgumentException ;
1415import org .elasticsearch .xpack .esql .core .expression .Expression ;
1516import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
17+ import org .elasticsearch .xpack .esql .core .expression .MapExpression ;
1618import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
1719import org .elasticsearch .xpack .esql .core .expression .function .Function ;
1820import org .elasticsearch .xpack .esql .core .querydsl .query .Query ;
2325import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesTo ;
2426import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesToLifecycle ;
2527import org .elasticsearch .xpack .esql .expression .function .FunctionInfo ;
28+ import org .elasticsearch .xpack .esql .expression .function .OptionalArgument ;
2629import org .elasticsearch .xpack .esql .expression .function .fulltext .Match ;
2730import org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
2831import org .elasticsearch .xpack .esql .optimizer .rules .physical .local .LucenePushdownPredicates ;
2932import org .elasticsearch .xpack .esql .planner .TranslatorHandler ;
3033import org .elasticsearch .xpack .esql .querydsl .query .KnnQuery ;
3134
3235import java .io .IOException ;
36+ import java .util .HashMap ;
3337import java .util .List ;
38+ import java .util .Map ;
3439import java .util .Objects ;
3540
41+ import static java .util .Map .entry ;
42+ import static org .elasticsearch .index .query .AbstractQueryBuilder .BOOST_FIELD ;
43+ import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .K_FIELD ;
44+ import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .NUM_CANDS_FIELD ;
45+ import static org .elasticsearch .search .vectors .KnnVectorQueryBuilder .VECTOR_SIMILARITY_FIELD ;
46+ import static org .elasticsearch .search .vectors .RescoreVectorBuilder .OVERSAMPLE_FIELD ;
3647import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
48+ import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .THIRD ;
3749import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isNotNull ;
3850import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isType ;
3951import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
52+ import static org .elasticsearch .xpack .esql .core .type .DataType .FLOAT ;
53+ import static org .elasticsearch .xpack .esql .core .type .DataType .INTEGER ;
54+ import static org .elasticsearch .xpack .esql .expression .function .fulltext .FullTextFunction .populateOptionsMap ;
4055import static org .elasticsearch .xpack .esql .expression .function .fulltext .Match .getNameFromFieldAttribute ;
4156
42- public class Knn extends Function implements TranslationAware {
57+ public class Knn extends Function implements TranslationAware , OptionalArgument {
4358
4459 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
4560
4661 private final Expression field ;
4762 private final Expression query ;
63+ // TODO Options could be serialized via QueryBuilder in case we want to rewrite it in the coordinator node (for query text inference)
64+ private final Expression options ;
65+
66+ public static final Map <String , DataType > ALLOWED_OPTIONS = Map .ofEntries (
67+ entry (K_FIELD .getPreferredName (), INTEGER ),
68+ entry (NUM_CANDS_FIELD .getPreferredName (), INTEGER ),
69+ entry (VECTOR_SIMILARITY_FIELD .getPreferredName (), FLOAT ),
70+ entry (BOOST_FIELD .getPreferredName (), FLOAT ),
71+ entry (OVERSAMPLE_FIELD .getPreferredName (), FLOAT )
72+ );
4873
4974 @ FunctionInfo (
5075 returnType = "boolean" ,
@@ -58,10 +83,11 @@ public class Knn extends Function implements TranslationAware {
5883 lifeCycle = FunctionAppliesToLifecycle .DEVELOPMENT
5984 ) }
6085 )
61- public Knn (Source source , Expression field , Expression query ) {
62- super (source , List .of (field , query ));
86+ public Knn (Source source , Expression field , Expression query , Expression options ) {
87+ super (source , options == null ? List .of (field , query ) : List . of ( field , query , options ));
6388 this .field = field ;
6489 this .query = query ;
90+ this .options = options ;
6591 }
6692
6793 public Expression field () {
@@ -72,6 +98,10 @@ public Expression query() {
7298 return query ;
7399 }
74100
101+ public Expression options () {
102+ return options ;
103+ }
104+
75105 @ Override
76106 public DataType dataType () {
77107 return DataType .BOOLEAN ;
@@ -104,17 +134,28 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand
104134 for (int i = 0 ; i < queryFolded .size (); i ++) {
105135 queryAsFloats [i ] = queryFolded .get (i ).floatValue ();
106136 }
107- return new KnnQuery (source (), fieldName , queryAsFloats );
137+
138+ return new KnnQuery (source (), fieldName , queryAsFloats , queryOptions ());
139+ }
140+
141+ private Map <String , Object > queryOptions () throws InvalidArgumentException {
142+ if (options () == null ) {
143+ return Map .of ();
144+ }
145+
146+ Map <String , Object > options = new HashMap <>();
147+ populateOptionsMap ((MapExpression ) options (), options , THIRD , sourceText (), ALLOWED_OPTIONS );
148+ return options ;
108149 }
109150
110151 @ Override
111152 public Expression replaceChildren (List <Expression > newChildren ) {
112- return new Knn (source (), newChildren .get (0 ), newChildren .get (1 ));
153+ return new Knn (source (), newChildren .get (0 ), newChildren .get (1 ), newChildren . size () > 2 ? newChildren . get ( 2 ) : null );
113154 }
114155
115156 @ Override
116157 protected NodeInfo <? extends Expression > info () {
117- return NodeInfo .create (this , Knn ::new , field (), query ());
158+ return NodeInfo .create (this , Knn ::new , field (), query (), options () );
118159 }
119160
120161 @ Override
@@ -126,15 +167,17 @@ private static Knn readFrom(StreamInput in) throws IOException {
126167 Source source = Source .readFrom ((PlanStreamInput ) in );
127168 Expression field = in .readNamedWriteable (Expression .class );
128169 Expression query = in .readNamedWriteable (Expression .class );
170+ Expression options = in .readOptionalNamedWriteable (Expression .class );
129171
130- return new Knn (source , field , query );
172+ return new Knn (source , field , query , options );
131173 }
132174
133175 @ Override
134176 public void writeTo (StreamOutput out ) throws IOException {
135177 source ().writeTo (out );
136178 out .writeNamedWriteable (field ());
137179 out .writeNamedWriteable (query ());
180+ out .writeOptionalNamedWriteable (options ());
138181 }
139182
140183 @ Override
0 commit comments