2424import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesTo ;
2525import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesToLifecycle ;
2626import org .elasticsearch .xpack .esql .expression .function .FunctionInfo ;
27+ import org .elasticsearch .xpack .esql .expression .function .Param ;
2728import org .elasticsearch .xpack .esql .expression .function .scalar .EsqlScalarFunction ;
2829import org .elasticsearch .xpack .esql .expression .function .scalar .math .Pow ;
2930import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .AbstractMultivalueFunction ;
@@ -54,19 +55,32 @@ public class CosineSimilarity extends EsqlScalarFunction implements VectorFuncti
5455 description = "Calculates the cosine similarity between two dense_vectors." ,
5556 appliesTo = { @ FunctionAppliesTo (lifeCycle = FunctionAppliesToLifecycle .DEVELOPMENT ) }
5657 )
57- public CosineSimilarity (Source source , Expression left , Expression right ) {
58+ public CosineSimilarity (
59+ Source source ,
60+ @ Param (name = "left" , type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity" )
61+ Expression left ,
62+ @ Param (name = "right" , type = { "dense_vector" }, description = "second dense_vector to calculate cosine similarity" )
63+ Expression right
64+ ) {
5865 super (source , List .of (left , right ));
5966 this .left = left ;
6067 this .right = right ;
6168 }
6269
70+ private CosineSimilarity (StreamInput in ) throws IOException {
71+ this (Source .readFrom ((PlanStreamInput ) in ), in .readNamedWriteable (Expression .class ), in .readNamedWriteable (Expression .class ));
72+ }
73+
6374 @ Override
64- public DataType dataType () {
65- return DataType .DOUBLE ;
75+ public void writeTo (StreamOutput out ) throws IOException {
76+ source ().writeTo (out );
77+ out .writeNamedWriteable (left ());
78+ out .writeNamedWriteable (right ());
6679 }
6780
68- private CosineSimilarity (StreamInput in ) throws IOException {
69- this (Source .readFrom ((PlanStreamInput ) in ), in .readNamedWriteable (Expression .class ), in .readNamedWriteable (Expression .class ));
81+ @ Override
82+ public DataType dataType () {
83+ return DataType .DOUBLE ;
7084 }
7185
7286 @ Override
@@ -75,20 +89,13 @@ protected TypeResolution resolveType() {
7589 return new TypeResolution ("Unresolved children" );
7690 }
7791
78- return checkParam (left ()).and (checkParam (right ()));
92+ return checkDenseVectorParam (left ()).and (checkDenseVectorParam (right ()));
7993 }
8094
81- private TypeResolution checkParam (Expression param ) {
95+ private TypeResolution checkDenseVectorParam (Expression param ) {
8296 return isNotNull (param , sourceText (), FIRST ).and (isType (param , dt -> dt == DENSE_VECTOR , sourceText (), FIRST , "dense_vector" ));
8397 }
8498
85- @ Override
86- public void writeTo (StreamOutput out ) throws IOException {
87- source ().writeTo (out );
88- out .writeNamedWriteable (left ());
89- out .writeNamedWriteable (right ());
90- }
91-
9299 @ Override
93100 public Expression replaceChildren (List <Expression > newChildren ) {
94101 return new CosineSimilarity (source (), newChildren .get (0 ), newChildren .get (1 ));
@@ -178,15 +185,19 @@ public final Block eval(Page page) {
178185
179186 readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
180187 readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
181- float result = VectorSimilarityFunction . COSINE . compare (leftScratch , rightScratch );
188+ float result = calculateSimilarity (leftScratch , rightScratch );
182189 builder .appendDouble (result );
183190 }
184191 return builder .build ().asBlock ();
185192 }
186193 }
187194 }
188195
189- private void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
196+ private float calculateSimilarity (float [] leftScratch , float [] rightScratch ) {
197+ return VectorSimilarityFunction .COSINE .compare (leftScratch , rightScratch );
198+ }
199+
200+ private static void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
190201 for (int i = 0 ; i < dimensions ; i ++) {
191202 scratch [i ] = block .getFloat (position + i );
192203 }
@@ -200,5 +211,4 @@ public final String toString() {
200211 @ Override
201212 public void close () {}
202213 }
203-
204214}
0 commit comments