Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0c18eba
Add CosineSimilarity function
carlosdelest Jul 4, 2025
31faca1
Add IT
carlosdelest Jul 4, 2025
16542b5
Refactor
carlosdelest Jul 4, 2025
5098373
Extract superclass, use overriden method
carlosdelest Jul 4, 2025
b441d60
Use lambda instead of overriden methods
carlosdelest Jul 4, 2025
4508154
Refactoring
carlosdelest Jul 4, 2025
58bd1c0
Refactoring
carlosdelest Jul 4, 2025
fc88621
Allow using literals for dense_vectors, converting them internally to…
carlosdelest Jul 4, 2025
597741d
Add test
carlosdelest Jul 4, 2025
85e2426
Test non-null args
carlosdelest Jul 4, 2025
6cc2115
Rename to v_cosine, add analyzer tests for implicit casting
carlosdelest Jul 7, 2025
1a9b44c
First CSV tests
carlosdelest Jul 7, 2025
a53eab3
Add tests
carlosdelest Jul 7, 2025
8c80b72
Spotless
carlosdelest Jul 7, 2025
7517557
Add comment
carlosdelest Jul 7, 2025
58e6ac7
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 7, 2025
ebafdf8
[CI] Auto commit changes from spotless
Jul 7, 2025
3ab9a72
Add checks for different number of dimensions
carlosdelest Jul 7, 2025
b1d6f85
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 7, 2025
4b0b772
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 7, 2025
53e96f9
Add test infrastructure for VectorSimilarityFunction. Change VectorSi…
carlosdelest Jul 8, 2025
290dbe1
Ensure casting is done using floats so we get the appropriate blocks …
carlosdelest Jul 8, 2025
312a727
Generate docs for function
carlosdelest Jul 8, 2025
08364ee
Generate docs for function
carlosdelest Jul 8, 2025
fb3bec7
Fix tests
carlosdelest Jul 8, 2025
3d82d86
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
7601522
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
2a9e322
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 8, 2025
dd09bf8
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 8, 2025
2400b7a
Fix telemetry test
carlosdelest Jul 8, 2025
2bdb35d
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 8, 2025
81f26ec
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
707b3c2
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
7fcdb36
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 15, 2025
5e9bcd2
[CI] Auto commit changes from spotless
Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public static void appendValue(Block.Builder builder, Object val, ElementType ty
case LONG -> ((LongBlock.Builder) builder).appendLong((Long) val);
case INT -> ((IntBlock.Builder) builder).appendInt((Integer) val);
case BYTES_REF -> ((BytesRefBlock.Builder) builder).appendBytesRef(toBytesRef(val));
case FLOAT -> ((FloatBlock.Builder) builder).appendFloat((Float) val);
case FLOAT -> ((FloatBlock.Builder) builder).appendFloat(((Number) val).floatValue());
case DOUBLE -> ((DoubleBlock.Builder) builder).appendDouble((Double) val);
case BOOLEAN -> ((BooleanBlock.Builder) builder).appendBoolean((Boolean) val);
default -> throw new UnsupportedOperationException("unsupported element type [" + type + "]");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Tests for cosine similarity function

similarityWithVectorField
required_capability: cosine_vector_similarity_function

from colors
| where color != "black"
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
| sort similarity desc, color asc
| limit 10
| keep color, similarity
;

color:text | similarity:double
cyan | 1.0
teal | 1.0
turquoise | 0.989
aqua marine | 0.965
azure | 0.916
lavender | 0.914
honeydew | 0.912
mint cream | 0.912
gainsboro | 0.908
gray | 0.908
;

similarityAsPartOfExpression
required_capability: cosine_vector_similarity_function

from colors
| where color != "black"
| eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
cyan | 1.5
teal | 1.5
turquoise | 1.495
aqua marine | 1.482
azure | 1.458
lavender | 1.457
honeydew | 1.456
mint cream | 1.456
gainsboro | 1.454
gray | 1.454
;

similarityWithLiteralVectors
required_capability: cosine_vector_similarity_function

row a = 1
| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3)
| keep similarity
;

similarity:double
0.978
;

similarityWithStats
required_capability: cosine_vector_similarity_function

from colors
| where color != "black"
| eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
0.832 | 0.5 | 1.0
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: cosine_vector_similarity_function

row vector = [1, 2, 3]
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this to work properly, we need to implement a conversion function so we can convert non-foldable values to dense_vector.

| sort similarity desc, color asc
| limit 10
| keep color, similarity
;

similarity:double
0.978
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.vector;

import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;

public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();

params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });

return params;
}

private final String functionName;
private final VectorSimilarityFunction similarityFunction;
private int numDims;

public VectorSimilarityFunctionsIT(
@Name("functionName") String functionName,
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
) {
this.functionName = functionName;
this.similarityFunction = similarityFunction;
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenVectors() {
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, right_vector)
| KEEP left_vector, right_vector, similarity
""", functionName);

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
float[] right = readVector((List<Float>) values.get(1));
Double similarity = (Double) values.get(2);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectorAndField() {
var randomVector = randomVectorArray();
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(left_vector, %s)
| KEEP left_vector, similarity
""", functionName, Arrays.toString(randomVector));

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
float[] left = readVector((List<Float>) values.get(0));
Double similarity = (Double) values.get(1);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}

@SuppressWarnings("unchecked")
public void testSimilarityBetweenConstantVectors() {
var vectorLeft = randomVectorArray();
var vectorRight = randomVectorArray();
var query = String.format(Locale.ROOT, """
ROW a = 1
| EVAL similarity = %s(%s, %s)
| KEEP similarity
""", functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(1, valuesList.size());

Double similarity = (Double) valuesList.get(0).get(0);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
}

private static float[] readVector(List<Float> leftVector) {
float[] leftScratch = new float[leftVector.size()];
for (int i = 0; i < leftVector.size(); i++) {
leftScratch[i] = leftVector.get(i);
}
return leftScratch;
}

@Before
public void setup() throws IOException {
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());

createIndexWithDenseVector("test");

numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> leftVector = randomVector();
List<Float> rightVector = randomVector();
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
}

indexRandom(true, docs);
}

private List<Float> randomVector() {
assert numDims != 0 : "numDims must be set before calling randomVector()";
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
}
return vector;
}

private float[] randomVectorArray() {
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
float[] vector = new float[numDims];
for (int j = 0; j < numDims; j++) {
vector[j] = randomFloat();
}
return vector;
}

private void createIndexWithDenseVector(String indexName) throws IOException {
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("id")
.field("type", "integer")
.endObject();
createDenseVectorField(mapping, "left_vector");
createDenseVectorField(mapping, "right_vector");
mapping.endObject().endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));

var CreateRequest = client.prepareCreate(indexName)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping(mapping)
.setSettings(settingsBuilder.build());
assertAcked(CreateRequest);
}

private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
mapping.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,9 @@ public enum Cap {
/**
* (Re)Added EXPLAIN command
*/
EXPLAIN(Build.current().isSnapshot());
EXPLAIN(Build.current().isSnapshot()),

COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());

private final boolean enabled;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1399,15 +1399,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func
if (f instanceof In in) {
return processIn(in);
}
if (f instanceof VectorFunction) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to change the order to ensure VectorFunction are processed first, as similarity functions are scalar functions as well

return processVectorFunction(f);
}
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
return processScalarOrGroupingFunction(f, registry);
}
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
return processBinaryOperator((BinaryOperator) f);
}
if (f instanceof VectorFunction vectorFunction) {
return processVectorFunction(f);
}
return f;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper;
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
Expand Down Expand Up @@ -248,7 +249,11 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions
if (multiValue.isEmpty()) {
return blockFactory.newConstantNullBlock(positions);
}
var wrapper = BlockUtils.wrapperFor(blockFactory, ElementType.fromJava(multiValue.get(0).getClass()), positions);
// dense_vector create internally float values, even if they are specified as doubles
ElementType elementType = lit.dataType() == DataType.DENSE_VECTOR
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this logic be in its own method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say no as this is a one-liner for getting the correct ElementType - there's no more logic than doing a specific check for dense_vector. I'd say, ff more special cases come into play then let's add it as it will become confusing.

? ElementType.FLOAT
: ElementType.fromJava(multiValue.get(0).getClass());
var wrapper = BlockUtils.wrapperFor(blockFactory, elementType, positions);
for (int i = 0; i < positions; i++) {
wrapper.accept(multiValue);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.elasticsearch.xpack.esql.expression;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
Expand Down Expand Up @@ -85,7 +84,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
Expand Down Expand Up @@ -259,9 +258,6 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
}

private static List<NamedWriteableRegistry.Entry> vector() {
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();
return VectorWritables.getNamedWritables();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;
Expand Down Expand Up @@ -487,7 +488,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
def(StGeohex.class, StGeohex::new, "st_geohex"),
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Loading