Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0c18eba
Add CosineSimilarity function
carlosdelest Jul 4, 2025
31faca1
Add IT
carlosdelest Jul 4, 2025
16542b5
Refactor
carlosdelest Jul 4, 2025
5098373
Extract superclass, use overriden method
carlosdelest Jul 4, 2025
b441d60
Use lambda instead of overriden methods
carlosdelest Jul 4, 2025
4508154
Refactoring
carlosdelest Jul 4, 2025
58bd1c0
Refactoring
carlosdelest Jul 4, 2025
fc88621
Allow using literals for dense_vectors, converting them internally to…
carlosdelest Jul 4, 2025
597741d
Add test
carlosdelest Jul 4, 2025
85e2426
Test non-null args
carlosdelest Jul 4, 2025
6cc2115
Rename to v_cosine, add analyzer tests for implicit casting
carlosdelest Jul 7, 2025
1a9b44c
First CSV tests
carlosdelest Jul 7, 2025
a53eab3
Add tests
carlosdelest Jul 7, 2025
8c80b72
Spotless
carlosdelest Jul 7, 2025
7517557
Add comment
carlosdelest Jul 7, 2025
58e6ac7
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 7, 2025
ebafdf8
[CI] Auto commit changes from spotless
Jul 7, 2025
3ab9a72
Add checks for different number of dimensions
carlosdelest Jul 7, 2025
b1d6f85
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 7, 2025
4b0b772
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 7, 2025
53e96f9
Add test infrastructure for VectorSimilarityFunction. Change VectorSi…
carlosdelest Jul 8, 2025
290dbe1
Ensure casting is done using floats so we get the appropriate blocks …
carlosdelest Jul 8, 2025
312a727
Generate docs for function
carlosdelest Jul 8, 2025
08364ee
Generate docs for function
carlosdelest Jul 8, 2025
fb3bec7
Fix tests
carlosdelest Jul 8, 2025
3d82d86
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
7601522
Remove unnecessary change to BlockUtis now that we create float eleme…
carlosdelest Jul 8, 2025
2a9e322
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 8, 2025
dd09bf8
Merge remote-tracking branch 'origin/main' into non-issue/esql-vector…
carlosdelest Jul 8, 2025
2400b7a
Fix telemetry test
carlosdelest Jul 8, 2025
2bdb35d
Merge remote-tracking branch 'carlosdelest/non-issue/esql-vector-sear…
carlosdelest Jul 8, 2025
81f26ec
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
707b3c2
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 9, 2025
7fcdb36
Merge branch 'main' into non-issue/esql-vector-search-functions-basics
carlosdelest Jul 15, 2025
5e9bcd2
[CI] Auto commit changes from spotless
Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.vector;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;

public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {

@SuppressWarnings("unchecked")
public void testCosineSimilarity() {
var query = """
FROM test
| EVAL similarity = v_cosine_similarity(left_vector, right_vector)
| KEEP id, left_vector, right_vector, similarity
""";

try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
valuesList.forEach(values -> {
List<Float> leftVector = (List<Float>) values.get(1);
float[] leftScratch = new float[leftVector.size()];
for (int i = 0; i < leftVector.size(); i++) {
leftScratch[i] = leftVector.get(i);
}
List<Float> rightVector = (List<Float>) values.get(2);
float[] rightScratch = new float[rightVector.size()];
for (int i = 0; i < rightVector.size(); i++) {
rightScratch[i] = rightVector.get(i);
}
Double similarity = (Double) values.get(3);
assertNotNull(similarity);

float expectedSimilarity = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
}

@Before
public void setup() throws IOException {
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());

createIndexWithDenseVector("test");

int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> leftVector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
leftVector.add(randomFloat());
}
List<Float> rightVector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
rightVector.add(randomFloat());
}
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
}

indexRandom(true, docs);
}

private void createIndexWithDenseVector(String indexName) throws IOException {
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("id")
.field("type", "integer")
.endObject();
createDenseVectorField(mapping, "left_vector");
createDenseVectorField(mapping, "right_vector");
mapping.endObject().endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));

var CreateRequest = client.prepareCreate(indexName)
.setSettings(Settings.builder().put("index.number_of_shards", 1))
.setMapping(mapping)
.setSettings(settingsBuilder.build());
assertAcked(CreateRequest);
}

private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
mapping.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,9 @@ public enum Cap {
/**
* (Re)Added EXPLAIN command
*/
EXPLAIN(Build.current().isSnapshot());
EXPLAIN(Build.current().isSnapshot()),

COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());

private final boolean enabled;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.elasticsearch.xpack.esql.expression;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
Expand Down Expand Up @@ -85,7 +84,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
Expand Down Expand Up @@ -259,9 +258,6 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
}

private static List<NamedWriteableRegistry.Entry> vector() {
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();
return VectorWritables.getNamedWritables();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;
Expand Down Expand Up @@ -487,7 +488,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
def(StGeohex.class, StGeohex::new, "st_geohex"),
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine_similarity") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.vector;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;

import java.io.IOException;

import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;

public class CosineSimilarity extends org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"CosineSimilarity",
CosineSimilarity::new
);

@FunctionInfo(
returnType = "double",
preview = true,
description = "Calculates the cosine similarity between two dense_vectors.",
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public CosineSimilarity(
Source source,
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
@Param(
name = "right",
type = { "dense_vector" },
description = "second dense_vector to calculate cosine similarity"
) Expression right
) {
super(source, left, right);
}

private CosineSimilarity(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
}

@Override
protected SimilarityEvaluatorFunction getSimilarityFunction() {
return COSINE::compare;
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, CosineSimilarity::new, left(), right());
}

@Override
public String getWriteableName() {
return ENTRY.name;
}
}
Loading