diff --git a/docs/changelog/136891.yaml b/docs/changelog/136891.yaml new file mode 100644 index 0000000000000..4e5a566b92d91 --- /dev/null +++ b/docs/changelog/136891.yaml @@ -0,0 +1,6 @@ +pr: 136891 +summary: Fixing bug when handling 1d literal vectors +area: ES|QL +type: bug +issues: + - 136364 diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index f0fc8955f0b39..9beaf775ed6ef 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -142,6 +142,27 @@ public void testSimilarityBetweenConstantVectorAndField() { } } + @SuppressWarnings("unchecked") + public void testSimilarityWithOneDimVector() { + final float oneDimFloat = randomFloat(); + final float[] randomVector = randomVectorArray(1); + var query = String.format(Locale.ROOT, """ + ROW left_vector = to_dense_vector(%s) + | EVAL similarity = %s(left_vector, %f) + | KEEP left_vector, similarity + """, Arrays.toString(randomVector), functionName, oneDimFloat); + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + valuesList.forEach(values -> { + float[] left = new float[] { (float) values.get(0) }; + Double similarity = (Double) values.get(1); + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.calculateSimilarity(left, new float[] { oneDimFloat }); + assertEquals(expectedSimilarity, similarity, 0.0001); + }); + } + } + public void testDifferentDimensions() { var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2)); var query = String.format(Locale.ROOT, """ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 42d215cc3d201..4bc554643d7fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -100,7 +100,13 @@ private static VectorValueProviderFactory getVectorValueProviderFactory( EvaluatorMapper.ToEvaluator toEvaluator ) { if (expression instanceof Literal) { - return new ConstantVectorProvider.Factory((ArrayList) ((Literal) expression).value()); + ArrayList constantVector; + if (((Literal) expression).value() instanceof Float) { + constantVector = new ArrayList<>(List.of((Float) ((Literal) expression).value())); + } else { + constantVector = (ArrayList) ((Literal) expression).value(); + } + return new ConstantVectorProvider.Factory(constantVector); } else { return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression)); }