Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ similarity:double

avg:double | min:double | max:double
0.832 | 0.5 | 1.0
;

similarityWithNull
required_capability: cosine_vector_similarity_function

from colors
| eval similarity = v_cosine(rgb_vector, null)
| stats total_null = count(*) where similarity is null
;

total_null:long
59
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ old lace | 60563.0
// end::vector-dot-product-result[]
;

similarityAsPartOfExpression
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated changes - indenting was wrong so I fixed it on all test files

required_capability: dot_product_vector_similarity_function
from colors
| eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;
similarityAsPartOfExpression
required_capability: dot_product_vector_similarity_function

from colors
| eval score = round((1 + v_dot_product(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
azure | 32513.75
Expand All @@ -62,18 +62,31 @@ similarity:double
4.5
;

similarityWithStats
required_capability: dot_product_vector_similarity_function
from colors
| eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;
similarityWithStats
required_capability: dot_product_vector_similarity_function

from colors
| eval similarity = round(v_dot_product(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
39519.017 | 0.5 | 65025.5
;

similarityWithNull
required_capability: dot_product_vector_similarity_function

from colors
| eval similarity = v_dot_product(rgb_vector, null)
| stats total_null = count(*) where similarity is null
;

total_null:long
59
;


# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: dot_product_vector_similarity_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ avg:double | min:double | max:double
391.254 | 0.0 | 765.0
;

similarityWithNull
required_capability: l1_norm_vector_similarity_function

from colors
| eval similarity = v_l1_norm(rgb_vector, null)
| stats total_null = count(*) where similarity is null
;

total_null:long
59
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: l1_norm_vector_similarity_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ tomato | 351.0227966308594
similarityAsPartOfExpression
required_capability: l2_norm_vector_similarity_function

from colors
| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;
from colors
| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
red | 221.836
Expand All @@ -62,18 +62,30 @@ similarity:double
1.732
;

similarityWithStats
required_capability: l2_norm_vector_similarity_function
from colors
| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;
similarityWithStats
required_capability: l2_norm_vector_similarity_function

from colors
| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
274.974 | 0.0 | 441.673
;

similarityWithNull
required_capability: l2_norm_vector_similarity_function

from colors
| eval similarity = v_l2_norm(rgb_vector, null)
| stats total_null = count(*) where similarity is null
;

total_null:long
59
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: l2_norm_vector_similarity_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,13 @@ public void testSimilarityBetweenVectors() {
float[] left = readVector((List<Float>) values.get(0));
float[] right = readVector((List<Float>) values.get(1));
Double similarity = (Double) values.get(2);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
if (left == null || right == null) {
assertNull(similarity);
} else {
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
});
}
}
Expand All @@ -124,10 +127,14 @@ public void testSimilarityBetweenConstantVectorAndField() {
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
Double similarity = (Double) values.get(1);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
if (left == null) {
assertNull(similarity);
return;
} else {
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
});
}
}
Expand Down Expand Up @@ -159,13 +166,20 @@ public void testSimilarityBetweenConstantVectors() {
assertEquals(1, valuesList.size());

Double similarity = (Double) valuesList.get(0).get(0);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
if (vectorLeft == null || vectorRight == null) {
assertNull(similarity);
} else {
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
}
}

private static float[] readVector(List<Float> leftVector) {
if (leftVector == null) {
return null;
}
float[] leftScratch = new float[leftVector.size()];
for (int i = 0; i < leftVector.size(); i++) {
leftScratch[i] = leftVector.get(i);
Expand Down Expand Up @@ -194,6 +208,9 @@ public void setup() throws IOException {

private List<Float> randomVector() {
assert numDims != 0 : "numDims must be set before calling randomVector()";
if (rarely()) {
return null;
}
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
Expand All @@ -203,7 +220,7 @@ private List<Float> randomVector() {

private float[] randomVectorArray() {
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
return randomVectorArray(numDims);
return rarely() ? null : randomVectorArray(numDims);
}

private static float[] randomVectorArray(int dimensions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.Page;
Expand Down Expand Up @@ -59,9 +60,7 @@ protected TypeResolution resolveType() {
}

private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
return isNotNull(param, sourceText(), paramOrdinal).and(
isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector")
);
return isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector");
}

/**
Expand Down Expand Up @@ -124,14 +123,14 @@ public Block eval(Page page) {

float[] leftScratch = new float[dimensions];
float[] rightScratch = new float[dimensions];
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
try (DoubleBlock.Builder builder = context.blockFactory().newDoubleBlockBuilder(positionCount * dimensions)) {
for (int p = 0; p < positionCount; p++) {
int dimsLeft = leftBlock.getValueCount(p);
int dimsRight = rightBlock.getValueCount(p);

if (dimsLeft == 0 || dimsRight == 0) {
// A null value on the left or right vector. Similarity is 0
builder.appendDouble(0.0);
// A null value on the left or right vector. Similarity is null
builder.appendNull();
continue;
} else if (dimsLeft != dimsRight) {
throw new EsqlClientException(
Expand All @@ -145,7 +144,7 @@ public Block eval(Page page) {
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
builder.appendDouble(result);
}
return builder.build().asBlock();
return builder.build();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.hamcrest.Matcher;
import org.junit.Before;

import java.util.ArrayList;
Expand Down Expand Up @@ -93,10 +92,4 @@ private static List<Float> randomDenseVector(int dimensions) {
}
return vector;
}

@Override
protected Matcher<Object> allNullsMatcher() {
// A null value on the left or right vector. Similarity is 0
return equalTo(0.0);
}
}