1010import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
1111import org .elasticsearch .common .io .stream .StreamInput ;
1212import org .elasticsearch .common .io .stream .StreamOutput ;
13- import org .elasticsearch .xpack . esql . capabilities . TranslationAware ;
13+ import org .elasticsearch .index . query . QueryBuilder ;
1414import org .elasticsearch .xpack .esql .core .InvalidArgumentException ;
1515import org .elasticsearch .xpack .esql .core .expression .Expression ;
1616import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
1717import org .elasticsearch .xpack .esql .core .expression .MapExpression ;
1818import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
19- import org .elasticsearch .xpack .esql .core .expression .function .Function ;
2019import org .elasticsearch .xpack .esql .core .querydsl .query .Query ;
2120import org .elasticsearch .xpack .esql .core .tree .NodeInfo ;
2221import org .elasticsearch .xpack .esql .core .tree .Source ;
2625import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesToLifecycle ;
2726import org .elasticsearch .xpack .esql .expression .function .FunctionInfo ;
2827import org .elasticsearch .xpack .esql .expression .function .OptionalArgument ;
28+ import org .elasticsearch .xpack .esql .expression .function .fulltext .FullTextFunction ;
2929import org .elasticsearch .xpack .esql .expression .function .fulltext .Match ;
3030import org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
31- import org .elasticsearch .xpack .esql .optimizer .rules .physical .local .LucenePushdownPredicates ;
3231import org .elasticsearch .xpack .esql .planner .TranslatorHandler ;
3332import org .elasticsearch .xpack .esql .querydsl .query .KnnQuery ;
3433
5150import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
5251import static org .elasticsearch .xpack .esql .core .type .DataType .FLOAT ;
5352import static org .elasticsearch .xpack .esql .core .type .DataType .INTEGER ;
54- import static org .elasticsearch .xpack .esql .expression .function .fulltext .FullTextFunction .populateOptionsMap ;
5553import static org .elasticsearch .xpack .esql .expression .function .fulltext .Match .getNameFromFieldAttribute ;
5654
57- public class Knn extends Function implements TranslationAware , OptionalArgument {
55+ public class Knn extends FullTextFunction implements OptionalArgument {
5856
5957 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (Expression .class , "Knn" , Knn ::readFrom );
6058
6159 private final Expression field ;
62- private final Expression query ;
63- // TODO Options could be serialized via QueryBuilder in case we want to rewrite it in the coordinator node (for query text inference)
6460 private final Expression options ;
6561
6662 public static final Map <String , DataType > ALLOWED_OPTIONS = Map .ofEntries (
@@ -84,20 +80,20 @@ public class Knn extends Function implements TranslationAware, OptionalArgument
8480 ) }
8581 )
8682 public Knn (Source source , Expression field , Expression query , Expression options ) {
87- super (source , options == null ? List .of (field , query ) : List .of (field , query , options ));
83+ this (source , field , query , options , null );
84+ }
85+
86+ public Knn (Source source , Expression field , Expression query , Expression options , QueryBuilder queryBuilder ) {
87+ super (source , query , options == null ? List .of (field , query ) : List .of (field , query , options ), queryBuilder );
8888 this .field = field ;
89- this .query = query ;
9089 this .options = options ;
9190 }
9291
92+
9393 public Expression field () {
9494 return field ;
9595 }
9696
97- public Expression query () {
98- return query ;
99- }
100-
10197 public Expression options () {
10298 return options ;
10399 }
@@ -108,7 +104,7 @@ public DataType dataType() {
108104 }
109105
110106 @ Override
111- protected final TypeResolution resolveType () {
107+ protected TypeResolution resolveParams () {
112108 if (childrenResolved () == false ) {
113109 return new TypeResolution ("Unresolved children" );
114110 }
@@ -118,12 +114,7 @@ protected final TypeResolution resolveType() {
118114 }
119115
120116 @ Override
121- public boolean translatable (LucenePushdownPredicates pushdownPredicates ) {
122- return true ;
123- }
124-
125- @ Override
126- public Query asQuery (LucenePushdownPredicates pushdownPredicates , TranslatorHandler handler ) {
117+ protected Query translate (TranslatorHandler handler ) {
127118 var fieldAttribute = Match .fieldAsFieldAttribute (field ());
128119
129120 Check .notNull (fieldAttribute , "Match must have a field attribute as the first argument" );
@@ -138,6 +129,11 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand
138129 return new KnnQuery (source (), fieldName , queryAsFloats , queryOptions ());
139130 }
140131
132+ @ Override
133+ public Expression replaceQueryBuilder (QueryBuilder queryBuilder ) {
134+ return new Knn (source (), field (), query (), options (), queryBuilder );
135+ }
136+
141137 private Map <String , Object > queryOptions () throws InvalidArgumentException {
142138 if (options () == null ) {
143139 return Map .of ();
@@ -167,30 +163,31 @@ private static Knn readFrom(StreamInput in) throws IOException {
167163 Source source = Source .readFrom ((PlanStreamInput ) in );
168164 Expression field = in .readNamedWriteable (Expression .class );
169165 Expression query = in .readNamedWriteable (Expression .class );
170- Expression options = in .readOptionalNamedWriteable (Expression .class );
166+ QueryBuilder queryBuilder = in .readOptionalNamedWriteable (QueryBuilder .class );
171167
172- return new Knn (source , field , query , options );
168+ return new Knn (source , field , query , null , queryBuilder );
173169 }
174170
175171 @ Override
176172 public void writeTo (StreamOutput out ) throws IOException {
177173 source ().writeTo (out );
178174 out .writeNamedWriteable (field ());
179175 out .writeNamedWriteable (query ());
180- out .writeOptionalNamedWriteable (options ());
176+ out .writeOptionalNamedWriteable (queryBuilder ());
181177 }
182178
183179 @ Override
184180 public boolean equals (Object o ) {
185181 if (o == null || getClass () != o .getClass ()) return false ;
186182 if (super .equals (o ) == false ) return false ;
187183 Knn knn = (Knn ) o ;
188- return Objects .equals (field , knn .field ) && Objects .equals (query , knn .query );
184+ return Objects .equals (field , knn .field ) && Objects .equals (query (), knn .query ())
185+ && Objects .equals (queryBuilder (), knn .queryBuilder ());
189186 }
190187
191188 @ Override
192189 public int hashCode () {
193- return Objects .hash (super . hashCode (), field , query );
190+ return Objects .hash (field (), query (), queryBuilder () );
194191 }
195192
196193}
0 commit comments