diff --git a/docs/reference/query-languages/esql/images/functions/v_cosine.svg b/docs/reference/query-languages/esql/images/functions/v_cosine.svg new file mode 100644 index 0000000000000..fb7a2ed91fa8d --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/v_cosine.svg @@ -0,0 +1 @@ +V_COSINE(left,right) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json b/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json new file mode 100644 index 0000000000000..f3b3df1d88c6a --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json @@ -0,0 +1,12 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "v_cosine", + "description" : "Calculates the cosine similarity between two dense_vectors.", + "signatures" : [ ], + "examples" : [ + " from colors\n | where color != \"black\"\n | eval similarity = v_cosine(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md b/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md new file mode 100644 index 0000000000000..22e4626fe38ad --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md @@ -0,0 +1,11 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### V COSINE +Calculates the cosine similarity between two dense_vectors. + +```esql + from colors + | where color != "black" + | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) + | sort similarity desc, color asc +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec new file mode 100644 index 0000000000000..d9e1ff408c739 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec @@ -0,0 +1,93 @@ + # Tests for cosine similarity function + + similarityWithVectorField + required_capability: cosine_vector_similarity_function + +// tag::vector-cosine-similarity[] + from colors + | where color != "black" + | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) + | sort similarity desc, color asc +// end::vector-cosine-similarity[] + | limit 10 + | keep color, similarity + ; + +// tag::vector-cosine-similarity-result[] +color:text | similarity:double +cyan | 1.0 +teal | 1.0 +turquoise | 0.9890533685684204 +aqua marine | 0.964962363243103 +azure | 0.916246771812439 +lavender | 0.9136701822280884 +mint cream | 0.9122757911682129 +honeydew | 0.9122424125671387 +gainsboro | 0.9082483053207397 +gray | 0.9082483053207397 +// end::vector-cosine-similarity-result[] +; + + similarityAsPartOfExpression + required_capability: cosine_vector_similarity_function + + from colors + | where color != "black" + | eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3) + | sort score desc, color asc + | limit 10 + | keep color, score + ; + +color:text | score:double +cyan | 1.5 +teal | 1.5 +turquoise | 1.495 +aqua marine | 1.482 +azure | 1.458 +lavender | 1.457 +honeydew | 1.456 +mint cream | 1.456 +gainsboro | 1.454 +gray | 1.454 +; + +similarityWithLiteralVectors +required_capability: cosine_vector_similarity_function + +row a = 1 +| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3) +| keep similarity +; + +similarity:double +0.978 +; + + similarityWithStats + required_capability: cosine_vector_similarity_function + + from colors + | where color != "black" + | eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3) + | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity) + ; + +avg:double | min:double | max:double +0.832 | 0.5 | 1.0 +; + +# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector +similarityWithRow-Ignore +required_capability: cosine_vector_similarity_function + +row vector = [1, 2, 3] +| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3) +| sort similarity desc, color asc +| limit 10 +| keep color, similarity +; + +similarity:double +0.978 +; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java new file mode 100644 index 0000000000000..6a861746facfd --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.esql.EsqlClientException; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase { + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + + params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE }); + + return params; + } + + private final String functionName; + private final VectorSimilarityFunction similarityFunction; + private int numDims; + + public VectorSimilarityFunctionsIT( + @Name("functionName") String functionName, + @Name("similarityFunction") VectorSimilarityFunction similarityFunction + ) { + this.functionName = functionName; + this.similarityFunction = similarityFunction; + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenVectors() { + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, right_vector) + | KEEP left_vector, right_vector, similarity + """, functionName); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + valuesList.forEach(values -> { + float[] left = readVector((List) values.get(0)); + float[] right = readVector((List) values.get(1)); + Double similarity = (Double) values.get(2); + + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(left, right); + assertEquals(expectedSimilarity, similarity, 0.0001); + }); + } + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenConstantVectorAndField() { + var randomVector = randomVectorArray(); + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, %s) + | KEEP left_vector, similarity + """, functionName, Arrays.toString(randomVector)); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + valuesList.forEach(values -> { + float[] left = readVector((List) values.get(0)); + Double similarity = (Double) values.get(1); + + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(left, randomVector); + assertEquals(expectedSimilarity, similarity, 0.0001); + }); + } + } + + public void testDifferentDimensions() { + var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2)); + var query = String.format(Locale.ROOT, """ + FROM test + | EVAL similarity = %s(left_vector, %s) + | KEEP left_vector, similarity + """, functionName, Arrays.toString(randomVector)); + + EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); }); + assertTrue(iae.getMessage().contains("Vectors must have the same dimensions")); + } + + @SuppressWarnings("unchecked") + public void testSimilarityBetweenConstantVectors() { + var vectorLeft = randomVectorArray(); + var vectorRight = randomVectorArray(); + var query = String.format(Locale.ROOT, """ + ROW a = 1 + | EVAL similarity = %s(%s, %s) + | KEEP similarity + """, functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight)); + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(1, valuesList.size()); + + Double similarity = (Double) valuesList.get(0).get(0); + assertNotNull(similarity); + float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight); + assertEquals(expectedSimilarity, similarity, 0.0001); + } + } + + private static float[] readVector(List leftVector) { + float[] leftScratch = new float[leftVector.size()]; + for (int i = 0; i < leftVector.size(); i++) { + leftScratch[i] = leftVector.get(i); + } + return leftScratch; + } + + @Before + public void setup() throws IOException { + assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + + createIndexWithDenseVector("test"); + + numDims = randomIntBetween(32, 64) * 2; // min 64, even number + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + List leftVector = randomVector(); + List rightVector = randomVector(); + docs[i] = prepareIndex("test").setId("" + i) + .setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector); + } + + indexRandom(true, docs); + } + + private List randomVector() { + assert numDims != 0 : "numDims must be set before calling randomVector()"; + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(randomFloat()); + } + return vector; + } + + private float[] randomVectorArray() { + assert numDims != 0 : "numDims must be set before calling randomVectorArray()"; + return randomVectorArray(numDims); + } + + private static float[] randomVectorArray(int dimensions) { + float[] vector = new float[dimensions]; + for (int j = 0; j < dimensions; j++) { + vector[j] = randomFloat(); + } + return vector; + } + + private void createIndexWithDenseVector(String indexName) throws IOException { + var client = client().admin().indices(); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "integer") + .endObject(); + createDenseVectorField(mapping, "left_vector"); + createDenseVectorField(mapping, "right_vector"); + mapping.endObject().endObject(); + Settings.Builder settingsBuilder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)); + + var CreateRequest = client.prepareCreate(indexName) + .setSettings(Settings.builder().put("index.number_of_shards", 1)) + .setMapping(mapping) + .setSettings(settingsBuilder.build()); + assertAcked(CreateRequest); + } + + private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException { + mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine"); + mapping.endObject(); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 4759579b94d24..e2ac8a4bef894 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1254,7 +1254,12 @@ public enum Cap { * Forbid usage of brackets in unquoted index and enrich policy names * https://github.com/elastic/elasticsearch/issues/130378 */ - NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES; + NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES, + + /* + * Cosine vector similarity function + */ + COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index e4b8949af5bdb..0058a8691f818 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -1400,15 +1400,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func if (f instanceof In in) { return processIn(in); } + if (f instanceof VectorFunction) { + return processVectorFunction(f); + } if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed return processScalarOrGroupingFunction(f, registry); } if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { return processBinaryOperator((BinaryOperator) f); } - if (f instanceof VectorFunction vectorFunction) { - return processVectorFunction(f); - } return f; } @@ -1613,6 +1613,7 @@ private static Expression castStringLiteral(Expression from, DataType target) { } } + @SuppressWarnings("unchecked") private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) { List args = vectorFunction.arguments(); List newArgs = new ArrayList<>(); @@ -1620,7 +1621,14 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { Object folded = arg.fold(FoldContext.small() /* TODO remove me */); if (folded instanceof List) { - Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR); + // Convert to floats so blocks are created accordingly + List floatVector; + if (arg.dataType() == FLOAT) { + floatVector = (List) folded; + } else { + floatVector = ((List) folded).stream().map(Number::floatValue).collect(Collectors.toList()); + } + Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR); newArgs.add(denseVector); continue; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index a3f6d3a089d49..28181d11b6833 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; @@ -85,7 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; -import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; @@ -259,9 +258,6 @@ private static List fullText() { } private static List vector() { - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { - return List.of(Knn.ENTRY); - } - return List.of(); + return VectorWritables.getNamedWritables(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index fd7f853eec089..f51ec914bf500 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -180,6 +180,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -489,7 +490,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"), def(StGeohex.class, StGeohex::new, "st_geohex"), def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"), - def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } }; + def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"), + def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java new file mode 100644 index 0000000000000..a86eb5633f729 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; + +public class CosineSimilarity extends VectorSimilarityFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "CosineSimilarity", + CosineSimilarity::new + ); + static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare; + + @FunctionInfo( + returnType = "double", + preview = true, + description = "Calculates the cosine similarity between two dense_vectors.", + examples = { @Example(file = "vector-cosine-similarity", tag = "vector-cosine-similarity") }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public CosineSimilarity( + Source source, + @Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left, + @Param( + name = "right", + type = { "dense_vector" }, + description = "second dense_vector to calculate cosine similarity" + ) Expression right + ) { + super(source, left, right); + } + + private CosineSimilarity(StreamInput in) throws IOException { + super(in); + } + + @Override + protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) { + return new CosineSimilarity(source(), newLeft, newRight); + } + + @Override + protected SimilarityEvaluatorFunction getSimilarityFunction() { + return SIMILARITY_FUNCTION; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, CosineSimilarity::new, left(), right()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java new file mode 100644 index 0000000000000..fc27ae2d876e8 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.EsqlClientException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; + +import java.io.IOException; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; + +/** + * Base class for vector similarity functions, which compute a similarity score between two dense vectors + */ +public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction { + + protected VectorSimilarityFunction(Source source, Expression left, Expression right) { + super(source, left, right); + } + + protected VectorSimilarityFunction(StreamInput in) throws IOException { + super(in); + } + + @Override + public DataType dataType() { + return DataType.DOUBLE; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + return checkDenseVectorParam(left(), FIRST).and(checkDenseVectorParam(right(), SECOND)); + } + + private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) { + return isNotNull(param, sourceText(), paramOrdinal).and( + isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector") + ); + } + + /** + * Functional interface for evaluating the similarity between two float arrays + */ + @FunctionalInterface + public interface SimilarityEvaluatorFunction { + float calculateSimilarity(float[] leftScratch, float[] rightScratch); + } + + @Override + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); + } + + @Override + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { + return new SimilarityEvaluatorFactory( + toEvaluator.apply(left()), + toEvaluator.apply(right()), + getSimilarityFunction(), + getClass().getSimpleName() + "Evaluator" + ); + } + + /** + * Returns the similarity function to be used for evaluating the similarity between two vectors. + */ + protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); + + private record SimilarityEvaluatorFactory( + EvalOperator.ExpressionEvaluator.Factory left, + EvalOperator.ExpressionEvaluator.Factory right, + SimilarityEvaluatorFunction similarityFunction, + String evaluatorName + ) implements EvalOperator.ExpressionEvaluator.Factory { + + @Override + public EvalOperator.ExpressionEvaluator get(DriverContext context) { + // TODO check whether to use this custom evaluator or reuse / define an existing one + return new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + try ( + FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page); + FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page) + ) { + int positionCount = page.getPositionCount(); + int dimensions = 0; + // Get the first non-empty vector to calculate the dimension + for (int p = 0; p < positionCount; p++) { + if (leftBlock.getValueCount(p) != 0) { + dimensions = leftBlock.getValueCount(p); + break; + } + } + if (dimensions == 0) { + return context.blockFactory().newConstantFloatBlockWith(0F, 0); + } + + float[] leftScratch = new float[dimensions]; + float[] rightScratch = new float[dimensions]; + try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) { + for (int p = 0; p < positionCount; p++) { + int dimsLeft = leftBlock.getValueCount(p); + int dimsRight = rightBlock.getValueCount(p); + + if (dimsLeft == 0 || dimsRight == 0) { + // A null value on the left or right vector. Similarity is 0 + builder.appendDouble(0.0); + continue; + } else if (dimsLeft != dimsRight) { + throw new EsqlClientException( + "Vectors must have the same dimensions; first vector has {}, and second has {}", + dimsLeft, + dimsRight + ); + } + readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); + readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); + float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); + builder.appendDouble(result); + } + return builder.build().asBlock(); + } + } + } + + @Override + public String toString() { + return evaluatorName() + "[left=" + left + ", right=" + right + "]"; + } + + @Override + public void close() {} + }; + } + + private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { + for (int i = 0; i < dimensions; i++) { + scratch[i] = block.getFloat(position + i); + } + } + + @Override + public String toString() { + return evaluatorName() + "[left=" + left + ", right=" + right + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java new file mode 100644 index 0000000000000..f1bf291b7715e --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Defines the named writables for vector functions in ESQL. + */ +public final class VectorWritables { + + private VectorWritables() { + // Utility class + throw new UnsupportedOperationException(); + } + + public static List getNamedWritables() { + List entries = new ArrayList<>(); + + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + entries.add(Knn.ENTRY); + } + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + entries.add(CosineSimilarity.ENTRY); + } + + return Collections.unmodifiableList(entries); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index b2521bddfb47b..439e10cce27d4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; @@ -92,6 +93,7 @@ import java.time.Period; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -123,6 +125,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; @@ -2337,7 +2340,7 @@ public void testImplicitCasting() { assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]")); } - public void testDenseVectorImplicitCasting() { + public void testDenseVectorImplicitCastingKnn() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); @@ -2351,7 +2354,46 @@ public void testDenseVectorImplicitCasting() { var field = knn.field(); var queryVector = as(knn.query(), Literal.class); assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); - assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234))); + assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f))); + } + + public void testDenseVectorImplicitCastingSimilarityFunctions() { + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + } + } + + private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List expectedElems) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | eval similarity = %s + """, similarityFunction), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("similarity", alias.name()); + var similarity = as(alias.child(), VectorSimilarityFunction.class); + var left = as(similarity.left(), FieldAttribute.class); + assertEquals("vector", left.name()); + var right = as(similarity.right(), Literal.class); + assertThat(right.dataType(), is(DENSE_VECTOR)); + assertThat(right.value(), equalTo(expectedElems)); + } + + public void testNoDenseVectorFailsSimilarityFunction() { + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)"); + } + } + + private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) { + var query = String.format(Locale.ROOT, "row a = 1 | eval similarity = %s", similarityFunction); + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query)); + assertThat( + error.getMessage(), + containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]") + ); } public void testRateRequiresCounterTypes() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 407360f0bf5f2..5224199eb5277 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2300,6 +2300,20 @@ private void checkFullTextFunctionsInStats(String functionInvocation) { ); } + public void testVectorSimilarityFunctionsNullArgs() throws Exception { + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first"); + checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second"); + } + } + + private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception { + assertThat( + error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer), + containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]") + ); + } + private void query(String query) { query(query, defaultAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java new file mode 100644 index 0000000000000..329eba63046f4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.hamcrest.Matchers.equalTo; + +public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase { + + protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @Before + public void checkCapability() { + assumeTrue("Similarity function is not enabled", capability().isEnabled()); + } + + /** + * Get the capability of the vector similarity function to check + */ + protected abstract EsqlCapabilities.Cap capability(); + + protected static Iterable similarityParameters( + String className, + VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction + ) { + + final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]"; + + List suppliers = new ArrayList<>(); + + // Basic test with two dense vectors + suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> { + int dimensions = between(64, 128); + List left = randomDenseVector(dimensions); + List right = randomDenseVector(dimensions); + float[] leftArray = listToFloatArray(left); + float[] rightArray = listToFloatArray(right); + double expected = similarityFunction.calculateSimilarity(leftArray, rightArray); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"), + new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2") + ), + evaluatorName, + DOUBLE, + equalTo(expected) // Random vectors should have cosine similarity close to 0 + ); + })); + + return parameterSuppliersFromTypedData(suppliers); + } + + private static float[] listToFloatArray(List floatList) { + float[] floatArray = new float[floatList.size()]; + for (int i = 0; i < floatList.size(); i++) { + floatArray[i] = floatList.get(i); + } + return floatArray; + } + + protected double calculateSimilarity(List left, List right) { + return 0; + } + + /** + * @return A random dense vector for testing + * @param dimensions + */ + private static List randomDenseVector(int dimensions) { + List vector = new ArrayList<>(); + for (int i = 0; i < dimensions; i++) { + vector.add(randomFloat()); + } + return vector; + } + + @Override + protected Matcher allNullsMatcher() { + // A null value on the left or right vector. Similarity is 0 + return equalTo(0.0); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java new file mode 100644 index 0000000000000..32ba95ee0af27 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.List; +import java.util.function.Supplier; + +@FunctionName("v_cosine") +public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase { + + public CosineSimilarityTests(@Name("TestCase") Supplier testCaseSupplier) { + super(testCaseSupplier); + } + + @ParametersFactory + public static Iterable parameters() { + return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION); + } + + protected EsqlCapabilities.Cap capability() { + return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION; + } + + @Override + protected Expression build(Source source, List args) { + return new CosineSimilarity(source, args.get(0), args.get(1)); + } +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 72b518e2228ee..94f56c3c85367 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -41,6 +41,7 @@ setup: - sum_over_time - count_over_time - distinct_over_time + - cosine_vector_similarity_function reason: "Test that should only be executed on snapshot versions" - do: {xpack.usage: {}} @@ -130,7 +131,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 156} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 157} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version":