Skip to content

Commit 4508154

Browse files
committed
Refactoring
1 parent b441d60 commit 4508154

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,16 @@
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
12-
import org.elasticsearch.compute.operator.EvalOperator;
1312
import org.elasticsearch.xpack.esql.core.expression.Expression;
1413
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1514
import org.elasticsearch.xpack.esql.core.tree.Source;
1615
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
1716
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
1817
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
1918
import org.elasticsearch.xpack.esql.expression.function.Param;
20-
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
2119
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2220

2321
import java.io.IOException;
24-
import java.util.List;
2522

2623
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
2724

@@ -48,29 +45,25 @@ public CosineSimilarity(
4845
description = "second dense_vector to calculate cosine similarity"
4946
) Expression right
5047
) {
51-
super(source, List.of(left, right), left, right);
48+
super(source, left, right);
5249
}
5350

5451
private CosineSimilarity(StreamInput in) throws IOException {
5552
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
5653
}
5754

5855
@Override
59-
protected NodeInfo<? extends Expression> info() {
60-
return NodeInfo.create(this, Pow::new, left(), right());
56+
protected SimilarityEvaluatorFunction getSimilarityFunction() {
57+
return COSINE::compare;
6158
}
6259

6360
@Override
64-
public String getWriteableName() {
65-
return ENTRY.name;
61+
protected NodeInfo<? extends Expression> info() {
62+
return NodeInfo.create(this, CosineSimilarity::new, left(), right());
6663
}
6764

6865
@Override
69-
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
70-
return new SimilarityEvaluatorFactory(
71-
toEvaluator.apply(left()),
72-
toEvaluator.apply(right()),
73-
(leftScratch, rightScratch) -> COSINE.compare(leftScratch, rightScratch)
74-
);
66+
public String getWriteableName() {
67+
return ENTRY.name;
7568
}
7669
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.xpack.esql.core.expression.Expression;
1818
import org.elasticsearch.xpack.esql.core.tree.Source;
1919
import org.elasticsearch.xpack.esql.core.type.DataType;
20-
import org.elasticsearch.xpack.esql.expression.function.Param;
2120
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
2221

2322
import java.io.IOException;
@@ -28,21 +27,16 @@
2827
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
2928
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
3029

31-
public abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
32-
protected Expression left;
33-
protected Expression right;
34-
35-
public VectorSimilarityFunction(
36-
Source source,
37-
List<Expression> fields,
38-
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
39-
@Param(
40-
name = "right",
41-
type = { "dense_vector" },
42-
description = "second dense_vector to calculate cosine similarity"
43-
) Expression right
44-
) {
45-
super(source, fields);
30+
/**
31+
* Base class for vector similarity functions, which compute a similarity score between two dense vectors
32+
*/
33+
abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
34+
35+
private final Expression left;
36+
private final Expression right;
37+
38+
protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
39+
super(source, List.of(left, right));
4640
this.left = left;
4741
this.right = right;
4842
}
@@ -85,29 +79,33 @@ public Expression right() {
8579
return right;
8680
}
8781

82+
/**
83+
* Functional interface for evaluating the similarity between two float arrays
84+
*/
8885
@FunctionalInterface
8986
public interface SimilarityEvaluatorFunction {
9087
float calculateSimilarity(float[] leftScratch, float[] rightScratch);
9188
}
9289

93-
protected class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory {
90+
@Override
91+
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
92+
return new SimilarityEvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right()), getSimilarityFunction());
93+
}
9494

95-
private final EvalOperator.ExpressionEvaluator.Factory left;
96-
private final EvalOperator.ExpressionEvaluator.Factory right;
97-
private final SimilarityEvaluatorFunction similarityFunction;
95+
/**
96+
* Returns the similarity function to be used for evaluating the similarity between two vectors.
97+
*/
98+
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
9899

99-
SimilarityEvaluatorFactory(
100-
EvalOperator.ExpressionEvaluator.Factory left,
101-
EvalOperator.ExpressionEvaluator.Factory right,
102-
SimilarityEvaluatorFunction similarityFunction
103-
) {
104-
this.left = left;
105-
this.right = right;
106-
this.similarityFunction = similarityFunction;
107-
}
100+
private record SimilarityEvaluatorFactory(
101+
EvalOperator.ExpressionEvaluator.Factory left,
102+
EvalOperator.ExpressionEvaluator.Factory right,
103+
SimilarityEvaluatorFunction similarityFunction
104+
) implements EvalOperator.ExpressionEvaluator.Factory {
108105

109106
@Override
110107
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
108+
// TODO check whether to use this custom evaluator or reuse / define an existing one
111109
return new EvalOperator.ExpressionEvaluator() {
112110
@Override
113111
public Block eval(Page page) {

0 commit comments

Comments
 (0)