Skip to content

Commit 53e96f9

Browse files
committed
Add test infrastructure for VectorSimilarityFunction. Change VectorSimilarityFunction to extend BinaryScalarFunction
1 parent 4b0b772 commit 53e96f9

File tree

4 files changed

+205
-59
lines changed

4 files changed

+205
-59
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.xpack.esql.core.expression.Expression;
13+
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
1314
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1415
import org.elasticsearch.xpack.esql.core.tree.Source;
1516
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
1617
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
1718
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
1819
import org.elasticsearch.xpack.esql.expression.function.Param;
19-
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2020

2121
import java.io.IOException;
2222

@@ -29,6 +29,7 @@ public class CosineSimilarity extends VectorSimilarityFunction {
2929
"CosineSimilarity",
3030
CosineSimilarity::new
3131
);
32+
static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare;
3233

3334
@FunctionInfo(
3435
returnType = "double",
@@ -49,12 +50,17 @@ public CosineSimilarity(
4950
}
5051

5152
private CosineSimilarity(StreamInput in) throws IOException {
52-
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
53+
super(in);
54+
}
55+
56+
@Override
57+
protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
58+
return new CosineSimilarity(source(), newLeft, newRight);
5359
}
5460

5561
@Override
5662
protected SimilarityEvaluatorFunction getSimilarityFunction() {
57-
return COSINE::compare;
63+
return SIMILARITY_FUNCTION;
5864
}
5965

6066
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.vector;
99

10-
import org.elasticsearch.common.io.stream.StreamOutput;
10+
import org.elasticsearch.common.io.stream.StreamInput;
1111
import org.elasticsearch.compute.data.Block;
1212
import org.elasticsearch.compute.data.DoubleVector;
1313
import org.elasticsearch.compute.data.FloatBlock;
@@ -16,13 +16,14 @@
1616
import org.elasticsearch.compute.operator.EvalOperator;
1717
import org.elasticsearch.xpack.esql.EsqlClientException;
1818
import org.elasticsearch.xpack.esql.core.expression.Expression;
19+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1920
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
21+
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
2022
import org.elasticsearch.xpack.esql.core.tree.Source;
2123
import org.elasticsearch.xpack.esql.core.type.DataType;
22-
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
24+
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2325

2426
import java.io.IOException;
25-
import java.util.List;
2627

2728
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
2829
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
@@ -33,22 +34,14 @@
3334
/**
3435
* Base class for vector similarity functions, which compute a similarity score between two dense vectors
3536
*/
36-
public abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
37-
38-
private final Expression left;
39-
private final Expression right;
37+
public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction {
4038

4139
protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
42-
super(source, List.of(left, right));
43-
this.left = left;
44-
this.right = right;
40+
super(source, left, right);
4541
}
4642

47-
@Override
48-
public void writeTo(StreamOutput out) throws IOException {
49-
source().writeTo(out);
50-
out.writeNamedWriteable(left());
51-
out.writeNamedWriteable(right());
43+
protected VectorSimilarityFunction(StreamInput in) throws IOException {
44+
super(in);
5245
}
5346

5447
@Override
@@ -71,19 +64,6 @@ private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.P
7164
);
7265
}
7366

74-
@Override
75-
public Expression replaceChildren(List<Expression> newChildren) {
76-
return new CosineSimilarity(source(), newChildren.get(0), newChildren.get(1));
77-
}
78-
79-
public Expression left() {
80-
return left;
81-
}
82-
83-
public Expression right() {
84-
return right;
85-
}
86-
8767
/**
8868
* Functional interface for evaluating the similarity between two float arrays
8969
*/
@@ -93,8 +73,18 @@ public interface SimilarityEvaluatorFunction {
9373
}
9474

9575
@Override
96-
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
97-
return new SimilarityEvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right()), getSimilarityFunction());
76+
public Object fold(FoldContext ctx) {
77+
return EvaluatorMapper.super.fold(source(), ctx);
78+
}
79+
80+
@Override
81+
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
82+
return new SimilarityEvaluatorFactory(
83+
toEvaluator.apply(left()),
84+
toEvaluator.apply(right()),
85+
getSimilarityFunction(),
86+
getClass().getSimpleName() + "Evaluator"
87+
);
9888
}
9989

10090
/**
@@ -105,7 +95,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator to
10595
private record SimilarityEvaluatorFactory(
10696
EvalOperator.ExpressionEvaluator.Factory left,
10797
EvalOperator.ExpressionEvaluator.Factory right,
108-
SimilarityEvaluatorFunction similarityFunction
98+
SimilarityEvaluatorFunction similarityFunction,
99+
String evaluatorName
109100
) implements EvalOperator.ExpressionEvaluator.Factory {
110101

111102
@Override
@@ -119,34 +110,36 @@ public Block eval(Page page) {
119110
FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page)
120111
) {
121112
int positionCount = page.getPositionCount();
122-
if (positionCount == 0) {
113+
int dimensions = 0;
114+
// Get the first non-empty vector to calculate the dimension
115+
for (int p = 0; p < positionCount; p++) {
116+
if (leftBlock.getValueCount(p) != 0) {
117+
dimensions = leftBlock.getValueCount(p);
118+
break;
119+
}
120+
}
121+
if (dimensions == 0) {
123122
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
124123
}
125124

126-
int dimensions = leftBlock.getValueCount(0);
127-
int dimsRight = rightBlock.getValueCount(0);
128-
if (dimensions != dimsRight) {
129-
throw new EsqlClientException(
130-
"Vectors must have the same dimensions; first vector has {}, and second has {}",
131-
dimensions,
132-
dimsRight
133-
);
134-
}
135125
float[] leftScratch = new float[dimensions];
136126
float[] rightScratch = new float[dimensions];
137127
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
138128
for (int p = 0; p < positionCount; p++) {
139-
assert leftBlock.getValueCount(p) == dimensions
140-
: "Left vector must have the same value count for all positions, but got left: "
141-
+ leftBlock.getValueCount(p)
142-
+ ", expected: "
143-
+ dimensions;
144-
assert rightBlock.getValueCount(p) == dimensions
145-
: "Left vector must have the same value count for all positions, but got left: "
146-
+ rightBlock.getValueCount(p)
147-
+ ", expected: "
148-
+ dimensions;
149-
129+
int dimsLeft = leftBlock.getValueCount(p);
130+
int dimsRight = rightBlock.getValueCount(p);
131+
132+
if (dimsLeft == 0 || dimsRight == 0) {
133+
// A null value on the left or right vector. Similarity is 0
134+
builder.appendDouble(0.0);
135+
continue;
136+
} else if (dimsLeft != dimsRight) {
137+
throw new EsqlClientException(
138+
"Vectors must have the same dimensions; first vector has {}, and second has {}",
139+
dimsLeft,
140+
dimsRight
141+
);
142+
}
150143
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
151144
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
152145
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
@@ -157,13 +150,13 @@ public Block eval(Page page) {
157150
}
158151
}
159152

160-
@Override
161-
public void close() {}
162-
163153
@Override
164154
public String toString() {
165-
return "ExpressionEvaluator[left=" + left + ", right=" + right + "]";
155+
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
166156
}
157+
158+
@Override
159+
public void close() {}
167160
};
168161
}
169162

@@ -172,5 +165,10 @@ private static void readFloatArray(FloatBlock block, int position, int dimension
172165
scratch[i] = block.getFloat(position + i);
173166
}
174167
}
168+
169+
@Override
170+
public String toString() {
171+
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
172+
}
175173
}
176174
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
12+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
13+
import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
14+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
15+
import org.hamcrest.Matcher;
16+
import org.junit.Before;
17+
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.function.Supplier;
21+
22+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
23+
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
24+
import static org.hamcrest.Matchers.equalTo;
25+
26+
public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase {
27+
28+
protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
29+
this.testCase = testCaseSupplier.get();
30+
}
31+
32+
@Before
33+
public void checkCapability() {
34+
assumeTrue("Similarity function is not enabled", capability().isEnabled());
35+
}
36+
37+
/**
38+
* Get the capability of the vector similarity function to check
39+
*/
40+
protected abstract EsqlCapabilities.Cap capability();
41+
42+
protected static Iterable<Object[]> similarityParameters(
43+
String className,
44+
VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction
45+
) {
46+
47+
final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]";
48+
49+
List<TestCaseSupplier> suppliers = new ArrayList<>();
50+
51+
// Basic test with two dense vectors
52+
suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> {
53+
int dimensions = between(64, 128);
54+
List<Float> left = randomDenseVector(dimensions);
55+
List<Float> right = randomDenseVector(dimensions);
56+
float[] leftArray = listToFloatArray(left);
57+
float[] rightArray = listToFloatArray(right);
58+
double expected = similarityFunction.calculateSimilarity(leftArray, rightArray);
59+
return new TestCaseSupplier.TestCase(
60+
List.of(
61+
new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"),
62+
new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2")
63+
),
64+
evaluatorName,
65+
DOUBLE,
66+
equalTo(expected) // Random vectors should have cosine similarity close to 0
67+
);
68+
}));
69+
70+
return parameterSuppliersFromTypedData(suppliers);
71+
}
72+
73+
private static float[] listToFloatArray(List<Float> floatList) {
74+
float[] floatArray = new float[floatList.size()];
75+
for (int i = 0; i < floatList.size(); i++) {
76+
floatArray[i] = floatList.get(i);
77+
}
78+
return floatArray;
79+
}
80+
81+
protected double calculateSimilarity(List<Float> left, List<Float> right) {
82+
return 0;
83+
}
84+
85+
/**
86+
* @return A random dense vector for testing
87+
* @param dimensions
88+
*/
89+
private static List<Float> randomDenseVector(int dimensions) {
90+
List<Float> vector = new ArrayList<>();
91+
for (int i = 0; i < dimensions; i++) {
92+
vector.add(randomFloat());
93+
}
94+
return vector;
95+
}
96+
97+
@Override
98+
protected Matcher<Object> allNullsMatcher() {
99+
// A null value on the left or right vector. Similarity is 0
100+
return equalTo(0.0);
101+
}
102+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
12+
13+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
14+
import org.elasticsearch.xpack.esql.core.expression.Expression;
15+
import org.elasticsearch.xpack.esql.core.tree.Source;
16+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
17+
18+
import java.util.List;
19+
import java.util.function.Supplier;
20+
21+
public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
22+
23+
public CosineSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
24+
super(testCaseSupplier);
25+
}
26+
27+
@ParametersFactory
28+
public static Iterable<Object[]> parameters() {
29+
return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION);
30+
}
31+
32+
protected EsqlCapabilities.Cap capability() {
33+
return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION;
34+
}
35+
36+
@Override
37+
protected Expression build(Source source, List<Expression> args) {
38+
return new CosineSimilarity(source, args.get(0), args.get(1));
39+
}
40+
}

0 commit comments

Comments
 (0)