|
17 | 17 | import org.elasticsearch.xpack.esql.core.expression.Expression; |
18 | 18 | import org.elasticsearch.xpack.esql.core.tree.Source; |
19 | 19 | import org.elasticsearch.xpack.esql.core.type.DataType; |
20 | | -import org.elasticsearch.xpack.esql.expression.function.Param; |
21 | 20 | import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; |
22 | 21 |
|
23 | 22 | import java.io.IOException; |
|
28 | 27 | import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; |
29 | 28 | import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; |
30 | 29 |
|
31 | | -public abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction { |
32 | | - protected Expression left; |
33 | | - protected Expression right; |
34 | | - |
35 | | - public VectorSimilarityFunction( |
36 | | - Source source, |
37 | | - List<Expression> fields, |
38 | | - @Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left, |
39 | | - @Param( |
40 | | - name = "right", |
41 | | - type = { "dense_vector" }, |
42 | | - description = "second dense_vector to calculate cosine similarity" |
43 | | - ) Expression right |
44 | | - ) { |
45 | | - super(source, fields); |
| 30 | +/** |
| 31 | + * Base class for vector similarity functions, which compute a similarity score between two dense vectors |
| 32 | + */ |
| 33 | +abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction { |
| 34 | + |
| 35 | + private final Expression left; |
| 36 | + private final Expression right; |
| 37 | + |
| 38 | + protected VectorSimilarityFunction(Source source, Expression left, Expression right) { |
| 39 | + super(source, List.of(left, right)); |
46 | 40 | this.left = left; |
47 | 41 | this.right = right; |
48 | 42 | } |
@@ -85,29 +79,33 @@ public Expression right() { |
85 | 79 | return right; |
86 | 80 | } |
87 | 81 |
|
| 82 | + /** |
| 83 | + * Functional interface for evaluating the similarity between two float arrays |
| 84 | + */ |
88 | 85 | @FunctionalInterface |
89 | 86 | public interface SimilarityEvaluatorFunction { |
90 | 87 | float calculateSimilarity(float[] leftScratch, float[] rightScratch); |
91 | 88 | } |
92 | 89 |
|
93 | | - protected class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory { |
| 90 | + @Override |
| 91 | + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { |
| 92 | + return new SimilarityEvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right()), getSimilarityFunction()); |
| 93 | + } |
94 | 94 |
|
95 | | - private final EvalOperator.ExpressionEvaluator.Factory left; |
96 | | - private final EvalOperator.ExpressionEvaluator.Factory right; |
97 | | - private final SimilarityEvaluatorFunction similarityFunction; |
| 95 | + /** |
| 96 | + * Returns the similarity function to be used for evaluating the similarity between two vectors. |
| 97 | + */ |
| 98 | + protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); |
98 | 99 |
|
99 | | - SimilarityEvaluatorFactory( |
100 | | - EvalOperator.ExpressionEvaluator.Factory left, |
101 | | - EvalOperator.ExpressionEvaluator.Factory right, |
102 | | - SimilarityEvaluatorFunction similarityFunction |
103 | | - ) { |
104 | | - this.left = left; |
105 | | - this.right = right; |
106 | | - this.similarityFunction = similarityFunction; |
107 | | - } |
| 100 | + private record SimilarityEvaluatorFactory( |
| 101 | + EvalOperator.ExpressionEvaluator.Factory left, |
| 102 | + EvalOperator.ExpressionEvaluator.Factory right, |
| 103 | + SimilarityEvaluatorFunction similarityFunction |
| 104 | + ) implements EvalOperator.ExpressionEvaluator.Factory { |
108 | 105 |
|
109 | 106 | @Override |
110 | 107 | public EvalOperator.ExpressionEvaluator get(DriverContext context) { |
| 108 | + // TODO check whether to use this custom evaluator or reuse / define an existing one |
111 | 109 | return new EvalOperator.ExpressionEvaluator() { |
112 | 110 | @Override |
113 | 111 | public Block eval(Page page) { |
|
0 commit comments