Skip to content

Commit f1ddd4c

Browse files
authored
ESQL: dense_vector cosine similarity function (#130641)
1 parent 730308c commit f1ddd4c

File tree

17 files changed

+842
-15
lines changed

17 files changed

+842
-15
lines changed

docs/reference/query-languages/esql/images/functions/v_cosine.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Tests for cosine similarity function
2+
3+
similarityWithVectorField
4+
required_capability: cosine_vector_similarity_function
5+
6+
// tag::vector-cosine-similarity[]
7+
from colors
8+
| where color != "black"
9+
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
10+
| sort similarity desc, color asc
11+
// end::vector-cosine-similarity[]
12+
| limit 10
13+
| keep color, similarity
14+
;
15+
16+
// tag::vector-cosine-similarity-result[]
17+
color:text | similarity:double
18+
cyan | 1.0
19+
teal | 1.0
20+
turquoise | 0.9890533685684204
21+
aqua marine | 0.964962363243103
22+
azure | 0.916246771812439
23+
lavender | 0.9136701822280884
24+
mint cream | 0.9122757911682129
25+
honeydew | 0.9122424125671387
26+
gainsboro | 0.9082483053207397
27+
gray | 0.9082483053207397
28+
// end::vector-cosine-similarity-result[]
29+
;
30+
31+
similarityAsPartOfExpression
32+
required_capability: cosine_vector_similarity_function
33+
34+
from colors
35+
| where color != "black"
36+
| eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3)
37+
| sort score desc, color asc
38+
| limit 10
39+
| keep color, score
40+
;
41+
42+
color:text | score:double
43+
cyan | 1.5
44+
teal | 1.5
45+
turquoise | 1.495
46+
aqua marine | 1.482
47+
azure | 1.458
48+
lavender | 1.457
49+
honeydew | 1.456
50+
mint cream | 1.456
51+
gainsboro | 1.454
52+
gray | 1.454
53+
;
54+
55+
similarityWithLiteralVectors
56+
required_capability: cosine_vector_similarity_function
57+
58+
row a = 1
59+
| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3)
60+
| keep similarity
61+
;
62+
63+
similarity:double
64+
0.978
65+
;
66+
67+
similarityWithStats
68+
required_capability: cosine_vector_similarity_function
69+
70+
from colors
71+
| where color != "black"
72+
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
73+
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
74+
;
75+
76+
avg:double | min:double | max:double
77+
0.832 | 0.5 | 1.0
78+
;
79+
80+
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
81+
similarityWithRow-Ignore
82+
required_capability: cosine_vector_similarity_function
83+
84+
row vector = [1, 2, 3]
85+
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
86+
| sort similarity desc, color asc
87+
| limit 10
88+
| keep color, similarity
89+
;
90+
91+
similarity:double
92+
0.978
93+
;
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.vector;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
12+
13+
import org.apache.lucene.index.VectorSimilarityFunction;
14+
import org.elasticsearch.action.index.IndexRequestBuilder;
15+
import org.elasticsearch.cluster.metadata.IndexMetadata;
16+
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
import org.elasticsearch.xcontent.XContentFactory;
19+
import org.elasticsearch.xpack.esql.EsqlClientException;
20+
import org.elasticsearch.xpack.esql.EsqlTestUtils;
21+
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
22+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
23+
import org.junit.Before;
24+
25+
import java.io.IOException;
26+
import java.util.ArrayList;
27+
import java.util.Arrays;
28+
import java.util.List;
29+
import java.util.Locale;
30+
31+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
32+
33+
public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
34+
35+
@ParametersFactory
36+
public static Iterable<Object[]> parameters() throws Exception {
37+
List<Object[]> params = new ArrayList<>();
38+
39+
params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
40+
41+
return params;
42+
}
43+
44+
private final String functionName;
45+
private final VectorSimilarityFunction similarityFunction;
46+
private int numDims;
47+
48+
public VectorSimilarityFunctionsIT(
49+
@Name("functionName") String functionName,
50+
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
51+
) {
52+
this.functionName = functionName;
53+
this.similarityFunction = similarityFunction;
54+
}
55+
56+
@SuppressWarnings("unchecked")
57+
public void testSimilarityBetweenVectors() {
58+
var query = String.format(Locale.ROOT, """
59+
FROM test
60+
| EVAL similarity = %s(left_vector, right_vector)
61+
| KEEP left_vector, right_vector, similarity
62+
""", functionName);
63+
64+
try (var resp = run(query)) {
65+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
66+
valuesList.forEach(values -> {
67+
float[] left = readVector((List<Float>) values.get(0));
68+
float[] right = readVector((List<Float>) values.get(1));
69+
Double similarity = (Double) values.get(2);
70+
71+
assertNotNull(similarity);
72+
float expectedSimilarity = similarityFunction.compare(left, right);
73+
assertEquals(expectedSimilarity, similarity, 0.0001);
74+
});
75+
}
76+
}
77+
78+
@SuppressWarnings("unchecked")
79+
public void testSimilarityBetweenConstantVectorAndField() {
80+
var randomVector = randomVectorArray();
81+
var query = String.format(Locale.ROOT, """
82+
FROM test
83+
| EVAL similarity = %s(left_vector, %s)
84+
| KEEP left_vector, similarity
85+
""", functionName, Arrays.toString(randomVector));
86+
87+
try (var resp = run(query)) {
88+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
89+
valuesList.forEach(values -> {
90+
float[] left = readVector((List<Float>) values.get(0));
91+
Double similarity = (Double) values.get(1);
92+
93+
assertNotNull(similarity);
94+
float expectedSimilarity = similarityFunction.compare(left, randomVector);
95+
assertEquals(expectedSimilarity, similarity, 0.0001);
96+
});
97+
}
98+
}
99+
100+
public void testDifferentDimensions() {
101+
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
102+
var query = String.format(Locale.ROOT, """
103+
FROM test
104+
| EVAL similarity = %s(left_vector, %s)
105+
| KEEP left_vector, similarity
106+
""", functionName, Arrays.toString(randomVector));
107+
108+
EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
109+
assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
110+
}
111+
112+
@SuppressWarnings("unchecked")
113+
public void testSimilarityBetweenConstantVectors() {
114+
var vectorLeft = randomVectorArray();
115+
var vectorRight = randomVectorArray();
116+
var query = String.format(Locale.ROOT, """
117+
ROW a = 1
118+
| EVAL similarity = %s(%s, %s)
119+
| KEEP similarity
120+
""", functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));
121+
122+
try (var resp = run(query)) {
123+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
124+
assertEquals(1, valuesList.size());
125+
126+
Double similarity = (Double) valuesList.get(0).get(0);
127+
assertNotNull(similarity);
128+
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
129+
assertEquals(expectedSimilarity, similarity, 0.0001);
130+
}
131+
}
132+
133+
private static float[] readVector(List<Float> leftVector) {
134+
float[] leftScratch = new float[leftVector.size()];
135+
for (int i = 0; i < leftVector.size(); i++) {
136+
leftScratch[i] = leftVector.get(i);
137+
}
138+
return leftScratch;
139+
}
140+
141+
@Before
142+
public void setup() throws IOException {
143+
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
144+
145+
createIndexWithDenseVector("test");
146+
147+
numDims = randomIntBetween(32, 64) * 2; // min 64, even number
148+
int numDocs = randomIntBetween(10, 100);
149+
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
150+
for (int i = 0; i < numDocs; i++) {
151+
List<Float> leftVector = randomVector();
152+
List<Float> rightVector = randomVector();
153+
docs[i] = prepareIndex("test").setId("" + i)
154+
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
155+
}
156+
157+
indexRandom(true, docs);
158+
}
159+
160+
private List<Float> randomVector() {
161+
assert numDims != 0 : "numDims must be set before calling randomVector()";
162+
List<Float> vector = new ArrayList<>(numDims);
163+
for (int j = 0; j < numDims; j++) {
164+
vector.add(randomFloat());
165+
}
166+
return vector;
167+
}
168+
169+
private float[] randomVectorArray() {
170+
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
171+
return randomVectorArray(numDims);
172+
}
173+
174+
private static float[] randomVectorArray(int dimensions) {
175+
float[] vector = new float[dimensions];
176+
for (int j = 0; j < dimensions; j++) {
177+
vector[j] = randomFloat();
178+
}
179+
return vector;
180+
}
181+
182+
private void createIndexWithDenseVector(String indexName) throws IOException {
183+
var client = client().admin().indices();
184+
XContentBuilder mapping = XContentFactory.jsonBuilder()
185+
.startObject()
186+
.startObject("properties")
187+
.startObject("id")
188+
.field("type", "integer")
189+
.endObject();
190+
createDenseVectorField(mapping, "left_vector");
191+
createDenseVectorField(mapping, "right_vector");
192+
mapping.endObject().endObject();
193+
Settings.Builder settingsBuilder = Settings.builder()
194+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
195+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
196+
197+
var CreateRequest = client.prepareCreate(indexName)
198+
.setSettings(Settings.builder().put("index.number_of_shards", 1))
199+
.setMapping(mapping)
200+
.setSettings(settingsBuilder.build());
201+
assertAcked(CreateRequest);
202+
}
203+
204+
private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
205+
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
206+
mapping.endObject();
207+
}
208+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,12 @@ public enum Cap {
12541254
* Forbid usage of brackets in unquoted index and enrich policy names
12551255
* https://github.com/elastic/elasticsearch/issues/130378
12561256
*/
1257-
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES;
1257+
NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES,
1258+
1259+
/*
1260+
* Cosine vector similarity function
1261+
*/
1262+
COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
12581263

12591264
private final boolean enabled;
12601265

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,15 +1400,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func
14001400
if (f instanceof In in) {
14011401
return processIn(in);
14021402
}
1403+
if (f instanceof VectorFunction) {
1404+
return processVectorFunction(f);
1405+
}
14031406
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
14041407
return processScalarOrGroupingFunction(f, registry);
14051408
}
14061409
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
14071410
return processBinaryOperator((BinaryOperator) f);
14081411
}
1409-
if (f instanceof VectorFunction vectorFunction) {
1410-
return processVectorFunction(f);
1411-
}
14121412
return f;
14131413
}
14141414

@@ -1613,14 +1613,22 @@ private static Expression castStringLiteral(Expression from, DataType target) {
16131613
}
16141614
}
16151615

1616+
@SuppressWarnings("unchecked")
16161617
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
16171618
List<Expression> args = vectorFunction.arguments();
16181619
List<Expression> newArgs = new ArrayList<>();
16191620
for (Expression arg : args) {
16201621
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
16211622
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
16221623
if (folded instanceof List) {
1623-
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
1624+
// Convert to floats so blocks are created accordingly
1625+
List<Float> floatVector;
1626+
if (arg.dataType() == FLOAT) {
1627+
floatVector = (List<Float>) folded;
1628+
} else {
1629+
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
1630+
}
1631+
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
16241632
newArgs.add(denseVector);
16251633
continue;
16261634
}

0 commit comments

Comments
 (0)