Skip to content

Commit ada9dae

Browse files
committed
Create vectorArgumentsCount() on VectorFunction to determine what params to cast
1 parent 1fd0cc0 commit ada9dae

File tree

5 files changed

+28
-3
lines changed

5 files changed

+28
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,8 +1668,11 @@ private static Expression castStringLiteral(Expression from, DataType target) {
16681668
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
16691669
List<Expression> args = vectorFunction.arguments();
16701670
List<Expression> newArgs = new ArrayList<>();
1671-
for (Expression arg : args) {
1672-
if (arg.resolved() && arg.dataType().isNumeric()) {
1671+
// Only the first vector arguments are vectors and considered for casting
1672+
int vectorArgsCount = ((VectorFunction)vectorFunction).vectorArgumentsCount();
1673+
for (int i = 0; i < args.size(); i++) {
1674+
Expression arg = args.get(i);
1675+
if (i < vectorArgsCount && arg.resolved() && arg.dataType().isNumeric()) {
16731676
if (arg.foldable()) {
16741677
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
16751678
if (folded instanceof List) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ public DataType dataType() {
203203
return DataType.BOOLEAN;
204204
}
205205

206+
@Override
207+
public int vectorArgumentsCount() {
208+
return 2;
209+
}
210+
206211
@Override
207212
protected TypeResolution resolveParams() {
208213
return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS));

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Magnitude.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ protected TypeResolution resolveType() {
9696
return isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.FIRST, "dense_vector");
9797
}
9898

99+
@Override
100+
public int vectorArgumentsCount() {
101+
return 1;
102+
}
103+
99104
/**
100105
* Functional interface for evaluating the scalar value of the underlying float array.
101106
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,11 @@
1212
* from multi values to dense_vector field types, so parameters are actually
1313
* processed as dense_vectors in vector functions
1414
*/
15-
public interface VectorFunction {}
15+
public interface VectorFunction {
16+
17+
/**
18+
* Number of arguments that should be treated as vectors. The first vectorArgumentsCount() arguments will be implicitly casted as
19+
* dense_vector according to the value returned of this method
20+
*/
21+
int vectorArgumentsCount();
22+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ public Object fold(FoldContext ctx) {
7777
return EvaluatorMapper.super.fold(source(), ctx);
7878
}
7979

80+
@Override
81+
public int vectorArgumentsCount() {
82+
return 2;
83+
}
84+
8085
@Override
8186
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
8287
return new SimilarityEvaluatorFactory(

0 commit comments

Comments
 (0)