Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Tests for l1_norm similarity function

similarityWithVectorField
required_capability: l1_norm_vector_similarity_function

// tag::vector-l1-norm-similarity[]
from colors
| eval similarity = v_l1_norm(rgb_vector, [0, 255, 255])
| sort similarity desc, color asc
// end::vector-l1-norm-similarity[]
| limit 10
| keep color, similarity
;

// tag::vector-l1-norm-similarity-result[]
color:text | similarity:double
red | 765.0
crimson | 650.0
maroon | 638.0
firebrick | 620.0
orange | 600.0
tomato | 595.0
brown | 591.0
chocolate | 585.0
coral | 558.0
gold | 550.0
// end::vector-l1-norm-similarity-result[]
;

similarityAsPartOfExpression
required_capability: l1_norm_vector_similarity_function

from colors
| eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
red | 383.5
crimson | 326.0
maroon | 320.0
firebrick | 311.0
orange | 301.0
tomato | 298.5
brown | 296.5
chocolate | 293.5
coral | 280.0
gold | 276.0
;

similarityWithLiteralVectors
required_capability: l1_norm_vector_similarity_function

row a = 1
| eval similarity = round(v_l1_norm([1, 2, 3], [0, 1, 2]), 3)
| keep similarity
;

similarity:double
3.0
;

similarityWithStats
required_capability: l1_norm_vector_similarity_function

from colors
| eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
391.254 | 0.0 | 765.0
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: l1_norm_vector_similarity_function

row vector = [1, 2, 3]
| eval similarity = round(v_l1_norm(vector, [0, 1, 2]), 3)
| sort similarity desc, color asc
| limit 10
| keep color, similarity
;

similarity:double
0.978
;
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction;
import org.junit.Before;

import java.io.IOException;
Expand All @@ -37,22 +39,25 @@ public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();

if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare });
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

}
if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_dot_product", VectorSimilarityFunction.DOT_PRODUCT });
params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare });
}
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity });
}

return params;
}

private final String functionName;
private final VectorSimilarityFunction similarityFunction;
private final SimilarityEvaluatorFunction similarityFunction;
private int numDims;

public VectorSimilarityFunctionsIT(
@Name("functionName") String functionName,
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
@Name("similarityFunction") SimilarityEvaluatorFunction similarityFunction
) {
this.functionName = functionName;
this.similarityFunction = similarityFunction;
Expand All @@ -74,7 +79,7 @@ public void testSimilarityBetweenVectors() {
Double similarity = (Double) values.get(2);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, right);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
Expand All @@ -96,7 +101,7 @@ public void testSimilarityBetweenConstantVectorAndField() {
Double similarity = (Double) values.get(1);

assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(left, randomVector);
float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
assertEquals(expectedSimilarity, similarity, 0.0001);
});
}
Expand Down Expand Up @@ -130,7 +135,7 @@ public void testSimilarityBetweenConstantVectors() {

Double similarity = (Double) valuesList.get(0).get(0);
assertNotNull(similarity);
float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
assertEquals(expectedSimilarity, similarity, 0.0001);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,11 @@ public enum Cap {
*/
DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),

/**
* l1 norm vector similarity function
*/
L1_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),

/**
* Support for the options field of CATEGORIZE.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;

Expand Down Expand Up @@ -493,7 +494,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"),
def(DotProduct.class, DotProduct::new, "v_dot_product") } };
def(DotProduct.class, DotProduct::new, "v_dot_product"),
def(L1Norm.class, L1Norm::new, "v_l1_norm") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.vector;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;

import java.io.IOException;

public class L1Norm extends VectorSimilarityFunction {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L1Norm", L1Norm::new);
static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = L1Norm::calculateSimilarity;

@FunctionInfo(
returnType = "double",
preview = true,
description = "Calculates the l1 norm between two dense_vectors.",
examples = { @Example(file = "vector-l1-norm", tag = "vector-l1-norm") },
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public L1Norm(
Source source,
@Param(
name = "left",
type = { "dense_vector" },
description = "first dense_vector to calculate l1 norm similarity"
) Expression left,
@Param(
name = "right",
type = { "dense_vector" },
description = "second dense_vector to calculate l1 norm similarity"
) Expression right
) {
super(source, left, right);
}

private L1Norm(StreamInput in) throws IOException {
super(in);
}

@Override
protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
return new L1Norm(source(), newLeft, newRight);
}

@Override
protected SimilarityEvaluatorFunction getSimilarityFunction() {
return SIMILARITY_FUNCTION;
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, L1Norm::new, left(), right());
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

public static float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
float result = 0f;
for (int i = 0; i < leftScratch.length; i++) {
result += Math.abs(leftScratch[i] - rightScratch[i]);
}
return result;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
entries.add(DotProduct.ENTRY);
}
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
entries.add(L1Norm.ENTRY);
}

return Collections.unmodifiableList(entries);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,10 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() {
);
checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
}
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
}
}

private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
Expand All @@ -2391,6 +2395,9 @@ public void testNoDenseVectorFailsSimilarityFunction() {
if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkNoDenseVectorFailsSimilarityFunction("v_dot_product([0, 1, 2], 0.342)");
}
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkNoDenseVectorFailsSimilarityFunction("v_l1_norm([0, 1, 2], 0.342)");
}
}

private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,10 @@ public void testVectorSimilarityFunctionsNullArgs() throws Exception {
checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)", "first");
checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)", "second");
}
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first");
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second");
}
}

private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
Expand Down
Loading
Loading