1616import org .elasticsearch .compute .data .Page ;
1717import org .elasticsearch .compute .operator .DriverContext ;
1818import org .elasticsearch .compute .operator .EvalOperator ;
19+ import org .elasticsearch .core .Releasable ;
1920import org .elasticsearch .core .Releasables ;
2021import org .elasticsearch .xpack .esql .EsqlClientException ;
2122import org .elasticsearch .xpack .esql .core .expression .Expression ;
2223import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
24+ import org .elasticsearch .xpack .esql .core .expression .Literal ;
2325import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
2426import org .elasticsearch .xpack .esql .core .expression .function .scalar .BinaryScalarFunction ;
2527import org .elasticsearch .xpack .esql .core .tree .Source ;
2628import org .elasticsearch .xpack .esql .core .type .DataType ;
2729import org .elasticsearch .xpack .esql .evaluator .mapper .EvaluatorMapper ;
2830
2931import java .io .IOException ;
32+ import java .util .ArrayList ;
33+ import java .util .Arrays ;
34+ import java .util .List ;
3035
3136import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
3237import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
@@ -79,22 +84,36 @@ public Object fold(FoldContext ctx) {
7984
8085 @ Override
8186 public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (EvaluatorMapper .ToEvaluator toEvaluator ) {
87+ VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory (left (), toEvaluator );
88+ VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory (right (), toEvaluator );
8289 return new SimilarityEvaluatorFactory (
83- toEvaluator . apply ( left ()) ,
84- toEvaluator . apply ( right ()) ,
90+ leftVectorProviderFactory ,
91+ rightVectorProviderFactory ,
8592 getSimilarityFunction (),
8693 getClass ().getSimpleName () + "Evaluator"
8794 );
8895 }
8996
97+ @ SuppressWarnings ("unchecked" )
98+ private static VectorValueProviderFactory getVectorValueProviderFactory (
99+ Expression expression ,
100+ EvaluatorMapper .ToEvaluator toEvaluator
101+ ) {
102+ if (expression instanceof Literal ) {
103+ return new ConstantVectorProvider .Factory ((ArrayList <Float >) ((Literal ) expression ).value ());
104+ } else {
105+ return new ExpressionVectorProvider .Factory (toEvaluator .apply (expression ));
106+ }
107+ }
108+
90109 /**
91110 * Returns the similarity function to be used for evaluating the similarity between two vectors.
92111 */
93112 protected abstract SimilarityEvaluatorFunction getSimilarityFunction ();
94113
95114 private record SimilarityEvaluatorFactory (
96- EvalOperator . ExpressionEvaluator . Factory left ,
97- EvalOperator . ExpressionEvaluator . Factory right ,
115+ VectorValueProviderFactory leftVectorProviderFactory ,
116+ VectorValueProviderFactory rightVectorProviderFactory ,
98117 SimilarityEvaluatorFunction similarityFunction ,
99118 String evaluatorName
100119 ) implements EvalOperator .ExpressionEvaluator .Factory {
@@ -103,8 +122,8 @@ private record SimilarityEvaluatorFactory(
103122 public EvalOperator .ExpressionEvaluator get (DriverContext context ) {
104123 // TODO check whether to use this custom evaluator or reuse / define an existing one
105124 return new SimilarityEvaluator (
106- left . get (context ),
107- right . get (context ),
125+ leftVectorProviderFactory . build (context ),
126+ rightVectorProviderFactory . build (context ),
108127 similarityFunction ,
109128 evaluatorName ,
110129 context .blockFactory ()
@@ -113,13 +132,13 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
113132
114133 @ Override
115134 public String toString () {
116- return evaluatorName () + "[left=" + left + ", right=" + right + "]" ;
135+ return evaluatorName () + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]" ;
117136 }
118137 }
119138
120139 private record SimilarityEvaluator (
121- EvalOperator . ExpressionEvaluator left ,
122- EvalOperator . ExpressionEvaluator right ,
140+ VectorValueProvider left ,
141+ VectorValueProvider right ,
123142 SimilarityEvaluatorFunction similarityFunction ,
124143 String evaluatorName ,
125144 BlockFactory blockFactory
@@ -129,26 +148,23 @@ private record SimilarityEvaluator(
129148
130149 @ Override
131150 public Block eval (Page page ) {
132- try (FloatBlock leftBlock = (FloatBlock ) left .eval (page ); FloatBlock rightBlock = (FloatBlock ) right .eval (page )) {
151+ try {
152+ left .eval (page );
153+ right .eval (page );
154+
155+ int dimensions = left .getDimensions ();
133156 int positionCount = page .getPositionCount ();
134- int dimensions = 0 ;
135- // Get the first non-empty vector to calculate the dimension
136- for (int p = 0 ; p < positionCount ; p ++) {
137- if (leftBlock .getValueCount (p ) != 0 ) {
138- dimensions = leftBlock .getValueCount (p );
139- break ;
140- }
141- }
142157 if (dimensions == 0 ) {
143158 return blockFactory .newConstantNullBlock (positionCount );
144159 }
145160
146- float [] leftScratch = new float [dimensions ];
147- float [] rightScratch = new float [dimensions ];
148- try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount * dimensions )) {
161+ try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount )) {
149162 for (int p = 0 ; p < positionCount ; p ++) {
150- int dimsLeft = leftBlock .getValueCount (p );
151- int dimsRight = rightBlock .getValueCount (p );
163+ float [] leftVector = left .getVector (p );
164+ float [] rightVector = right .getVector (p );
165+
166+ int dimsLeft = leftVector == null ? 0 : leftVector .length ;
167+ int dimsRight = rightVector == null ? 0 : rightVector .length ;
152168
153169 if (dimsLeft == 0 || dimsRight == 0 ) {
154170 // A null value on the left or right vector. Similarity is null
@@ -161,19 +177,14 @@ public Block eval(Page page) {
161177 dimsRight
162178 );
163179 }
164- readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
165- readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
166- float result = similarityFunction .calculateSimilarity (leftScratch , rightScratch );
180+ float result = similarityFunction .calculateSimilarity (leftVector , rightVector );
167181 builder .appendDouble (result );
168182 }
169183 return builder .build ();
170184 }
171- }
172- }
173-
174- private static void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
175- for (int i = 0 ; i < dimensions ; i ++) {
176- scratch [i ] = block .getFloat (position + i );
185+ } finally {
186+ left .finish ();
187+ right .finish ();
177188 }
178189 }
179190
@@ -192,4 +203,171 @@ public void close() {
192203 Releasables .close (left , right );
193204 }
194205 }
206+
207+ interface VectorValueProvider extends Releasable {
208+
209+ void eval (Page page );
210+
211+ float [] getVector (int position );
212+
213+ int getDimensions ();
214+
215+ void finish ();
216+
217+ long baseRamBytesUsed ();
218+ }
219+
220+ interface VectorValueProviderFactory {
221+ VectorValueProvider build (DriverContext context );
222+ }
223+
224+ private static class ConstantVectorProvider implements VectorValueProvider {
225+
226+ record Factory (List <Float > vector ) implements VectorValueProviderFactory {
227+ public VectorValueProvider build (DriverContext context ) {
228+ return new ConstantVectorProvider (vector );
229+ }
230+
231+ @ Override
232+ public String toString () {
233+ return ConstantVectorProvider .class .getSimpleName () + "[vector=" + Arrays .toString (vector .toArray ()) + "]" ;
234+ }
235+ }
236+
237+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ConstantVectorProvider .class );
238+
239+ private final float [] vector ;
240+
241+ ConstantVectorProvider (List <Float > vector ) {
242+ assert vector != null ;
243+ this .vector = new float [vector .size ()];
244+ for (int i = 0 ; i < vector .size (); i ++) {
245+ this .vector [i ] = vector .get (i );
246+ }
247+ }
248+
249+ @ Override
250+ public void eval (Page page ) {
251+ // no-op
252+ }
253+
254+ @ Override
255+ public float [] getVector (int position ) {
256+ return vector ;
257+ }
258+
259+ @ Override
260+ public int getDimensions () {
261+ return vector .length ;
262+ }
263+
264+ @ Override
265+ public void finish () {
266+ // no-op
267+ }
268+
269+ @ Override
270+ public void close () {
271+ // no-op
272+ }
273+
274+ @ Override
275+ public long baseRamBytesUsed () {
276+ return BASE_RAM_BYTES_USED + RamUsageEstimator .shallowSizeOf (vector );
277+ }
278+
279+ @ Override
280+ public String toString () {
281+ return this .getClass ().getSimpleName () + "[vector=" + Arrays .toString (vector ) + "]" ;
282+ }
283+ }
284+
285+ private static class ExpressionVectorProvider implements VectorValueProvider {
286+
287+ record Factory (EvalOperator .ExpressionEvaluator .Factory expressionEvaluatorFactory ) implements VectorValueProviderFactory {
288+ public VectorValueProvider build (DriverContext context ) {
289+ return new ExpressionVectorProvider (expressionEvaluatorFactory .get (context ));
290+ }
291+
292+ @ Override
293+ public String toString () {
294+ return ExpressionVectorProvider .class .getSimpleName () + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]" ;
295+ }
296+ }
297+
298+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ExpressionVectorProvider .class );
299+
300+ private final EvalOperator .ExpressionEvaluator expressionEvaluator ;
301+ private FloatBlock block ;
302+ private float [] scratch ;
303+
304+ ExpressionVectorProvider (EvalOperator .ExpressionEvaluator expressionEvaluator ) {
305+ assert expressionEvaluator != null ;
306+ this .expressionEvaluator = expressionEvaluator ;
307+ }
308+
309+ @ Override
310+ public void eval (Page page ) {
311+ block = (FloatBlock ) expressionEvaluator .eval (page );
312+ }
313+
314+ @ Override
315+ public float [] getVector (int position ) {
316+ if (block .isNull (position )) {
317+ return null ;
318+ }
319+ if (scratch == null ) {
320+ int dims = block .getValueCount (position );
321+ if (dims > 0 ) {
322+ scratch = new float [dims ];
323+ }
324+ }
325+ if (scratch != null ) {
326+ readFloatArray (block , block .getFirstValueIndex (position ), scratch );
327+ }
328+ return scratch ;
329+ }
330+
331+ @ Override
332+ public int getDimensions () {
333+ for (int p = 0 ; p < block .getPositionCount (); p ++) {
334+ int dims = block .getValueCount (p );
335+ if (dims > 0 ) {
336+ return dims ;
337+ }
338+ }
339+ return 0 ;
340+ }
341+
342+ @ Override
343+ public void finish () {
344+ if (block != null ) {
345+ block .close ();
346+ block = null ;
347+ scratch = null ;
348+ }
349+ }
350+
351+ @ Override
352+ public long baseRamBytesUsed () {
353+ return BASE_RAM_BYTES_USED + expressionEvaluator .baseRamBytesUsed () + (block == null ? 0 : block .ramBytesUsed ())
354+ + (scratch == null ? 0 : RamUsageEstimator .shallowSizeOf (scratch ));
355+ }
356+
357+ @ Override
358+ public void close () {
359+ Releasables .close (expressionEvaluator );
360+ }
361+
362+ private static void readFloatArray (FloatBlock block , int firstValueIndex , float [] scratch ) {
363+ for (int i = 0 ; i < scratch .length ; i ++) {
364+ scratch [i ] = block .getFloat (firstValueIndex + i );
365+ }
366+ }
367+
368+ @ Override
369+ public String toString () {
370+ return this .getClass ().getSimpleName () + "[expressionEvaluator=[" + expressionEvaluator + "]]" ;
371+ }
372+ }
195373}
0 commit comments