77
88package org .elasticsearch .xpack .esql .expression .function .vector ;
99
10- import org .elasticsearch .common .io .stream .StreamOutput ;
10+ import org .elasticsearch .common .io .stream .StreamInput ;
1111import org .elasticsearch .compute .data .Block ;
1212import org .elasticsearch .compute .data .DoubleVector ;
1313import org .elasticsearch .compute .data .FloatBlock ;
1616import org .elasticsearch .compute .operator .EvalOperator ;
1717import org .elasticsearch .xpack .esql .EsqlClientException ;
1818import org .elasticsearch .xpack .esql .core .expression .Expression ;
19+ import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
1920import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
21+ import org .elasticsearch .xpack .esql .core .expression .function .scalar .BinaryScalarFunction ;
2022import org .elasticsearch .xpack .esql .core .tree .Source ;
2123import org .elasticsearch .xpack .esql .core .type .DataType ;
22- import org .elasticsearch .xpack .esql .expression . function . scalar . EsqlScalarFunction ;
24+ import org .elasticsearch .xpack .esql .evaluator . mapper . EvaluatorMapper ;
2325
2426import java .io .IOException ;
25- import java .util .List ;
2627
2728import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
2829import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
3334/**
3435 * Base class for vector similarity functions, which compute a similarity score between two dense vectors
3536 */
36- public abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
37-
38- private final Expression left ;
39- private final Expression right ;
37+ public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper , VectorFunction {
4038
4139 protected VectorSimilarityFunction (Source source , Expression left , Expression right ) {
42- super (source , List .of (left , right ));
43- this .left = left ;
44- this .right = right ;
40+ super (source , left , right );
4541 }
4642
47- @ Override
48- public void writeTo (StreamOutput out ) throws IOException {
49- source ().writeTo (out );
50- out .writeNamedWriteable (left ());
51- out .writeNamedWriteable (right ());
43+ protected VectorSimilarityFunction (StreamInput in ) throws IOException {
44+ super (in );
5245 }
5346
5447 @ Override
@@ -71,19 +64,6 @@ private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.P
7164 );
7265 }
7366
74- @ Override
75- public Expression replaceChildren (List <Expression > newChildren ) {
76- return new CosineSimilarity (source (), newChildren .get (0 ), newChildren .get (1 ));
77- }
78-
79- public Expression left () {
80- return left ;
81- }
82-
83- public Expression right () {
84- return right ;
85- }
86-
8767 /**
8868 * Functional interface for evaluating the similarity between two float arrays
8969 */
@@ -93,8 +73,18 @@ public interface SimilarityEvaluatorFunction {
9373 }
9474
9575 @ Override
96- public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (ToEvaluator toEvaluator ) {
97- return new SimilarityEvaluatorFactory (toEvaluator .apply (left ()), toEvaluator .apply (right ()), getSimilarityFunction ());
76+ public Object fold (FoldContext ctx ) {
77+ return EvaluatorMapper .super .fold (source (), ctx );
78+ }
79+
80+ @ Override
81+ public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (EvaluatorMapper .ToEvaluator toEvaluator ) {
82+ return new SimilarityEvaluatorFactory (
83+ toEvaluator .apply (left ()),
84+ toEvaluator .apply (right ()),
85+ getSimilarityFunction (),
86+ getClass ().getSimpleName () + "Evaluator"
87+ );
9888 }
9989
10090 /**
@@ -105,7 +95,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator to
10595 private record SimilarityEvaluatorFactory (
10696 EvalOperator .ExpressionEvaluator .Factory left ,
10797 EvalOperator .ExpressionEvaluator .Factory right ,
108- SimilarityEvaluatorFunction similarityFunction
98+ SimilarityEvaluatorFunction similarityFunction ,
99+ String evaluatorName
109100 ) implements EvalOperator .ExpressionEvaluator .Factory {
110101
111102 @ Override
@@ -119,34 +110,36 @@ public Block eval(Page page) {
119110 FloatBlock rightBlock = (FloatBlock ) right .get (context ).eval (page )
120111 ) {
121112 int positionCount = page .getPositionCount ();
122- if (positionCount == 0 ) {
113+ int dimensions = 0 ;
114+ // Get the first non-empty vector to calculate the dimension
115+ for (int p = 0 ; p < positionCount ; p ++) {
116+ if (leftBlock .getValueCount (p ) != 0 ) {
117+ dimensions = leftBlock .getValueCount (p );
118+ break ;
119+ }
120+ }
121+ if (dimensions == 0 ) {
123122 return context .blockFactory ().newConstantFloatBlockWith (0F , 0 );
124123 }
125124
126- int dimensions = leftBlock .getValueCount (0 );
127- int dimsRight = rightBlock .getValueCount (0 );
128- if (dimensions != dimsRight ) {
129- throw new EsqlClientException (
130- "Vectors must have the same dimensions; first vector has {}, and second has {}" ,
131- dimensions ,
132- dimsRight
133- );
134- }
135125 float [] leftScratch = new float [dimensions ];
136126 float [] rightScratch = new float [dimensions ];
137127 try (DoubleVector .Builder builder = context .blockFactory ().newDoubleVectorBuilder (positionCount * dimensions )) {
138128 for (int p = 0 ; p < positionCount ; p ++) {
139- assert leftBlock .getValueCount (p ) == dimensions
140- : "Left vector must have the same value count for all positions, but got left: "
141- + leftBlock .getValueCount (p )
142- + ", expected: "
143- + dimensions ;
144- assert rightBlock .getValueCount (p ) == dimensions
145- : "Left vector must have the same value count for all positions, but got left: "
146- + rightBlock .getValueCount (p )
147- + ", expected: "
148- + dimensions ;
149-
129+ int dimsLeft = leftBlock .getValueCount (p );
130+ int dimsRight = rightBlock .getValueCount (p );
131+
132+ if (dimsLeft == 0 || dimsRight == 0 ) {
133+ // A null value on the left or right vector. Similarity is 0
134+ builder .appendDouble (0.0 );
135+ continue ;
136+ } else if (dimsLeft != dimsRight ) {
137+ throw new EsqlClientException (
138+ "Vectors must have the same dimensions; first vector has {}, and second has {}" ,
139+ dimsLeft ,
140+ dimsRight
141+ );
142+ }
150143 readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
151144 readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
152145 float result = similarityFunction .calculateSimilarity (leftScratch , rightScratch );
@@ -157,13 +150,13 @@ public Block eval(Page page) {
157150 }
158151 }
159152
160- @ Override
161- public void close () {}
162-
163153 @ Override
164154 public String toString () {
165- return "ExpressionEvaluator [left=" + left + ", right=" + right + "]" ;
155+ return evaluatorName () + " [left=" + left + ", right=" + right + "]" ;
166156 }
157+
158+ @ Override
159+ public void close () {}
167160 };
168161 }
169162
@@ -172,5 +165,10 @@ private static void readFloatArray(FloatBlock block, int position, int dimension
172165 scratch [i ] = block .getFloat (position + i );
173166 }
174167 }
168+
169+ @ Override
170+ public String toString () {
171+ return evaluatorName () + "[left=" + left + ", right=" + right + "]" ;
172+ }
175173 }
176174}
0 commit comments