Skip to content

Commit 0f58f24

Browse files
committed
Implicit casting
1 parent 8e9b280 commit 0f58f24

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
6666
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
6767
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
68+
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
6869
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
6970
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
7071
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
@@ -1365,9 +1366,11 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func
13651366
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
13661367
return processBinaryOperator((BinaryOperator) f);
13671368
}
1369+
if (f instanceof VectorFunction vectorFunction) {
1370+
return processVectorFunction(f);
1371+
}
13681372
return f;
13691373
}
1370-
13711374
private static Expression processScalarOrGroupingFunction(
13721375
org.elasticsearch.xpack.esql.core.expression.function.Function f,
13731376
EsqlFunctionRegistry registry
@@ -1564,6 +1567,25 @@ private static Expression castStringLiteral(Expression from, DataType target) {
15641567
return unresolvedAttribute(from, target.toString(), e);
15651568
}
15661569
}
1570+
1571+
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
1572+
List<Expression> args = vectorFunction.arguments();
1573+
List<Expression> newArgs = new ArrayList<>();
1574+
for (Expression arg : args) {
1575+
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
1576+
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
1577+
if (folded instanceof List) {
1578+
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
1579+
newArgs.add(denseVector);
1580+
continue;
1581+
}
1582+
}
1583+
newArgs.add(arg);
1584+
}
1585+
1586+
return vectorFunction.replaceChildren(newArgs);
1587+
}
1588+
15671589
}
15681590

15691591
/**

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
5353
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
5454

55-
public class Knn extends FullTextFunction implements OptionalArgument {
55+
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction {
5656

5757
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
5858

@@ -110,7 +110,7 @@ protected TypeResolution resolveParams() {
110110
}
111111

112112
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"))
113-
.and(TypeResolutions.isNumeric(query(), sourceText(), TypeResolutions.ParamOrdinal.SECOND));
113+
.and(isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND));
114114
}
115115

116116
@Override
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
/**
11+
* Marker interface for vector functions. Makes possible to do implicit casting
12+
* from multi values to dense_vector field types, so parameters are actually
13+
* processed as dense_vectors in vector functions
14+
*/
15+
public interface VectorFunction {
16+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
5050
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
5151
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
52+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
5253
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
5354
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
5455
import org.elasticsearch.xpack.esql.index.EsIndex;
@@ -2347,6 +2348,24 @@ public void testImplicitCasting() {
23472348
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
23482349
}
23492350

2351+
public void testDenseVectorImplicitCasting() {
2352+
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
2353+
2354+
var plan = analyze("""
2355+
from test | where knn(vector, [0.342, 0.164, 0.234])
2356+
""",
2357+
"mapping-dense_vector.json");
2358+
2359+
var limit = as(plan, Limit.class);
2360+
var filter = as(limit.child(), Filter.class);
2361+
var knn = as(filter.condition(), Knn.class);
2362+
var field = knn.field();
2363+
var queryVector = as(knn.query(), Literal.class);
2364+
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
2365+
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
2366+
}
2367+
2368+
23502369
public void testRateRequiresCounterTypes() {
23512370
assumeTrue("rate requires snapshot builds", Build.current().isSnapshot());
23522371
Analyzer analyzer = analyzer(tsdbIndexResolution());

0 commit comments

Comments
 (0)