Skip to content

Commit 0f77374

Browse files
committed
Make Knn a FullTextFunction
1 parent 9588048 commit 0f77374

File tree

1 file changed

+22
-25
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector

1 file changed

+22
-25
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
13-
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
13+
import org.elasticsearch.index.query.QueryBuilder;
1414
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
1515
import org.elasticsearch.xpack.esql.core.expression.Expression;
1616
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1717
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
1818
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
19-
import org.elasticsearch.xpack.esql.core.expression.function.Function;
2019
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
2120
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2221
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -26,9 +25,9 @@
2625
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
2726
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2827
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
28+
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
2929
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
3030
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
31-
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3231
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3332
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
3433

@@ -51,16 +50,13 @@
5150
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
5251
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
5352
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
54-
import static org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction.populateOptionsMap;
5553
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
5654

57-
public class Knn extends Function implements TranslationAware, OptionalArgument {
55+
public class Knn extends FullTextFunction implements OptionalArgument {
5856

5957
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
6058

6159
private final Expression field;
62-
private final Expression query;
63-
// TODO Options could be serialized via QueryBuilder in case we want to rewrite it in the coordinator node (for query text inference)
6460
private final Expression options;
6561

6662
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
@@ -84,20 +80,20 @@ public class Knn extends Function implements TranslationAware, OptionalArgument
8480
) }
8581
)
8682
public Knn(Source source, Expression field, Expression query, Expression options) {
87-
super(source, options == null ? List.of(field, query) : List.of(field, query, options));
83+
this(source, field, query, options, null);
84+
}
85+
86+
public Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
87+
super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
8888
this.field = field;
89-
this.query = query;
9089
this.options = options;
9190
}
9291

92+
9393
public Expression field() {
9494
return field;
9595
}
9696

97-
public Expression query() {
98-
return query;
99-
}
100-
10197
public Expression options() {
10298
return options;
10399
}
@@ -108,7 +104,7 @@ public DataType dataType() {
108104
}
109105

110106
@Override
111-
protected final TypeResolution resolveType() {
107+
protected TypeResolution resolveParams() {
112108
if (childrenResolved() == false) {
113109
return new TypeResolution("Unresolved children");
114110
}
@@ -118,12 +114,7 @@ protected final TypeResolution resolveType() {
118114
}
119115

120116
@Override
121-
public boolean translatable(LucenePushdownPredicates pushdownPredicates) {
122-
return true;
123-
}
124-
125-
@Override
126-
public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
117+
protected Query translate(TranslatorHandler handler) {
127118
var fieldAttribute = Match.fieldAsFieldAttribute(field());
128119

129120
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
@@ -138,6 +129,11 @@ public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHand
138129
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
139130
}
140131

132+
@Override
133+
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
134+
return new Knn(source(), field(), query(), options(), queryBuilder);
135+
}
136+
141137
private Map<String, Object> queryOptions() throws InvalidArgumentException {
142138
if (options() == null) {
143139
return Map.of();
@@ -167,30 +163,31 @@ private static Knn readFrom(StreamInput in) throws IOException {
167163
Source source = Source.readFrom((PlanStreamInput) in);
168164
Expression field = in.readNamedWriteable(Expression.class);
169165
Expression query = in.readNamedWriteable(Expression.class);
170-
Expression options = in.readOptionalNamedWriteable(Expression.class);
166+
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
171167

172-
return new Knn(source, field, query, options);
168+
return new Knn(source, field, query, null, queryBuilder);
173169
}
174170

175171
@Override
176172
public void writeTo(StreamOutput out) throws IOException {
177173
source().writeTo(out);
178174
out.writeNamedWriteable(field());
179175
out.writeNamedWriteable(query());
180-
out.writeOptionalNamedWriteable(options());
176+
out.writeOptionalNamedWriteable(queryBuilder());
181177
}
182178

183179
@Override
184180
public boolean equals(Object o) {
185181
if (o == null || getClass() != o.getClass()) return false;
186182
if (super.equals(o) == false) return false;
187183
Knn knn = (Knn) o;
188-
return Objects.equals(field, knn.field) && Objects.equals(query, knn.query);
184+
return Objects.equals(field, knn.field) && Objects.equals(query(), knn.query())
185+
&& Objects.equals(queryBuilder(), knn.queryBuilder());
189186
}
190187

191188
@Override
192189
public int hashCode() {
193-
return Objects.hash(super.hashCode(), field, query);
190+
return Objects.hash(field(), query(), queryBuilder());
194191
}
195192

196193
}

0 commit comments

Comments
 (0)