-
Notifications
You must be signed in to change notification settings - Fork 25.6k
ESQL: dense_vector cosine similarity function #130641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0c18eba
31faca1
16542b5
5098373
b441d60
4508154
58bd1c0
fc88621
597741d
85e2426
6cc2115
1a9b44c
a53eab3
8c80b72
7517557
58e6ac7
ebafdf8
3ab9a72
b1d6f85
4b0b772
53e96f9
290dbe1
312a727
08364ee
fb3bec7
3d82d86
7601522
2a9e322
dd09bf8
2400b7a
2bdb35d
81f26ec
707b3c2
7fcdb36
5e9bcd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Tests for cosine similarity function | ||
|
|
||
| similarityWithVectorField | ||
| required_capability: cosine_vector_similarity_function | ||
|
|
||
| // tag::vector-cosine-similarity[] | ||
| from colors | ||
| | where color != "black" | ||
| | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) | ||
| | sort similarity desc, color asc | ||
| // end::vector-cosine-similarity[] | ||
| | limit 10 | ||
| | keep color, similarity | ||
| ; | ||
|
|
||
| // tag::vector-cosine-similarity-result[] | ||
| color:text | similarity:double | ||
| cyan | 1.0 | ||
| teal | 1.0 | ||
| turquoise | 0.9890533685684204 | ||
| aqua marine | 0.964962363243103 | ||
| azure | 0.916246771812439 | ||
| lavender | 0.9136701822280884 | ||
| mint cream | 0.9122757911682129 | ||
| honeydew | 0.9122424125671387 | ||
| gainsboro | 0.9082483053207397 | ||
| gray | 0.9082483053207397 | ||
| // end::vector-cosine-similarity-result[] | ||
| ; | ||
|
|
||
| similarityAsPartOfExpression | ||
| required_capability: cosine_vector_similarity_function | ||
|
|
||
| from colors | ||
| | where color != "black" | ||
| | eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3) | ||
| | sort score desc, color asc | ||
| | limit 10 | ||
| | keep color, score | ||
| ; | ||
|
|
||
| color:text | score:double | ||
| cyan | 1.5 | ||
| teal | 1.5 | ||
| turquoise | 1.495 | ||
| aqua marine | 1.482 | ||
| azure | 1.458 | ||
| lavender | 1.457 | ||
| honeydew | 1.456 | ||
| mint cream | 1.456 | ||
| gainsboro | 1.454 | ||
| gray | 1.454 | ||
| ; | ||
|
|
||
| similarityWithLiteralVectors | ||
| required_capability: cosine_vector_similarity_function | ||
|
|
||
| row a = 1 | ||
| | eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3) | ||
| | keep similarity | ||
| ; | ||
|
|
||
| similarity:double | ||
| 0.978 | ||
| ; | ||
|
|
||
| similarityWithStats | ||
| required_capability: cosine_vector_similarity_function | ||
|
|
||
| from colors | ||
| | where color != "black" | ||
| | eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3) | ||
| | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) | ||
| ; | ||
|
|
||
| avg:double | min:double | max:double | ||
| 0.832 | 0.5 | 1.0 | ||
| ; | ||
|
|
||
| # TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector | ||
| similarityWithRow-Ignore | ||
| required_capability: cosine_vector_similarity_function | ||
|
|
||
| row vector = [1, 2, 3] | ||
| | eval similarity = round(v_cosine(vector, [0, 1, 2]), 3) | ||
| | sort similarity desc, color asc | ||
| | limit 10 | ||
| | keep color, similarity | ||
| ; | ||
|
|
||
| similarity:double | ||
| 0.978 | ||
| ; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.esql.vector; | ||
|
|
||
| import com.carrotsearch.randomizedtesting.annotations.Name; | ||
| import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; | ||
|
|
||
| import org.apache.lucene.index.VectorSimilarityFunction; | ||
| import org.elasticsearch.action.index.IndexRequestBuilder; | ||
| import org.elasticsearch.cluster.metadata.IndexMetadata; | ||
| import org.elasticsearch.common.settings.Settings; | ||
| import org.elasticsearch.xcontent.XContentBuilder; | ||
| import org.elasticsearch.xcontent.XContentFactory; | ||
| import org.elasticsearch.xpack.esql.EsqlClientException; | ||
| import org.elasticsearch.xpack.esql.EsqlTestUtils; | ||
| import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; | ||
| import org.elasticsearch.xpack.esql.action.EsqlCapabilities; | ||
| import org.junit.Before; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayList; | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
| import java.util.Locale; | ||
|
|
||
| import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; | ||
|
|
||
| public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase { | ||
|
|
||
| @ParametersFactory | ||
| public static Iterable<Object[]> parameters() throws Exception { | ||
| List<Object[]> params = new ArrayList<>(); | ||
|
|
||
| params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE }); | ||
|
|
||
| return params; | ||
| } | ||
|
|
||
| private final String functionName; | ||
| private final VectorSimilarityFunction similarityFunction; | ||
| private int numDims; | ||
|
|
||
| public VectorSimilarityFunctionsIT( | ||
| @Name("functionName") String functionName, | ||
| @Name("similarityFunction") VectorSimilarityFunction similarityFunction | ||
| ) { | ||
| this.functionName = functionName; | ||
| this.similarityFunction = similarityFunction; | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| public void testSimilarityBetweenVectors() { | ||
| var query = String.format(Locale.ROOT, """ | ||
| FROM test | ||
| | EVAL similarity = %s(left_vector, right_vector) | ||
| | KEEP left_vector, right_vector, similarity | ||
| """, functionName); | ||
|
|
||
| try (var resp = run(query)) { | ||
| List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp); | ||
| valuesList.forEach(values -> { | ||
| float[] left = readVector((List<Float>) values.get(0)); | ||
| float[] right = readVector((List<Float>) values.get(1)); | ||
| Double similarity = (Double) values.get(2); | ||
|
|
||
| assertNotNull(similarity); | ||
| float expectedSimilarity = similarityFunction.compare(left, right); | ||
| assertEquals(expectedSimilarity, similarity, 0.0001); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| public void testSimilarityBetweenConstantVectorAndField() { | ||
| var randomVector = randomVectorArray(); | ||
| var query = String.format(Locale.ROOT, """ | ||
| FROM test | ||
| | EVAL similarity = %s(left_vector, %s) | ||
| | KEEP left_vector, similarity | ||
| """, functionName, Arrays.toString(randomVector)); | ||
|
|
||
| try (var resp = run(query)) { | ||
| List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp); | ||
| valuesList.forEach(values -> { | ||
| float[] left = readVector((List<Float>) values.get(0)); | ||
| Double similarity = (Double) values.get(1); | ||
|
|
||
| assertNotNull(similarity); | ||
| float expectedSimilarity = similarityFunction.compare(left, randomVector); | ||
| assertEquals(expectedSimilarity, similarity, 0.0001); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| public void testDifferentDimensions() { | ||
| var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2)); | ||
| var query = String.format(Locale.ROOT, """ | ||
| FROM test | ||
| | EVAL similarity = %s(left_vector, %s) | ||
| | KEEP left_vector, similarity | ||
| """, functionName, Arrays.toString(randomVector)); | ||
|
|
||
| EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); }); | ||
| assertTrue(iae.getMessage().contains("Vectors must have the same dimensions")); | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| public void testSimilarityBetweenConstantVectors() { | ||
| var vectorLeft = randomVectorArray(); | ||
| var vectorRight = randomVectorArray(); | ||
| var query = String.format(Locale.ROOT, """ | ||
| ROW a = 1 | ||
| | EVAL similarity = %s(%s, %s) | ||
| | KEEP similarity | ||
| """, functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight)); | ||
|
|
||
| try (var resp = run(query)) { | ||
| List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp); | ||
| assertEquals(1, valuesList.size()); | ||
|
|
||
| Double similarity = (Double) valuesList.get(0).get(0); | ||
| assertNotNull(similarity); | ||
| float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight); | ||
| assertEquals(expectedSimilarity, similarity, 0.0001); | ||
| } | ||
| } | ||
|
|
||
| private static float[] readVector(List<Float> leftVector) { | ||
| float[] leftScratch = new float[leftVector.size()]; | ||
| for (int i = 0; i < leftVector.size(); i++) { | ||
| leftScratch[i] = leftVector.get(i); | ||
| } | ||
| return leftScratch; | ||
| } | ||
|
|
||
| @Before | ||
| public void setup() throws IOException { | ||
| assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); | ||
|
|
||
| createIndexWithDenseVector("test"); | ||
|
|
||
| numDims = randomIntBetween(32, 64) * 2; // min 64, even number | ||
| int numDocs = randomIntBetween(10, 100); | ||
| IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; | ||
| for (int i = 0; i < numDocs; i++) { | ||
| List<Float> leftVector = randomVector(); | ||
| List<Float> rightVector = randomVector(); | ||
| docs[i] = prepareIndex("test").setId("" + i) | ||
| .setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector); | ||
| } | ||
|
|
||
| indexRandom(true, docs); | ||
| } | ||
|
|
||
| private List<Float> randomVector() { | ||
| assert numDims != 0 : "numDims must be set before calling randomVector()"; | ||
| List<Float> vector = new ArrayList<>(numDims); | ||
| for (int j = 0; j < numDims; j++) { | ||
| vector.add(randomFloat()); | ||
| } | ||
| return vector; | ||
| } | ||
|
|
||
| private float[] randomVectorArray() { | ||
| assert numDims != 0 : "numDims must be set before calling randomVectorArray()"; | ||
| return randomVectorArray(numDims); | ||
| } | ||
|
|
||
| private static float[] randomVectorArray(int dimensions) { | ||
| float[] vector = new float[dimensions]; | ||
| for (int j = 0; j < dimensions; j++) { | ||
| vector[j] = randomFloat(); | ||
| } | ||
| return vector; | ||
| } | ||
|
|
||
| private void createIndexWithDenseVector(String indexName) throws IOException { | ||
| var client = client().admin().indices(); | ||
| XContentBuilder mapping = XContentFactory.jsonBuilder() | ||
| .startObject() | ||
| .startObject("properties") | ||
| .startObject("id") | ||
| .field("type", "integer") | ||
| .endObject(); | ||
| createDenseVectorField(mapping, "left_vector"); | ||
| createDenseVectorField(mapping, "right_vector"); | ||
| mapping.endObject().endObject(); | ||
| Settings.Builder settingsBuilder = Settings.builder() | ||
| .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) | ||
| .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)); | ||
|
|
||
| var CreateRequest = client.prepareCreate(indexName) | ||
| .setSettings(Settings.builder().put("index.number_of_shards", 1)) | ||
| .setMapping(mapping) | ||
| .setSettings(settingsBuilder.build()); | ||
| assertAcked(CreateRequest); | ||
| } | ||
|
|
||
| private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException { | ||
| mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine"); | ||
| mapping.endObject(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1400,15 +1400,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func | |
| if (f instanceof In in) { | ||
| return processIn(in); | ||
| } | ||
| if (f instanceof VectorFunction) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needed to change the order to ensure VectorFunction are processed first, as similarity functions are scalar functions as well |
||
| return processVectorFunction(f); | ||
| } | ||
| if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed | ||
| return processScalarOrGroupingFunction(f, registry); | ||
| } | ||
| if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { | ||
| return processBinaryOperator((BinaryOperator) f); | ||
| } | ||
| if (f instanceof VectorFunction vectorFunction) { | ||
| return processVectorFunction(f); | ||
| } | ||
| return f; | ||
| } | ||
|
|
||
|
|
@@ -1613,14 +1613,22 @@ private static Expression castStringLiteral(Expression from, DataType target) { | |
| } | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) { | ||
| List<Expression> args = vectorFunction.arguments(); | ||
| List<Expression> newArgs = new ArrayList<>(); | ||
| for (Expression arg : args) { | ||
| if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { | ||
| Object folded = arg.fold(FoldContext.small() /* TODO remove me */); | ||
| if (folded instanceof List) { | ||
| Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR); | ||
| // Convert to floats so blocks are created accordingly | ||
| List<Float> floatVector; | ||
| if (arg.dataType() == FLOAT) { | ||
| floatVector = (List<Float>) folded; | ||
| } else { | ||
| floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList()); | ||
| } | ||
| Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR); | ||
| newArgs.add(denseVector); | ||
| continue; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this to work properly, we need to implement a conversion function so we can convert non-foldable values to dense_vector.