77
88package org .elasticsearch .xpack .esql .expression .function .vector ;
99
10- import org .apache .lucene .index .VectorSimilarityFunction ;
1110import org .elasticsearch .common .io .stream .NamedWriteableRegistry ;
1211import org .elasticsearch .common .io .stream .StreamInput ;
13- import org .elasticsearch .common .io .stream .StreamOutput ;
14- import org .elasticsearch .compute .data .Block ;
15- import org .elasticsearch .compute .data .DoubleVector ;
16- import org .elasticsearch .compute .data .FloatBlock ;
17- import org .elasticsearch .compute .data .Page ;
1812import org .elasticsearch .compute .operator .DriverContext ;
1913import org .elasticsearch .compute .operator .EvalOperator ;
2014import org .elasticsearch .xpack .esql .core .expression .Expression ;
2115import org .elasticsearch .xpack .esql .core .tree .NodeInfo ;
2216import org .elasticsearch .xpack .esql .core .tree .Source ;
23- import org .elasticsearch .xpack .esql .core .type .DataType ;
2417import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesTo ;
2518import org .elasticsearch .xpack .esql .expression .function .FunctionAppliesToLifecycle ;
2619import org .elasticsearch .xpack .esql .expression .function .FunctionInfo ;
2720import org .elasticsearch .xpack .esql .expression .function .Param ;
28- import org .elasticsearch .xpack .esql .expression .function .scalar .EsqlScalarFunction ;
2921import org .elasticsearch .xpack .esql .expression .function .scalar .math .Pow ;
30- import org .elasticsearch .xpack .esql .expression .function .scalar .multivalue .AbstractMultivalueFunction ;
3122import org .elasticsearch .xpack .esql .io .stream .PlanStreamInput ;
3223
3324import java .io .IOException ;
3425import java .util .List ;
3526
36- import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
37- import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isNotNull ;
38- import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .isType ;
39- import static org .elasticsearch .xpack .esql .core .type .DataType .DENSE_VECTOR ;
27+ import static org .apache .lucene .index .VectorSimilarityFunction .COSINE ;
4028
41- public class CosineSimilarity extends EsqlScalarFunction implements VectorFunction {
29+ public class CosineSimilarity extends org . elasticsearch . xpack . esql . expression . function . vector . VectorSimilarityFunction {
4230
4331 public static final NamedWriteableRegistry .Entry ENTRY = new NamedWriteableRegistry .Entry (
4432 Expression .class ,
4533 "CosineSimilarity" ,
4634 CosineSimilarity ::new
4735 );
4836
49- private Expression left ;
50- private Expression right ;
51-
5237 @ FunctionInfo (
5338 returnType = "double" ,
5439 preview = true ,
@@ -57,58 +42,20 @@ public class CosineSimilarity extends EsqlScalarFunction implements VectorFuncti
5742 )
5843 public CosineSimilarity (
5944 Source source ,
60- @ Param (name = "left" , type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity" )
61- Expression left ,
62- @ Param (name = "right" , type = { "dense_vector" }, description = "second dense_vector to calculate cosine similarity" )
63- Expression right
45+ @ Param (name = "left" , type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity" ) Expression left ,
46+ @ Param (
47+ name = "right" ,
48+ type = { "dense_vector" },
49+ description = "second dense_vector to calculate cosine similarity"
50+ ) Expression right
6451 ) {
65- super (source , List .of (left , right ));
66- this .left = left ;
67- this .right = right ;
52+ super (source , List .of (left , right ), left , right );
6853 }
6954
7055 private CosineSimilarity (StreamInput in ) throws IOException {
7156 this (Source .readFrom ((PlanStreamInput ) in ), in .readNamedWriteable (Expression .class ), in .readNamedWriteable (Expression .class ));
7257 }
7358
74- @ Override
75- public void writeTo (StreamOutput out ) throws IOException {
76- source ().writeTo (out );
77- out .writeNamedWriteable (left ());
78- out .writeNamedWriteable (right ());
79- }
80-
81- @ Override
82- public DataType dataType () {
83- return DataType .DOUBLE ;
84- }
85-
86- @ Override
87- protected TypeResolution resolveType () {
88- if (childrenResolved () == false ) {
89- return new TypeResolution ("Unresolved children" );
90- }
91-
92- return checkDenseVectorParam (left ()).and (checkDenseVectorParam (right ()));
93- }
94-
95- private TypeResolution checkDenseVectorParam (Expression param ) {
96- return isNotNull (param , sourceText (), FIRST ).and (isType (param , dt -> dt == DENSE_VECTOR , sourceText (), FIRST , "dense_vector" ));
97- }
98-
99- @ Override
100- public Expression replaceChildren (List <Expression > newChildren ) {
101- return new CosineSimilarity (source (), newChildren .get (0 ), newChildren .get (1 ));
102- }
103-
104- public Expression left () {
105- return left ;
106- }
107-
108- public Expression right () {
109- return right ;
110- }
111-
11259 @ Override
11360 protected NodeInfo <? extends Expression > info () {
11461 return NodeInfo .create (this , Pow ::new , left (), right ());
@@ -121,94 +68,22 @@ public String getWriteableName() {
12168
12269 @ Override
12370 public EvalOperator .ExpressionEvaluator .Factory toEvaluator (ToEvaluator toEvaluator ) {
124- return new EvaluatorFactory (toEvaluator .apply (left ()), toEvaluator .apply (right ()));
125- }
126-
127- private record EvaluatorFactory (EvalOperator .ExpressionEvaluator .Factory left , EvalOperator .ExpressionEvaluator .Factory right )
128- implements
129- EvalOperator .ExpressionEvaluator .Factory {
130- @ Override
131- public EvalOperator .ExpressionEvaluator get (DriverContext context ) {
132- return new Evaluator (context , left .get (context ), right .get (context ));
133- }
134-
135- @ Override
136- public String toString () {
137- return "CosineSimilarity[left=" + left + ", right=" + right + "]" ;
138- }
139- }
140-
141- /**
142- * Evaluator for {@link CosineSimilarity}. Not generated and doesn’t extend from
143- * {@link AbstractMultivalueFunction.AbstractEvaluator} because it’s different from {@link org.elasticsearch.compute.ann.MvEvaluator}
144- * or scalar evaluators.
145- *
146- * We can probably generalize to a common class or use its own annotation / evaluator template
147- */
148- private static class Evaluator implements EvalOperator .ExpressionEvaluator {
149- private final DriverContext context ;
150- private final EvalOperator .ExpressionEvaluator left ;
151- private final EvalOperator .ExpressionEvaluator right ;
152-
153- Evaluator (DriverContext context , EvalOperator .ExpressionEvaluator left , EvalOperator .ExpressionEvaluator right ) {
154- this .context = context ;
155- this .left = left ;
156- this .right = right ;
157- }
158-
159- @ Override
160- public final Block eval (Page page ) {
161- try (FloatBlock leftBlock = (FloatBlock ) left .eval (page ); FloatBlock rightBlock = (FloatBlock ) right .eval (page )) {
162- int positionCount = page .getPositionCount ();
163- if (positionCount == 0 ) {
164- return context .blockFactory ().newConstantFloatBlockWith (0F , 0 );
165- }
166-
167- int dimensions = leftBlock .getValueCount (0 );
168- int dimsRight = rightBlock .getValueCount (0 );
169- assert dimensions == dimsRight
170- : "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight ;
171- float [] leftScratch = new float [dimensions ];
172- float [] rightScratch = new float [dimensions ];
173- try (DoubleVector .Builder builder = context .blockFactory ().newDoubleVectorBuilder (positionCount * dimensions )) {
174- for (int p = 0 ; p < positionCount ; p ++) {
175- assert leftBlock .getValueCount (p ) == dimensions
176- : "Left vector must have the same value count for all positions, but got left: "
177- + leftBlock .getValueCount (p )
178- + ", expected: "
179- + dimensions ;
180- assert rightBlock .getValueCount (p ) == dimensions
181- : "Left vector must have the same value count for all positions, but got left: "
182- + rightBlock .getValueCount (p )
183- + ", expected: "
184- + dimensions ;
185-
186- readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
187- readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
188- float result = calculateSimilarity (leftScratch , rightScratch );
189- builder .appendDouble (result );
190- }
191- return builder .build ().asBlock ();
192- }
71+ return new VectorSimilarityFunction .SimilarityEvaluatorFactory (toEvaluator .apply (left ()), toEvaluator .apply (right ())) {
72+ @ Override
73+ protected SimilarityEvaluator getSimilarityEvaluator (DriverContext context ) {
74+ return new CosineSimilarityEvaluator (context , left .get (context ), right .get (context ));
19375 }
194- }
195-
196- private float calculateSimilarity (float [] leftScratch , float [] rightScratch ) {
197- return VectorSimilarityFunction .COSINE .compare (leftScratch , rightScratch );
198- }
76+ };
77+ }
19978
200- private static void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
201- for (int i = 0 ; i < dimensions ; i ++) {
202- scratch [i ] = block .getFloat (position + i );
203- }
79+ private static class CosineSimilarityEvaluator extends VectorSimilarityFunction .SimilarityEvaluator {
80+ CosineSimilarityEvaluator (DriverContext context , EvalOperator .ExpressionEvaluator left , EvalOperator .ExpressionEvaluator right ) {
81+ super (context , left , right );
20482 }
20583
20684 @ Override
207- public final String toString ( ) {
208- return "CosineSimilarity[left=" + left + ", right=" + right + "]" ;
85+ protected float calculateSimilarity ( float [] leftScratch , float [] rightScratch ) {
86+ return COSINE . compare ( leftScratch , rightScratch ) ;
20987 }
210-
211- @ Override
212- public void close () {}
21388 }
21489}
0 commit comments