diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec index d9e1ff408c739..46d80609a06bf 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec @@ -75,6 +75,19 @@ similarity:double avg:double | min:double | max:double 0.832 | 0.5 | 1.0 +; + +similarityWithNull +required_capability: cosine_vector_similarity_function +required_capability: vector_similarity_functions_support_null + +from colors +| eval similarity = v_cosine(rgb_vector, null) +| stats total_null = count(*) where similarity is null +; + +total_null:long +59 ; # TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec index 65bc4b9a365ce..b6d32b5ae651b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec @@ -27,15 +27,15 @@ old lace | 60563.0 // end::vector-dot-product-result[] ; - similarityAsPartOfExpression - required_capability: dot_product_vector_similarity_function - - from colors - | eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3) - | sort score desc, color asc - | limit 10 - | keep color, score - ; +similarityAsPartOfExpression +required_capability: dot_product_vector_similarity_function + +from colors +| eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3) +| sort score desc, color asc +| limit 10 +| keep color, score +; color:text | score:double azure | 32513.75 @@ -62,18 +62,32 @@ similarity:double 4.5 ; - similarityWithStats - required_capability: dot_product_vector_similarity_function - - from colors - | eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3) - | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) - ; +similarityWithStats +required_capability: dot_product_vector_similarity_function + +from colors +| eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3) +| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) +; avg:double | min:double | max:double 39519.017 | 0.5 | 65025.5 ; +similarityWithNull +required_capability: dot_product_vector_similarity_function +required_capability: vector_similarity_functions_support_null + +from colors +| eval similarity = v_dot_product(rgb_vector, null) +| stats total_null = count(*) where similarity is null +; + +total_null:long +59 +; + + # TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector similarityWithRow-Ignore required_capability: dot_product_vector_similarity_function diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec index 4a7b4e004d117..53f550dd4fe1f 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec @@ -27,15 +27,15 @@ gold | 550.0 // end::vector-l1-norm-result[] ; - similarityAsPartOfExpression - required_capability: l1_norm_vector_similarity_function - - from colors - | eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3) - | sort score desc, color asc - | limit 10 - | keep color, score - ; +similarityAsPartOfExpression +required_capability: l1_norm_vector_similarity_function + +from colors +| eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3) +| sort score desc, color asc +| limit 10 +| keep color, score +; color:text | score:double red | 383.5 @@ -62,18 +62,31 @@ similarity:double 3.0 ; - similarityWithStats - required_capability: l1_norm_vector_similarity_function - - from colors - | eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3) - | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) - ; +similarityWithStats +required_capability: l1_norm_vector_similarity_function + +from colors +| eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3) +| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) +; avg:double | min:double | max:double 391.254 | 0.0 | 765.0 ; +similarityWithNull +required_capability: l1_norm_vector_similarity_function +required_capability: vector_similarity_functions_support_null + +from colors +| eval similarity = v_l1_norm(rgb_vector, null) +| stats total_null = count(*) where similarity is null +; + +total_null:long +59 +; + # TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector similarityWithRow-Ignore required_capability: l1_norm_vector_similarity_function diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec index c623a21ca6885..03a094ed93cad 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec @@ -30,12 +30,12 @@ tomato | 351.0227966308594 similarityAsPartOfExpression required_capability: l2_norm_vector_similarity_function - from colors - | eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3) - | sort score desc, color asc - | limit 10 - | keep color, score - ; +from colors +| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3) +| sort score desc, color asc +| limit 10 +| keep color, score +; color:text | score:double red | 221.836 @@ -62,18 +62,31 @@ similarity:double 1.732 ; - similarityWithStats - required_capability: l2_norm_vector_similarity_function - - from colors - | eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3) - | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) - ; +similarityWithStats +required_capability: l2_norm_vector_similarity_function + +from colors +| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3) +| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) +; avg:double | min:double | max:double 274.974 | 0.0 | 441.673 ; +similarityWithNull +required_capability: l2_norm_vector_similarity_function +required_capability: vector_similarity_functions_support_null + +from colors +| eval similarity = v_l2_norm(rgb_vector, null) +| stats total_null = count(*) where similarity is null +; + +total_null:long +59 +; + # TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector similarityWithRow-Ignore required_capability: l2_norm_vector_similarity_function diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index ccde2623fddea..2d85e3bd7f93c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -102,10 +102,13 @@ public void testSimilarityBetweenVectors() { float[] left = readVector((List) values.get(0)); float[] right = readVector((List) values.get(1)); Double similarity = (Double) values.get(2); - - assertNotNull(similarity); - float expectedSimilarity = similarityFunction.calculateSimilarity(left, right); - assertEquals(expectedSimilarity, similarity, 0.0001); + if (left == null || right == null) { + assertNull(similarity); + } else { + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.calculateSimilarity(left, right); + assertEquals(expectedSimilarity, similarity, 0.0001); + } }); } } @@ -124,10 +127,13 @@ public void testSimilarityBetweenConstantVectorAndField() { valuesList.forEach(values -> { float[] left = readVector((List) values.get(0)); Double similarity = (Double) values.get(1); - - assertNotNull(similarity); - float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector); - assertEquals(expectedSimilarity, similarity, 0.0001); + if (left == null) { + assertNull(similarity); + } else { + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector); + assertEquals(expectedSimilarity, similarity, 0.0001); + } }); } } @@ -159,13 +165,20 @@ public void testSimilarityBetweenConstantVectors() { assertEquals(1, valuesList.size()); Double similarity = (Double) valuesList.get(0).get(0); - assertNotNull(similarity); - float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight); - assertEquals(expectedSimilarity, similarity, 0.0001); + if (vectorLeft == null || vectorRight == null) { + assertNull(similarity); + } else { + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight); + assertEquals(expectedSimilarity, similarity, 0.0001); + } } } private static float[] readVector(List leftVector) { + if (leftVector == null) { + return null; + } float[] leftScratch = new float[leftVector.size()]; for (int i = 0; i < leftVector.size(); i++) { leftScratch[i] = leftVector.get(i); @@ -194,6 +207,9 @@ public void setup() throws IOException { private List randomVector() { assert numDims != 0 : "numDims must be set before calling randomVector()"; + if (rarely()) { + return null; + } List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { vector.add(randomFloat()); @@ -203,7 +219,7 @@ private List randomVector() { private float[] randomVectorArray() { assert numDims != 0 : "numDims must be set before calling randomVectorArray()"; - return randomVectorArray(numDims); + return rarely() ? null : randomVectorArray(numDims); } private static float[] randomVectorArray(int dimensions) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index b71f8f3fe83fc..14a79f54646ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1359,7 +1359,12 @@ public enum Cap { /** * Byte elements dense vector field type support. */ - DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG); + DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG), + + /** + * Support null elements on vector similarity functions + */ + VECTOR_SIMILARITY_FUNCTIONS_SUPPORT_NULL; private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index fc27ae2d876e8..69dcaa17368dc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -9,7 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; @@ -27,7 +27,6 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; @@ -59,9 +58,7 @@ protected TypeResolution resolveType() { } private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) { - return isNotNull(param, sourceText(), paramOrdinal).and( - isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector") - ); + return isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector"); } /** @@ -124,14 +121,14 @@ public Block eval(Page page) { float[] leftScratch = new float[dimensions]; float[] rightScratch = new float[dimensions]; - try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) { + try (DoubleBlock.Builder builder = context.blockFactory().newDoubleBlockBuilder(positionCount * dimensions)) { for (int p = 0; p < positionCount; p++) { int dimsLeft = leftBlock.getValueCount(p); int dimsRight = rightBlock.getValueCount(p); if (dimsLeft == 0 || dimsRight == 0) { - // A null value on the left or right vector. Similarity is 0 - builder.appendDouble(0.0); + // A null value on the left or right vector. Similarity is null + builder.appendNull(); continue; } else if (dimsLeft != dimsRight) { throw new EsqlClientException( @@ -145,7 +142,7 @@ public Block eval(Page page) { float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); builder.appendDouble(result); } - return builder.build().asBlock(); + return builder.build(); } } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 32b4ccb768efe..37d6719ddccfc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2484,28 +2484,25 @@ private void checkFullTextFunctionsInStats(String functionInvocation) { public void testVectorSimilarityFunctionsNullArgs() throws Exception { if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first"); - checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second"); + checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)"); + checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)"); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)", "first"); - checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)", "second"); + checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)"); + checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)"); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first"); - checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second"); + checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)"); + checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)"); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)", "first"); - checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)", "second"); + checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)"); + checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)"); } } - private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception { - assertThat( - error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer), - containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]") - ); + private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation) throws Exception { + query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer); } private void query(String query) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java index 329eba63046f4..791152df5acb0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -12,7 +12,6 @@ import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.hamcrest.Matcher; import org.junit.Before; import java.util.ArrayList; @@ -93,10 +92,4 @@ private static List randomDenseVector(int dimensions) { } return vector; } - - @Override - protected Matcher allNullsMatcher() { - // A null value on the left or right vector. Similarity is 0 - return equalTo(0.0); - } }