Skip to content

Commit 8a278ed

Browse files
carlosdelestelasticsearchmachine
andauthored
ESQL - Allow null values in vector similarity functions (elastic#132919)
* Allow null values and return null as result * Fix indentation * [CI] Auto commit changes from spotless * Remove unnecessary return * Fix verifier tests * Fix bwc tests introducing capability --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent c7fcf8b commit 8a278ed

File tree

9 files changed

+148
-87
lines changed

9 files changed

+148
-87
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ similarity:double
7575

7676
avg:double | min:double | max:double
7777
0.832 | 0.5 | 1.0
78+
;
79+
80+
similarityWithNull
81+
required_capability: cosine_vector_similarity_function
82+
required_capability: vector_similarity_functions_support_null
83+
84+
from colors
85+
| eval similarity = v_cosine(rgb_vector, null)
86+
| stats total_null = count(*) where similarity is null
87+
;
88+
89+
total_null:long
90+
59
7891
;
7992

8093
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ old lace | 60563.0
2727
// end::vector-dot-product-result[]
2828
;
2929

30-
similarityAsPartOfExpression
31-
required_capability: dot_product_vector_similarity_function
32-
33-
from colors
34-
| eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3)
35-
| sort score desc, color asc
36-
| limit 10
37-
| keep color, score
38-
;
30+
similarityAsPartOfExpression
31+
required_capability: dot_product_vector_similarity_function
32+
33+
from colors
34+
| eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3)
35+
| sort score desc, color asc
36+
| limit 10
37+
| keep color, score
38+
;
3939

4040
color:text | score:double
4141
azure | 32513.75
@@ -62,18 +62,32 @@ similarity:double
6262
4.5
6363
;
6464

65-
similarityWithStats
66-
required_capability: dot_product_vector_similarity_function
67-
68-
from colors
69-
| eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3)
70-
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71-
;
65+
similarityWithStats
66+
required_capability: dot_product_vector_similarity_function
67+
68+
from colors
69+
| eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3)
70+
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71+
;
7272

7373
avg:double | min:double | max:double
7474
39519.017 | 0.5 | 65025.5
7575
;
7676

77+
similarityWithNull
78+
required_capability: dot_product_vector_similarity_function
79+
required_capability: vector_similarity_functions_support_null
80+
81+
from colors
82+
| eval similarity = v_dot_product(rgb_vector, null)
83+
| stats total_null = count(*) where similarity is null
84+
;
85+
86+
total_null:long
87+
59
88+
;
89+
90+
7791
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
7892
similarityWithRow-Ignore
7993
required_capability: dot_product_vector_similarity_function

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ gold | 550.0
2727
// end::vector-l1-norm-result[]
2828
;
2929

30-
similarityAsPartOfExpression
31-
required_capability: l1_norm_vector_similarity_function
32-
33-
from colors
34-
| eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3)
35-
| sort score desc, color asc
36-
| limit 10
37-
| keep color, score
38-
;
30+
similarityAsPartOfExpression
31+
required_capability: l1_norm_vector_similarity_function
32+
33+
from colors
34+
| eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3)
35+
| sort score desc, color asc
36+
| limit 10
37+
| keep color, score
38+
;
3939

4040
color:text | score:double
4141
red | 383.5
@@ -62,18 +62,31 @@ similarity:double
6262
3.0
6363
;
6464

65-
similarityWithStats
66-
required_capability: l1_norm_vector_similarity_function
67-
68-
from colors
69-
| eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3)
70-
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71-
;
65+
similarityWithStats
66+
required_capability: l1_norm_vector_similarity_function
67+
68+
from colors
69+
| eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3)
70+
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71+
;
7272

7373
avg:double | min:double | max:double
7474
391.254 | 0.0 | 765.0
7575
;
7676

77+
similarityWithNull
78+
required_capability: l1_norm_vector_similarity_function
79+
required_capability: vector_similarity_functions_support_null
80+
81+
from colors
82+
| eval similarity = v_l1_norm(rgb_vector, null)
83+
| stats total_null = count(*) where similarity is null
84+
;
85+
86+
total_null:long
87+
59
88+
;
89+
7790
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
7891
similarityWithRow-Ignore
7992
required_capability: l1_norm_vector_similarity_function

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ tomato | 351.0227966308594
3030
similarityAsPartOfExpression
3131
required_capability: l2_norm_vector_similarity_function
3232

33-
from colors
34-
| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3)
35-
| sort score desc, color asc
36-
| limit 10
37-
| keep color, score
38-
;
33+
from colors
34+
| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3)
35+
| sort score desc, color asc
36+
| limit 10
37+
| keep color, score
38+
;
3939

4040
color:text | score:double
4141
red | 221.836
@@ -62,18 +62,31 @@ similarity:double
6262
1.732
6363
;
6464

65-
similarityWithStats
66-
required_capability: l2_norm_vector_similarity_function
67-
68-
from colors
69-
| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3)
70-
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71-
;
65+
similarityWithStats
66+
required_capability: l2_norm_vector_similarity_function
67+
68+
from colors
69+
| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3)
70+
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
71+
;
7272

7373
avg:double | min:double | max:double
7474
274.974 | 0.0 | 441.673
7575
;
7676

77+
similarityWithNull
78+
required_capability: l2_norm_vector_similarity_function
79+
required_capability: vector_similarity_functions_support_null
80+
81+
from colors
82+
| eval similarity = v_l2_norm(rgb_vector, null)
83+
| stats total_null = count(*) where similarity is null
84+
;
85+
86+
total_null:long
87+
59
88+
;
89+
7790
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
7891
similarityWithRow-Ignore
7992
required_capability: l2_norm_vector_similarity_function

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,13 @@ public void testSimilarityBetweenVectors() {
102102
float[] left = readVector((List<Float>) values.get(0));
103103
float[] right = readVector((List<Float>) values.get(1));
104104
Double similarity = (Double) values.get(2);
105-
106-
assertNotNull(similarity);
107-
float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
108-
assertEquals(expectedSimilarity, similarity, 0.0001);
105+
if (left == null || right == null) {
106+
assertNull(similarity);
107+
} else {
108+
assertNotNull(similarity);
109+
float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
110+
assertEquals(expectedSimilarity, similarity, 0.0001);
111+
}
109112
});
110113
}
111114
}
@@ -124,10 +127,13 @@ public void testSimilarityBetweenConstantVectorAndField() {
124127
valuesList.forEach(values -> {
125128
float[] left = readVector((List<Float>) values.get(0));
126129
Double similarity = (Double) values.get(1);
127-
128-
assertNotNull(similarity);
129-
float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
130-
assertEquals(expectedSimilarity, similarity, 0.0001);
130+
if (left == null) {
131+
assertNull(similarity);
132+
} else {
133+
assertNotNull(similarity);
134+
float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
135+
assertEquals(expectedSimilarity, similarity, 0.0001);
136+
}
131137
});
132138
}
133139
}
@@ -159,13 +165,20 @@ public void testSimilarityBetweenConstantVectors() {
159165
assertEquals(1, valuesList.size());
160166

161167
Double similarity = (Double) valuesList.get(0).get(0);
162-
assertNotNull(similarity);
163-
float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
164-
assertEquals(expectedSimilarity, similarity, 0.0001);
168+
if (vectorLeft == null || vectorRight == null) {
169+
assertNull(similarity);
170+
} else {
171+
assertNotNull(similarity);
172+
float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
173+
assertEquals(expectedSimilarity, similarity, 0.0001);
174+
}
165175
}
166176
}
167177

168178
private static float[] readVector(List<Float> leftVector) {
179+
if (leftVector == null) {
180+
return null;
181+
}
169182
float[] leftScratch = new float[leftVector.size()];
170183
for (int i = 0; i < leftVector.size(); i++) {
171184
leftScratch[i] = leftVector.get(i);
@@ -194,6 +207,9 @@ public void setup() throws IOException {
194207

195208
private List<Float> randomVector() {
196209
assert numDims != 0 : "numDims must be set before calling randomVector()";
210+
if (rarely()) {
211+
return null;
212+
}
197213
List<Float> vector = new ArrayList<>(numDims);
198214
for (int j = 0; j < numDims; j++) {
199215
vector.add(randomFloat());
@@ -203,7 +219,7 @@ private List<Float> randomVector() {
203219

204220
private float[] randomVectorArray() {
205221
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
206-
return randomVectorArray(numDims);
222+
return rarely() ? null : randomVectorArray(numDims);
207223
}
208224

209225
private static float[] randomVectorArray(int dimensions) {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,12 @@ public enum Cap {
13591359
/**
13601360
* Byte elements dense vector field type support.
13611361
*/
1362-
DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG);
1362+
DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS(EsqlCorePlugin.DENSE_VECTOR_FEATURE_FLAG),
1363+
1364+
/**
1365+
* Support null elements on vector similarity functions
1366+
*/
1367+
VECTOR_SIMILARITY_FUNCTIONS_SUPPORT_NULL;
13631368

13641369
private final boolean enabled;
13651370

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.compute.data.Block;
12-
import org.elasticsearch.compute.data.DoubleVector;
12+
import org.elasticsearch.compute.data.DoubleBlock;
1313
import org.elasticsearch.compute.data.FloatBlock;
1414
import org.elasticsearch.compute.data.Page;
1515
import org.elasticsearch.compute.operator.DriverContext;
@@ -27,7 +27,6 @@
2727

2828
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
2929
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
30-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
3130
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
3231
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
3332

@@ -59,9 +58,7 @@ protected TypeResolution resolveType() {
5958
}
6059

6160
private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
62-
return isNotNull(param, sourceText(), paramOrdinal).and(
63-
isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector")
64-
);
61+
return isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector");
6562
}
6663

6764
/**
@@ -124,14 +121,14 @@ public Block eval(Page page) {
124121

125122
float[] leftScratch = new float[dimensions];
126123
float[] rightScratch = new float[dimensions];
127-
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
124+
try (DoubleBlock.Builder builder = context.blockFactory().newDoubleBlockBuilder(positionCount * dimensions)) {
128125
for (int p = 0; p < positionCount; p++) {
129126
int dimsLeft = leftBlock.getValueCount(p);
130127
int dimsRight = rightBlock.getValueCount(p);
131128

132129
if (dimsLeft == 0 || dimsRight == 0) {
133-
// A null value on the left or right vector. Similarity is 0
134-
builder.appendDouble(0.0);
130+
// A null value on the left or right vector. Similarity is null
131+
builder.appendNull();
135132
continue;
136133
} else if (dimsLeft != dimsRight) {
137134
throw new EsqlClientException(
@@ -145,7 +142,7 @@ public Block eval(Page page) {
145142
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
146143
builder.appendDouble(result);
147144
}
148-
return builder.build().asBlock();
145+
return builder.build();
149146
}
150147
}
151148
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,28 +2484,25 @@ private void checkFullTextFunctionsInStats(String functionInvocation) {
24842484

24852485
public void testVectorSimilarityFunctionsNullArgs() throws Exception {
24862486
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2487-
checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first");
2488-
checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second");
2487+
checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)");
2488+
checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)");
24892489
}
24902490
if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2491-
checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)", "first");
2492-
checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)", "second");
2491+
checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)");
2492+
checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)");
24932493
}
24942494
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2495-
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first");
2496-
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second");
2495+
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)");
2496+
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)");
24972497
}
24982498
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2499-
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)", "first");
2500-
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)", "second");
2499+
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)");
2500+
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)");
25012501
}
25022502
}
25032503

2504-
private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
2505-
assertThat(
2506-
error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer),
2507-
containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]")
2508-
);
2504+
private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation) throws Exception {
2505+
query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer);
25092506
}
25102507

25112508
private void query(String query) {

0 commit comments

Comments
 (0)