Skip to content

Commit 16542b5

Browse files
committed
Refactor
1 parent 31faca1 commit 16542b5

File tree

1 file changed

+27
-17
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector

1 file changed

+27
-17
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
2525
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
2626
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
27+
import org.elasticsearch.xpack.esql.expression.function.Param;
2728
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
2829
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
2930
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
@@ -54,19 +55,32 @@ public class CosineSimilarity extends EsqlScalarFunction implements VectorFuncti
5455
description = "Calculates the cosine similarity between two dense_vectors.",
5556
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
5657
)
57-
public CosineSimilarity(Source source, Expression left, Expression right) {
58+
public CosineSimilarity(
59+
Source source,
60+
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity")
61+
Expression left,
62+
@Param(name = "right", type = { "dense_vector" }, description = "second dense_vector to calculate cosine similarity")
63+
Expression right
64+
) {
5865
super(source, List.of(left, right));
5966
this.left = left;
6067
this.right = right;
6168
}
6269

70+
private CosineSimilarity(StreamInput in) throws IOException {
71+
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
72+
}
73+
6374
@Override
64-
public DataType dataType() {
65-
return DataType.DOUBLE;
75+
public void writeTo(StreamOutput out) throws IOException {
76+
source().writeTo(out);
77+
out.writeNamedWriteable(left());
78+
out.writeNamedWriteable(right());
6679
}
6780

68-
private CosineSimilarity(StreamInput in) throws IOException {
69-
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
81+
@Override
82+
public DataType dataType() {
83+
return DataType.DOUBLE;
7084
}
7185

7286
@Override
@@ -75,20 +89,13 @@ protected TypeResolution resolveType() {
7589
return new TypeResolution("Unresolved children");
7690
}
7791

78-
return checkParam(left()).and(checkParam(right()));
92+
return checkDenseVectorParam(left()).and(checkDenseVectorParam(right()));
7993
}
8094

81-
private TypeResolution checkParam(Expression param) {
95+
private TypeResolution checkDenseVectorParam(Expression param) {
8296
return isNotNull(param, sourceText(), FIRST).and(isType(param, dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
8397
}
8498

85-
@Override
86-
public void writeTo(StreamOutput out) throws IOException {
87-
source().writeTo(out);
88-
out.writeNamedWriteable(left());
89-
out.writeNamedWriteable(right());
90-
}
91-
9299
@Override
93100
public Expression replaceChildren(List<Expression> newChildren) {
94101
return new CosineSimilarity(source(), newChildren.get(0), newChildren.get(1));
@@ -178,15 +185,19 @@ public final Block eval(Page page) {
178185

179186
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
180187
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
181-
float result = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
188+
float result = calculateSimilarity(leftScratch, rightScratch);
182189
builder.appendDouble(result);
183190
}
184191
return builder.build().asBlock();
185192
}
186193
}
187194
}
188195

189-
private void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
196+
private float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
197+
return VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
198+
}
199+
200+
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
190201
for (int i = 0; i < dimensions; i++) {
191202
scratch[i] = block.getFloat(position + i);
192203
}
@@ -200,5 +211,4 @@ public final String toString() {
200211
@Override
201212
public void close() {}
202213
}
203-
204214
}

0 commit comments

Comments
 (0)