Skip to content

Commit e8dd1e8

Browse files
committed
Accept text field type for knn
1 parent 4d2d91f commit e8dd1e8

File tree

2 files changed

+13
-1
lines changed
  • x-pack/plugin
    • esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression
    • esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector

2 files changed

+13
-1
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public TypeResolution and(TypeResolution other) {
5555
return failed ? this : other;
5656
}
5757

58+
public TypeResolution or(TypeResolution other) {
59+
return failed ? other : this;
60+
}
61+
5862
public TypeResolution and(Supplier<TypeResolution> other) {
5963
return failed ? this : other.get();
6064
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@
6767
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
6868
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
6969
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
70+
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
71+
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
7072
import static org.elasticsearch.xpack.esql.expression.Foldables.TypeResolutionValidator.forPreOptimizationValidation;
7173
import static org.elasticsearch.xpack.esql.expression.Foldables.resolveTypeQuery;
7274

7375
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
76+
77+
private static final String[] ACCEPTED_FIELD_TYPES = { "dense_vector", "semantic_text" };
7478
private final Logger log = LogManager.getLogger(getClass());
7579

7680
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
@@ -209,7 +213,11 @@ protected TypeResolution resolveParams() {
209213
}
210214

211215
private TypeResolution resolveField() {
212-
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
216+
return isNotNull(field(), sourceText(), FIRST).and(
217+
isType(field(), dt -> dt == TEXT, sourceText(), FIRST, ACCEPTED_FIELD_TYPES).or(
218+
isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, ACCEPTED_FIELD_TYPES)
219+
)
220+
);
213221
}
214222

215223
private TypeResolution resolveQuery() {

0 commit comments

Comments
 (0)