Skip to content

Commit 290dbe1

Browse files
committed
Ensure casting is done using floats so we get the appropriate blocks and folding done
1 parent 53e96f9 commit 290dbe1

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1613,14 +1613,22 @@ private static Expression castStringLiteral(Expression from, DataType target) {
16131613
}
16141614
}
16151615

1616+
@SuppressWarnings("unchecked")
16161617
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
16171618
List<Expression> args = vectorFunction.arguments();
16181619
List<Expression> newArgs = new ArrayList<>();
16191620
for (Expression arg : args) {
16201621
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
16211622
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
16221623
if (folded instanceof List) {
1623-
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
1624+
// Convert to floats so blocks are created accordingly
1625+
List<Float> floatVector;
1626+
if (arg.dataType() == FLOAT) {
1627+
floatVector = (List<Float>) folded;
1628+
} else {
1629+
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
1630+
}
1631+
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
16241632
newArgs.add(denseVector);
16251633
continue;
16261634
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.elasticsearch.xpack.esql.core.expression.Expression;
2525
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2626
import org.elasticsearch.xpack.esql.core.expression.Literal;
27-
import org.elasticsearch.xpack.esql.core.type.DataType;
2827
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2928
import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper;
3029
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
@@ -249,11 +248,7 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions
249248
if (multiValue.isEmpty()) {
250249
return blockFactory.newConstantNullBlock(positions);
251250
}
252-
// dense_vector create internally float values, even if they are specified as doubles
253-
ElementType elementType = lit.dataType() == DataType.DENSE_VECTOR
254-
? ElementType.FLOAT
255-
: ElementType.fromJava(multiValue.get(0).getClass());
256-
var wrapper = BlockUtils.wrapperFor(blockFactory, elementType, positions);
251+
var wrapper = BlockUtils.wrapperFor(blockFactory, ElementType.fromJava(multiValue.get(0).getClass()), positions);
257252
for (int i = 0; i < positions; i++) {
258253
wrapper.accept(multiValue);
259254
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,16 +2387,17 @@ public void testDenseVectorImplicitCastingKnn() {
23872387
var field = knn.field();
23882388
var queryVector = as(knn.query(), Literal.class);
23892389
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
2390-
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
2390+
assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));
23912391
}
23922392

23932393
public void testDenseVectorImplicitCastingSimilarityFunctions() {
23942394
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2395-
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])");
2395+
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
2396+
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
23962397
}
23972398
}
23982399

2399-
private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction) {
2400+
private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
24002401
var plan = analyze(String.format(Locale.ROOT, """
24012402
from test | eval similarity = %s
24022403
""", similarityFunction), "mapping-dense_vector.json");
@@ -2410,8 +2411,7 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity
24102411
assertEquals("vector", left.name());
24112412
var right = as(similarity.right(), Literal.class);
24122413
assertThat(right.dataType(), is(DENSE_VECTOR));
2413-
assertThat(right.value(), equalTo(List.of(0.342, 0.164, 0.234)));
2414-
;
2414+
assertThat(right.value(), equalTo(expectedElems));
24152415
}
24162416

24172417
public void testNoDenseVectorFailsSimilarityFunction() {

0 commit comments

Comments
 (0)