Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0c18eba
Add CosineSimilarity function
carlosdelest Jul 4, 2025
31faca1
Add IT
carlosdelest Jul 4, 2025
16542b5
Refactor
carlosdelest Jul 4, 2025
5098373
Extract superclass, use overriden method
carlosdelest Jul 4, 2025
b441d60
Use lambda instead of overriden methods
carlosdelest Jul 4, 2025
4508154
Refactoring
carlosdelest Jul 4, 2025
58bd1c0
Refactoring
carlosdelest Jul 4, 2025
fc88621
Allow using literals for dense_vectors, converting them internally to…
carlosdelest Jul 4, 2025
597741d
Add test
carlosdelest Jul 4, 2025
85e2426
Test non-null args
carlosdelest Jul 4, 2025
6cc2115
Rename to v_cosine, add analyzer tests for implicit casting
carlosdelest Jul 7, 2025
1a9b44c
First CSV tests
carlosdelest Jul 7, 2025
a53eab3
Add tests
carlosdelest Jul 7, 2025
8c80b72
Spotless
carlosdelest Jul 7, 2025
7517557
Add comment
carlosdelest Jul 7, 2025
58e6ac7
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 7, 2025
ebafdf8
[CI] Auto commit changes from spotless
Jul 7, 2025
3ab9a72
Add checks for different number of dimensions
carlosdelest Jul 7, 2025
b1d6f85
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 7, 2025
4b0b772
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 7, 2025
53e96f9
Add test infrastructure for VectorSimilarityFunction. Change VectorSi…
carlosdelest Jul 8, 2025
290dbe1
Ensure casting is done using floats so we get the appropriate blocks …
carlosdelest Jul 8, 2025
312a727
Generate docs for function
carlosdelest Jul 8, 2025
08364ee
Generate docs for function
carlosdelest Jul 8, 2025
fb3bec7
Fix tests
carlosdelest Jul 8, 2025
3d82d86
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
7601522
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
2a9e322
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 8, 2025
dd09bf8
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 8, 2025
2400b7a
Fix telemetry test
carlosdelest Jul 8, 2025
2bdb35d
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 8, 2025
81f26ec
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
707b3c2
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
7fcdb36
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 15, 2025
5e9bcd2
[CI] Auto commit changes from spotless
Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Tests for cosine similarity function

similarityWithVectorField
required_capability: cosine_vector_similarity_function

// tag::vector-cosine-similarity[]
from colors
| where color != "black"
| eval similarity = v_cosine(rgb_vector, [0, 255, 255])
| sort similarity desc, color asc
// end::vector-cosine-similarity[]
| limit 10
| keep color, similarity
;

// tag::vector-cosine-similarity-result[]
color:text | similarity:double
cyan | 1.0
teal | 1.0
turquoise | 0.9890533685684204
aqua marine | 0.964962363243103
azure | 0.916246771812439
lavender | 0.9136701822280884
mint cream | 0.9122757911682129
honeydew | 0.9122424125671387
gainsboro | 0.9082483053207397
gray | 0.9082483053207397
// end::vector-cosine-similarity-result[]
;

similarityAsPartOfExpression
required_capability: cosine_vector_similarity_function

from colors
| where color != "black"
| eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
cyan | 1.5
teal | 1.5
turquoise | 1.495
aqua marine | 1.482
azure | 1.458
lavender | 1.457
honeydew | 1.456
mint cream | 1.456
gainsboro | 1.454
gray | 1.454
;

similarityWithLiteralVectors
required_capability: cosine_vector_similarity_function

row a = 1
| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3)
| keep similarity
;

similarity:double
0.978
;

similarityWithStats
required_capability: cosine_vector_similarity_function

from colors
| where color != "black"
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
0.832 | 0.5 | 1.0
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: cosine_vector_similarity_function

row vector = [1, 2, 3]
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this to work properly, we need to implement a conversion function so we can convert non-foldable values to dense_vector.

| sort similarity desc, color asc
| limit 10
| keep color, similarity
;

similarity:double
0.978
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.vector;

import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlClientException;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;

public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();

params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });

return params;
}

private final String functionName;
private final VectorSimilarityFunction similarityFunction;
private int numDims;

public VectorSimilarityFunctionsIT(
@Name("functionName") String functionName,
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
) {
this.functionName = functionName;
this.similarityFunction = similarityFunction;
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenVectors() {
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, right_vector)
| KEEP left_vector, right_vector, similarity
""", functionName);

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
float[] right = readVector((List<Float>) values.get(1));
Double similarity = (Double) values.get(2);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectorAndField() {
var randomVector = randomVectorArray();
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
| KEEP left_vector, similarity
""", functionName, Arrays.toString(randomVector));

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
Double similarity = (Double) values.get(1);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}

public void testDifferentDimensions() {
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
| KEEP left_vector, similarity
""", functionName, Arrays.toString(randomVector));

EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectors() {
var vectorLeft = randomVectorArray();
var vectorRight = randomVectorArray();
var query = String.format(Locale.ROOT, """
ROW a = 1
| EVAL similarity = %s(%s, %s)
| KEEP similarity
""", functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(1, valuesList.size());

Double similarity = (Double) valuesList.get(0).get(0);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
}

private static float[] readVector(List<Float> leftVector) {
float[] leftScratch = new float[leftVector.size()];
for (int i = 0; i < leftVector.size(); i++) {
leftScratch[i] = leftVector.get(i);
}
return leftScratch;
}

@Before
public void setup() throws IOException {
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());

createIndexWithDenseVector("test");

numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> leftVector = randomVector();
List<Float> rightVector = randomVector();
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
}

indexRandom(true, docs);
}

private List<Float> randomVector() {
assert numDims != 0 : "numDims must be set before calling randomVector()";
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
}
return vector;
}

private float[] randomVectorArray() {
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
return randomVectorArray(numDims);
}

private static float[] randomVectorArray(int dimensions) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = randomFloat();
}
return vector;
}

private void createIndexWithDenseVector(String indexName) throws IOException {
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("id")
.field("type", "integer")
.endObject();
createDenseVectorField(mapping, "left_vector");
createDenseVectorField(mapping, "right_vector");
mapping.endObject().endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));

var CreateRequest = client.prepareCreate(indexName)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping(mapping)
.setSettings(settingsBuilder.build());
assertAcked(CreateRequest);
}

private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
mapping.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,12 @@ public enum Cap {
/**
* Support avg with aggregate metric doubles
*/
AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG);
AGGREGATE_METRIC_DOUBLE_AVG(AGGREGATE_METRIC_DOUBLE_FEATURE_FLAG),

/**
* Cosine vector similarity function
*/
COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());

private final boolean enabled;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1400,15 +1400,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func
if (f instanceof In in) {
return processIn(in);
}
if (f instanceof VectorFunction) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to change the order to ensure VectorFunction are processed first, as similarity functions are scalar functions as well

return processVectorFunction(f);
}
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
return processScalarOrGroupingFunction(f, registry);
}
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
return processBinaryOperator((BinaryOperator) f);
}
if (f instanceof VectorFunction vectorFunction) {
return processVectorFunction(f);
}
return f;
}

Expand Down Expand Up @@ -1613,14 +1613,22 @@ private static Expression castStringLiteral(Expression from, DataType target) {
}
}

@SuppressWarnings("unchecked")
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
List<Expression> args = vectorFunction.arguments();
List<Expression> newArgs = new ArrayList<>();
for (Expression arg : args) {
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
if (folded instanceof List) {
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
// Convert to floats so blocks are created accordingly
List<Float> floatVector;
if (arg.dataType() == FLOAT) {
floatVector = (List<Float>) folded;
} else {
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
}
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
newArgs.add(denseVector);
continue;
}
Expand Down
Loading
Loading