Skip to content

Commit f7a9cb5

Browse files
authored
[ES|QL] fixing bug when handling 1d literal vectors (#136891)
1 parent 6beabf4 commit f7a9cb5

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

docs/changelog/136891.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 136891
2+
summary: Fixing bug when handling 1d literal vectors
3+
area: ES|QL
4+
type: bug
5+
issues:
6+
- 136364

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,27 @@ public void testSimilarityBetweenConstantVectorAndField() {
142142
}
143143
}
144144

145+
@SuppressWarnings("unchecked")
146+
public void testSimilarityWithOneDimVector() {
147+
final float oneDimFloat = randomFloat();
148+
final float[] randomVector = randomVectorArray(1);
149+
var query = String.format(Locale.ROOT, """
150+
ROW left_vector = to_dense_vector(%s)
151+
| EVAL similarity = %s(left_vector, %f)
152+
| KEEP left_vector, similarity
153+
""", Arrays.toString(randomVector), functionName, oneDimFloat);
154+
try (var resp = run(query)) {
155+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
156+
valuesList.forEach(values -> {
157+
float[] left = new float[] { (float) values.get(0) };
158+
Double similarity = (Double) values.get(1);
159+
assertNotNull(similarity);
160+
float expectedSimilarity = similarityFunction.calculateSimilarity(left, new float[] { oneDimFloat });
161+
assertEquals(expectedSimilarity, similarity, 0.0001);
162+
});
163+
}
164+
}
165+
145166
public void testDifferentDimensions() {
146167
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
147168
var query = String.format(Locale.ROOT, """

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ private static VectorValueProviderFactory getVectorValueProviderFactory(
100100
EvaluatorMapper.ToEvaluator toEvaluator
101101
) {
102102
if (expression instanceof Literal) {
103-
return new ConstantVectorProvider.Factory((ArrayList<Float>) ((Literal) expression).value());
103+
ArrayList<Float> constantVector;
104+
if (((Literal) expression).value() instanceof Float) {
105+
constantVector = new ArrayList<>(List.of((Float) ((Literal) expression).value()));
106+
} else {
107+
constantVector = (ArrayList<Float>) ((Literal) expression).value();
108+
}
109+
return new ConstantVectorProvider.Factory(constantVector);
104110
} else {
105111
return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression));
106112
}

0 commit comments

Comments
 (0)