Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Tests for l2_norm similarity function

similarityWithVectorField
required_capability: l2_norm_vector_similarity_function

// tag::vector-l2-norm[]
from colors
| eval similarity = v_l2_norm(rgb_vector, [0, 255, 255])
| sort similarity desc, color asc
// end::vector-l2-norm[]
| limit 10
| keep color, similarity
;

// tag::vector-l2-norm-result[]
color:text | similarity:double
red | 441.6729431152344
maroon | 382.6669616699219
crimson | 376.36419677734375
orange | 371.68536376953125
gold | 362.8360595703125
black | 360.62445068359375
magenta | 360.62445068359375
yellow | 360.62445068359375
firebrick | 359.67486572265625
tomato | 351.0227966308594
// end::vector-l2-norm-result[]
;

similarityAsPartOfExpression
required_capability: l2_norm_vector_similarity_function

from colors
| eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3)
| sort score desc, color asc
| limit 10
| keep color, score
;

color:text | score:double
red | 221.836
maroon | 192.333
crimson | 189.182
orange | 186.843
gold | 182.418
black | 181.312
magenta | 181.312
yellow | 181.312
firebrick | 180.837
tomato | 176.511
;

similarityWithLiteralVectors
required_capability: l2_norm_vector_similarity_function

row a = 1
| eval similarity = round(v_l2_norm([1, 2, 3], [0, 1, 2]), 3)
| keep similarity
;

similarity:double
1.732
;

similarityWithStats
required_capability: l2_norm_vector_similarity_function

from colors
| eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3)
| stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
;

avg:double | min:double | max:double
274.974 | 0.0 | 441.673
;

# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
similarityWithRow-Ignore
required_capability: l2_norm_vector_similarity_function

row vector = [1, 2, 3]
| eval similarity = round(v_l2_norm(vector, [0, 1, 2]), 3)
| sort similarity desc, color asc
| limit 10
| keep color, similarity
;

similarity:double
0.978
;
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction;
import org.junit.Before;

Expand All @@ -47,6 +48,9 @@ public static Iterable<Object[]> parameters() throws Exception {
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity });
}
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity });
}

return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,11 @@ public enum Cap {
*/
L1_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),

/**
* l2 norm vector similarity function
*/
L2_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),

/**
* Support for the options field of CATEGORIZE.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;

Expand Down Expand Up @@ -495,7 +496,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"),
def(DotProduct.class, DotProduct::new, "v_dot_product"),
def(L1Norm.class, L1Norm::new, "v_l1_norm") } };
def(L1Norm.class, L1Norm::new, "v_l1_norm"),
def(L2Norm.class, L2Norm::new, "v_l2_norm") } };
}

public EsqlFunctionRegistry snapshotRegistry() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.vector;

import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;

import java.io.IOException;

public class L2Norm extends VectorSimilarityFunction {

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L2Norm", L2Norm::new);
static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = L2Norm::calculateSimilarity;

@FunctionInfo(
returnType = "double",
preview = true,
description = "Calculates the l2 norm between two dense_vectors.",
examples = { @Example(file = "vector-l2-norm", tag = "vector-l2-norm") },
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public L2Norm(
Source source,
@Param(
name = "left",
type = { "dense_vector" },
description = "first dense_vector to calculate l2 norm similarity"
) Expression left,
@Param(
name = "right",
type = { "dense_vector" },
description = "second dense_vector to calculate l2 norm similarity"
) Expression right
) {
super(source, left, right);
}

private L2Norm(StreamInput in) throws IOException {
super(in);
}

@Override
protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
return new L2Norm(source(), newLeft, newRight);
}

@Override
protected SimilarityEvaluatorFunction getSimilarityFunction() {
return SIMILARITY_FUNCTION;
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, L2Norm::new, left(), right());
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

public static float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
return (float) Math.sqrt(VectorUtil.squareDistance(leftScratch, rightScratch));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
entries.add(L1Norm.ENTRY);
}
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
entries.add(L2Norm.ENTRY);
}

return Collections.unmodifiableList(entries);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2369,6 +2369,10 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() {
checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
}
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
}
}

private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
Expand Down Expand Up @@ -2398,6 +2402,9 @@ public void testNoDenseVectorFailsSimilarityFunction() {
if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkNoDenseVectorFailsSimilarityFunction("v_l1_norm([0, 1, 2], 0.342)");
}
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkNoDenseVectorFailsSimilarityFunction("v_l2_norm([0, 1, 2], 0.342)");
}
}

private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,10 @@ public void testVectorSimilarityFunctionsNullArgs() throws Exception {
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first");
checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second");
}
if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)", "first");
checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)", "second");
}
}

private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
Expand Down
Loading
Loading