1616import org .elasticsearch .compute .data .Page ;
1717import org .elasticsearch .compute .operator .DriverContext ;
1818import org .elasticsearch .compute .operator .EvalOperator ;
19+ import org .elasticsearch .core .Releasable ;
1920import org .elasticsearch .core .Releasables ;
2021import org .elasticsearch .xpack .esql .EsqlClientException ;
2122import org .elasticsearch .xpack .esql .core .expression .Expression ;
2223import org .elasticsearch .xpack .esql .core .expression .FieldAttribute ;
2324import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
25+ import org .elasticsearch .xpack .esql .core .expression .Literal ;
2426import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
2527import org .elasticsearch .xpack .esql .core .expression .function .scalar .BinaryScalarFunction ;
2628import org .elasticsearch .xpack .esql .core .tree .Source ;
2729import org .elasticsearch .xpack .esql .core .type .DataType ;
2830import org .elasticsearch .xpack .esql .evaluator .mapper .EvaluatorMapper ;
2931
3032import java .io .IOException ;
33+ import java .util .ArrayList ;
34+ import java .util .Arrays ;
35+ import java .util .List ;
3136
3237import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
3338import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
@@ -79,23 +84,13 @@ public Object fold(FoldContext ctx) {
7984 return EvaluatorMapper .super .fold (source (), ctx );
8085 }
8186
82- @ Override
83- public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (EvaluatorMapper .ToEvaluator toEvaluator ) {
84- return new SimilarityEvaluatorFactory (
85- toEvaluator .apply (left ()),
86- toEvaluator .apply (right ()),
87- getSimilarityFunction (),
88- getClass ().getSimpleName () + "Evaluator"
89- );
90- }
91-
9287 @ Override
9388 public PushableOptions pushableOptions () {
9489 // Can only push if one of the arguments is a literal, and the other a field name
9590 return left ().foldable () && left ().dataType () != NULL && right () instanceof FieldAttribute
9691 || right ().foldable () && right ().dataType () != NULL && left () instanceof FieldAttribute
97- ? PushableOptions .PREFERRED
98- : PushableOptions .NOT_SUPPORTED ;
92+ ? PushableOptions .PREFERRED
93+ : PushableOptions .NOT_SUPPORTED ;
9994 }
10095
10196 @ Override
@@ -107,14 +102,38 @@ public final String asScript() {
107102
108103 protected abstract String scriptFunctionName ();
109104
105+ @ Override
106+ public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (EvaluatorMapper .ToEvaluator toEvaluator ) {
107+ VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory (left (), toEvaluator );
108+ VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory (right (), toEvaluator );
109+ return new SimilarityEvaluatorFactory (
110+ leftVectorProviderFactory ,
111+ rightVectorProviderFactory ,
112+ getSimilarityFunction (),
113+ getClass ().getSimpleName () + "Evaluator"
114+ );
115+ }
116+
117+ @ SuppressWarnings ("unchecked" )
118+ private static VectorValueProviderFactory getVectorValueProviderFactory (
119+ Expression expression ,
120+ EvaluatorMapper .ToEvaluator toEvaluator
121+ ) {
122+ if (expression instanceof Literal ) {
123+ return new ConstantVectorProvider .Factory ((ArrayList <Float >) ((Literal ) expression ).value ());
124+ } else {
125+ return new ExpressionVectorProvider .Factory (toEvaluator .apply (expression ));
126+ }
127+ }
128+
110129 /**
111130 * Returns the similarity function to be used for evaluating the similarity between two vectors.
112131 */
113132 protected abstract SimilarityEvaluatorFunction getSimilarityFunction ();
114133
115134 private record SimilarityEvaluatorFactory (
116- EvalOperator . ExpressionEvaluator . Factory left ,
117- EvalOperator . ExpressionEvaluator . Factory right ,
135+ VectorValueProviderFactory leftVectorProviderFactory ,
136+ VectorValueProviderFactory rightVectorProviderFactory ,
118137 SimilarityEvaluatorFunction similarityFunction ,
119138 String evaluatorName
120139 ) implements EvalOperator .ExpressionEvaluator .Factory {
@@ -123,8 +142,8 @@ private record SimilarityEvaluatorFactory(
123142 public EvalOperator .ExpressionEvaluator get (DriverContext context ) {
124143 // TODO check whether to use this custom evaluator or reuse / define an existing one
125144 return new SimilarityEvaluator (
126- left . get (context ),
127- right . get (context ),
145+ leftVectorProviderFactory . build (context ),
146+ rightVectorProviderFactory . build (context ),
128147 similarityFunction ,
129148 evaluatorName ,
130149 context .blockFactory ()
@@ -133,13 +152,13 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
133152
134153 @ Override
135154 public String toString () {
136- return evaluatorName () + "[left=" + left + ", right=" + right + "]" ;
155+ return evaluatorName () + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]" ;
137156 }
138157 }
139158
140159 private record SimilarityEvaluator (
141- EvalOperator . ExpressionEvaluator left ,
142- EvalOperator . ExpressionEvaluator right ,
160+ VectorValueProvider left ,
161+ VectorValueProvider right ,
143162 SimilarityEvaluatorFunction similarityFunction ,
144163 String evaluatorName ,
145164 BlockFactory blockFactory
@@ -149,26 +168,23 @@ private record SimilarityEvaluator(
149168
150169 @ Override
151170 public Block eval (Page page ) {
152- try (FloatBlock leftBlock = (FloatBlock ) left .eval (page ); FloatBlock rightBlock = (FloatBlock ) right .eval (page )) {
171+ try {
172+ left .eval (page );
173+ right .eval (page );
174+
175+ int dimensions = left .getDimensions ();
153176 int positionCount = page .getPositionCount ();
154- int dimensions = 0 ;
155- // Get the first non-empty vector to calculate the dimension
156- for (int p = 0 ; p < positionCount ; p ++) {
157- if (leftBlock .getValueCount (p ) != 0 ) {
158- dimensions = leftBlock .getValueCount (p );
159- break ;
160- }
161- }
162177 if (dimensions == 0 ) {
163178 return blockFactory .newConstantNullBlock (positionCount );
164179 }
165180
166- float [] leftScratch = new float [dimensions ];
167- float [] rightScratch = new float [dimensions ];
168- try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount * dimensions )) {
181+ try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount )) {
169182 for (int p = 0 ; p < positionCount ; p ++) {
170- int dimsLeft = leftBlock .getValueCount (p );
171- int dimsRight = rightBlock .getValueCount (p );
183+ float [] leftVector = left .getVector (p );
184+ float [] rightVector = right .getVector (p );
185+
186+ int dimsLeft = leftVector == null ? 0 : leftVector .length ;
187+ int dimsRight = rightVector == null ? 0 : rightVector .length ;
172188
173189 if (dimsLeft == 0 || dimsRight == 0 ) {
174190 // A null value on the left or right vector. Similarity is null
@@ -181,19 +197,14 @@ public Block eval(Page page) {
181197 dimsRight
182198 );
183199 }
184- readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
185- readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
186- float result = similarityFunction .calculateSimilarity (leftScratch , rightScratch );
200+ float result = similarityFunction .calculateSimilarity (leftVector , rightVector );
187201 builder .appendDouble (result );
188202 }
189203 return builder .build ();
190204 }
191- }
192- }
193-
194- private static void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
195- for (int i = 0 ; i < dimensions ; i ++) {
196- scratch [i ] = block .getFloat (position + i );
205+ } finally {
206+ left .finish ();
207+ right .finish ();
197208 }
198209 }
199210
@@ -212,4 +223,171 @@ public void close() {
212223 Releasables .close (left , right );
213224 }
214225 }
226+
227+ interface VectorValueProvider extends Releasable {
228+
229+ void eval (Page page );
230+
231+ float [] getVector (int position );
232+
233+ int getDimensions ();
234+
235+ void finish ();
236+
237+ long baseRamBytesUsed ();
238+ }
239+
240+ interface VectorValueProviderFactory {
241+ VectorValueProvider build (DriverContext context );
242+ }
243+
244+ private static class ConstantVectorProvider implements VectorValueProvider {
245+
246+ record Factory (List <Float > vector ) implements VectorValueProviderFactory {
247+ public VectorValueProvider build (DriverContext context ) {
248+ return new ConstantVectorProvider (vector );
249+ }
250+
251+ @ Override
252+ public String toString () {
253+ return ConstantVectorProvider .class .getSimpleName () + "[vector=" + Arrays .toString (vector .toArray ()) + "]" ;
254+ }
255+ }
256+
257+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ConstantVectorProvider .class );
258+
259+ private final float [] vector ;
260+
261+ ConstantVectorProvider (List <Float > vector ) {
262+ assert vector != null ;
263+ this .vector = new float [vector .size ()];
264+ for (int i = 0 ; i < vector .size (); i ++) {
265+ this .vector [i ] = vector .get (i );
266+ }
267+ }
268+
269+ @ Override
270+ public void eval (Page page ) {
271+ // no-op
272+ }
273+
274+ @ Override
275+ public float [] getVector (int position ) {
276+ return vector ;
277+ }
278+
279+ @ Override
280+ public int getDimensions () {
281+ return vector .length ;
282+ }
283+
284+ @ Override
285+ public void finish () {
286+ // no-op
287+ }
288+
289+ @ Override
290+ public void close () {
291+ // no-op
292+ }
293+
294+ @ Override
295+ public long baseRamBytesUsed () {
296+ return BASE_RAM_BYTES_USED + RamUsageEstimator .shallowSizeOf (vector );
297+ }
298+
299+ @ Override
300+ public String toString () {
301+ return this .getClass ().getSimpleName () + "[vector=" + Arrays .toString (vector ) + "]" ;
302+ }
303+ }
304+
305+ private static class ExpressionVectorProvider implements VectorValueProvider {
306+
307+ record Factory (EvalOperator .ExpressionEvaluator .Factory expressionEvaluatorFactory ) implements VectorValueProviderFactory {
308+ public VectorValueProvider build (DriverContext context ) {
309+ return new ExpressionVectorProvider (expressionEvaluatorFactory .get (context ));
310+ }
311+
312+ @ Override
313+ public String toString () {
314+ return ExpressionVectorProvider .class .getSimpleName () + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]" ;
315+ }
316+ }
317+
318+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ExpressionVectorProvider .class );
319+
320+ private final EvalOperator .ExpressionEvaluator expressionEvaluator ;
321+ private FloatBlock block ;
322+ private float [] scratch ;
323+
324+ ExpressionVectorProvider (EvalOperator .ExpressionEvaluator expressionEvaluator ) {
325+ assert expressionEvaluator != null ;
326+ this .expressionEvaluator = expressionEvaluator ;
327+ }
328+
329+ @ Override
330+ public void eval (Page page ) {
331+ block = (FloatBlock ) expressionEvaluator .eval (page );
332+ }
333+
334+ @ Override
335+ public float [] getVector (int position ) {
336+ if (block .isNull (position )) {
337+ return null ;
338+ }
339+ if (scratch == null ) {
340+ int dims = block .getValueCount (position );
341+ if (dims > 0 ) {
342+ scratch = new float [dims ];
343+ }
344+ }
345+ if (scratch != null ) {
346+ readFloatArray (block , block .getFirstValueIndex (position ), scratch );
347+ }
348+ return scratch ;
349+ }
350+
351+ @ Override
352+ public int getDimensions () {
353+ for (int p = 0 ; p < block .getPositionCount (); p ++) {
354+ int dims = block .getValueCount (p );
355+ if (dims > 0 ) {
356+ return dims ;
357+ }
358+ }
359+ return 0 ;
360+ }
361+
362+ @ Override
363+ public void finish () {
364+ if (block != null ) {
365+ block .close ();
366+ block = null ;
367+ scratch = null ;
368+ }
369+ }
370+
371+ @ Override
372+ public long baseRamBytesUsed () {
373+ return BASE_RAM_BYTES_USED + expressionEvaluator .baseRamBytesUsed () + (block == null ? 0 : block .ramBytesUsed ())
374+ + (scratch == null ? 0 : RamUsageEstimator .shallowSizeOf (scratch ));
375+ }
376+
377+ @ Override
378+ public void close () {
379+ Releasables .close (expressionEvaluator );
380+ }
381+
382+ private static void readFloatArray (FloatBlock block , int firstValueIndex , float [] scratch ) {
383+ for (int i = 0 ; i < scratch .length ; i ++) {
384+ scratch [i ] = block .getFloat (firstValueIndex + i );
385+ }
386+ }
387+
388+ @ Override
389+ public String toString () {
390+ return this .getClass ().getSimpleName () + "[expressionEvaluator=[" + expressionEvaluator + "]]" ;
391+ }
392+ }
215393}
0 commit comments