|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.esql.vector; |
9 | 9 |
|
| 10 | +import com.carrotsearch.randomizedtesting.annotations.Name; |
| 11 | +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; |
| 12 | + |
10 | 13 | import org.apache.lucene.index.VectorSimilarityFunction; |
11 | 14 | import org.elasticsearch.action.index.IndexRequestBuilder; |
12 | 15 | import org.elasticsearch.cluster.metadata.IndexMetadata; |
|
20 | 23 |
|
21 | 24 | import java.io.IOException; |
22 | 25 | import java.util.ArrayList; |
| 26 | +import java.util.Arrays; |
23 | 27 | import java.util.List; |
| 28 | +import java.util.Locale; |
24 | 29 |
|
25 | 30 | import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; |
26 | 31 |
|
27 | 32 | public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase { |
28 | 33 |
|
| 34 | + @ParametersFactory |
| 35 | + public static Iterable<Object[]> parameters() throws Exception { |
| 36 | + List<Object[]> params = new ArrayList<>(); |
| 37 | + |
| 38 | + params.add(new Object[] { "v_cosine_similarity", VectorSimilarityFunction.COSINE }); |
| 39 | + |
| 40 | + return params; |
| 41 | + } |
| 42 | + |
| 43 | + private final String functionName; |
| 44 | + private final VectorSimilarityFunction similarityFunction; |
| 45 | + private int numDims; |
| 46 | + |
| 47 | + public VectorSimilarityFunctionsIT( |
| 48 | + @Name("functionName") String functionName, |
| 49 | + @Name("similarityFunction") VectorSimilarityFunction similarityFunction |
| 50 | + ) { |
| 51 | + this.functionName = functionName; |
| 52 | + this.similarityFunction = similarityFunction; |
| 53 | + } |
| 54 | + |
29 | 55 | @SuppressWarnings("unchecked") |
30 | | - public void testCosineSimilarity() { |
31 | | - var query = """ |
| 56 | + public void testCosineSimilarityBetweenVectors() { |
| 57 | + var query = String.format(Locale.ROOT, """ |
32 | 58 | FROM test |
33 | | - | EVAL similarity = v_cosine_similarity(left_vector, right_vector) |
34 | | - | KEEP id, left_vector, right_vector, similarity |
35 | | - """; |
| 59 | + | EVAL similarity = %s(left_vector, right_vector) |
| 60 | + | KEEP left_vector, right_vector, similarity |
| 61 | + """, functionName); |
36 | 62 |
|
37 | 63 | try (var resp = run(query)) { |
38 | 64 | List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp); |
39 | 65 | valuesList.forEach(values -> { |
40 | | - List<Float> leftVector = (List<Float>) values.get(1); |
41 | | - float[] leftScratch = new float[leftVector.size()]; |
42 | | - for (int i = 0; i < leftVector.size(); i++) { |
43 | | - leftScratch[i] = leftVector.get(i); |
44 | | - } |
45 | | - List<Float> rightVector = (List<Float>) values.get(2); |
46 | | - float[] rightScratch = new float[rightVector.size()]; |
47 | | - for (int i = 0; i < rightVector.size(); i++) { |
48 | | - rightScratch[i] = rightVector.get(i); |
49 | | - } |
50 | | - Double similarity = (Double) values.get(3); |
| 66 | + float[] left = readVector((List<Float>) values.get(0)); |
| 67 | + float[] right = readVector((List<Float>) values.get(1)); |
| 68 | + Double similarity = (Double) values.get(2); |
| 69 | + |
51 | 70 | assertNotNull(similarity); |
| 71 | + float expectedSimilarity = similarityFunction.compare(left, right); |
| 72 | + assertEquals(expectedSimilarity, similarity, 0.0001); |
| 73 | + }); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + @SuppressWarnings("unchecked") |
| 78 | + public void testCosineSimilarityBetweenConstantVectorAndField() { |
| 79 | + var randomVector = randomVectorArray(); |
| 80 | + var query = String.format(Locale.ROOT, """ |
| 81 | + FROM test |
| 82 | + | EVAL similarity = %s(left_vector, %s) |
| 83 | + | KEEP left_vector, similarity |
| 84 | + """, functionName, Arrays.toString(randomVector)); |
| 85 | + |
| 86 | + try (var resp = run(query)) { |
| 87 | + List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp); |
| 88 | + valuesList.forEach(values -> { |
| 89 | + float[] left = readVector((List<Float>) values.get(0)); |
| 90 | + Double similarity = (Double) values.get(1); |
52 | 91 |
|
53 | | - float expectedSimilarity = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch); |
| 92 | + assertNotNull(similarity); |
| 93 | + float expectedSimilarity = similarityFunction.compare(left, randomVector); |
54 | 94 | assertEquals(expectedSimilarity, similarity, 0.0001); |
55 | 95 | }); |
56 | 96 | } |
57 | 97 | } |
58 | 98 |
|
| 99 | + private static float[] readVector(List<Float> leftVector) { |
| 100 | + float[] leftScratch = new float[leftVector.size()]; |
| 101 | + for (int i = 0; i < leftVector.size(); i++) { |
| 102 | + leftScratch[i] = leftVector.get(i); |
| 103 | + } |
| 104 | + return leftScratch; |
| 105 | + } |
| 106 | + |
59 | 107 | @Before |
60 | 108 | public void setup() throws IOException { |
61 | 109 | assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); |
62 | 110 |
|
63 | 111 | createIndexWithDenseVector("test"); |
64 | 112 |
|
65 | | - int numDims = randomIntBetween(32, 64) * 2; // min 64, even number |
| 113 | + numDims = randomIntBetween(32, 64) * 2; // min 64, even number |
66 | 114 | int numDocs = randomIntBetween(10, 100); |
67 | 115 | IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; |
68 | 116 | for (int i = 0; i < numDocs; i++) { |
69 | | - List<Float> leftVector = new ArrayList<>(numDims); |
70 | | - for (int j = 0; j < numDims; j++) { |
71 | | - leftVector.add(randomFloat()); |
72 | | - } |
73 | | - List<Float> rightVector = new ArrayList<>(numDims); |
74 | | - for (int j = 0; j < numDims; j++) { |
75 | | - rightVector.add(randomFloat()); |
76 | | - } |
| 117 | + List<Float> leftVector = randomVector(); |
| 118 | + List<Float> rightVector = randomVector(); |
77 | 119 | docs[i] = prepareIndex("test").setId("" + i) |
78 | 120 | .setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector); |
79 | 121 | } |
80 | 122 |
|
81 | 123 | indexRandom(true, docs); |
82 | 124 | } |
83 | 125 |
|
| 126 | + private List<Float> randomVector() { |
| 127 | + assert numDims != 0 : "numDims must be set before calling randomVector()"; |
| 128 | + List<Float> vector = new ArrayList<>(numDims); |
| 129 | + for (int j = 0; j < numDims; j++) { |
| 130 | + vector.add(randomFloat()); |
| 131 | + } |
| 132 | + return vector; |
| 133 | + } |
| 134 | + |
| 135 | + private float[] randomVectorArray() { |
| 136 | + assert numDims != 0 : "numDims must be set before calling randomVectorArray()"; |
| 137 | + float[] vector = new float[numDims]; |
| 138 | + for (int j = 0; j < numDims; j++) { |
| 139 | + vector[j] = randomFloat(); |
| 140 | + } |
| 141 | + return vector; |
| 142 | + } |
| 143 | + |
84 | 144 | private void createIndexWithDenseVector(String indexName) throws IOException { |
85 | 145 | var client = client().admin().indices(); |
86 | 146 | XContentBuilder mapping = XContentFactory.jsonBuilder() |
|
0 commit comments