16
16
import org .elasticsearch .compute .data .Page ;
17
17
import org .elasticsearch .compute .operator .DriverContext ;
18
18
import org .elasticsearch .compute .operator .EvalOperator ;
19
+ import org .elasticsearch .core .Releasable ;
19
20
import org .elasticsearch .core .Releasables ;
20
21
import org .elasticsearch .xpack .esql .EsqlClientException ;
21
22
import org .elasticsearch .xpack .esql .core .expression .Expression ;
22
23
import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
24
+ import org .elasticsearch .xpack .esql .core .expression .Literal ;
23
25
import org .elasticsearch .xpack .esql .core .expression .TypeResolutions ;
24
26
import org .elasticsearch .xpack .esql .core .expression .function .scalar .BinaryScalarFunction ;
25
27
import org .elasticsearch .xpack .esql .core .tree .Source ;
26
28
import org .elasticsearch .xpack .esql .core .type .DataType ;
27
29
import org .elasticsearch .xpack .esql .evaluator .mapper .EvaluatorMapper ;
28
30
29
31
import java .io .IOException ;
32
+ import java .util .ArrayList ;
33
+ import java .util .Arrays ;
34
+ import java .util .List ;
30
35
31
36
import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .FIRST ;
32
37
import static org .elasticsearch .xpack .esql .core .expression .TypeResolutions .ParamOrdinal .SECOND ;
@@ -79,22 +84,36 @@ public Object fold(FoldContext ctx) {
79
84
80
85
@ Override
81
86
public final EvalOperator .ExpressionEvaluator .Factory toEvaluator (EvaluatorMapper .ToEvaluator toEvaluator ) {
87
+ VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory (left (), toEvaluator );
88
+ VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory (right (), toEvaluator );
82
89
return new SimilarityEvaluatorFactory (
83
- toEvaluator . apply ( left ()) ,
84
- toEvaluator . apply ( right ()) ,
90
+ leftVectorProviderFactory ,
91
+ rightVectorProviderFactory ,
85
92
getSimilarityFunction (),
86
93
getClass ().getSimpleName () + "Evaluator"
87
94
);
88
95
}
89
96
97
+ @ SuppressWarnings ("unchecked" )
98
+ private static VectorValueProviderFactory getVectorValueProviderFactory (
99
+ Expression expression ,
100
+ EvaluatorMapper .ToEvaluator toEvaluator
101
+ ) {
102
+ if (expression instanceof Literal ) {
103
+ return new ConstantVectorProvider .Factory ((ArrayList <Float >) ((Literal ) expression ).value ());
104
+ } else {
105
+ return new ExpressionVectorProvider .Factory (toEvaluator .apply (expression ));
106
+ }
107
+ }
108
+
90
109
/**
91
110
* Returns the similarity function to be used for evaluating the similarity between two vectors.
92
111
*/
93
112
protected abstract SimilarityEvaluatorFunction getSimilarityFunction ();
94
113
95
114
private record SimilarityEvaluatorFactory (
96
- EvalOperator . ExpressionEvaluator . Factory left ,
97
- EvalOperator . ExpressionEvaluator . Factory right ,
115
+ VectorValueProviderFactory leftVectorProviderFactory ,
116
+ VectorValueProviderFactory rightVectorProviderFactory ,
98
117
SimilarityEvaluatorFunction similarityFunction ,
99
118
String evaluatorName
100
119
) implements EvalOperator .ExpressionEvaluator .Factory {
@@ -103,8 +122,8 @@ private record SimilarityEvaluatorFactory(
103
122
public EvalOperator .ExpressionEvaluator get (DriverContext context ) {
104
123
// TODO check whether to use this custom evaluator or reuse / define an existing one
105
124
return new SimilarityEvaluator (
106
- left . get (context ),
107
- right . get (context ),
125
+ leftVectorProviderFactory . build (context ),
126
+ rightVectorProviderFactory . build (context ),
108
127
similarityFunction ,
109
128
evaluatorName ,
110
129
context .blockFactory ()
@@ -113,13 +132,13 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
113
132
114
133
@ Override
115
134
public String toString () {
116
- return evaluatorName () + "[left=" + left + ", right=" + right + "]" ;
135
+ return evaluatorName () + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]" ;
117
136
}
118
137
}
119
138
120
139
private record SimilarityEvaluator (
121
- EvalOperator . ExpressionEvaluator left ,
122
- EvalOperator . ExpressionEvaluator right ,
140
+ VectorValueProvider left ,
141
+ VectorValueProvider right ,
123
142
SimilarityEvaluatorFunction similarityFunction ,
124
143
String evaluatorName ,
125
144
BlockFactory blockFactory
@@ -129,26 +148,23 @@ private record SimilarityEvaluator(
129
148
130
149
@ Override
131
150
public Block eval (Page page ) {
132
- try (FloatBlock leftBlock = (FloatBlock ) left .eval (page ); FloatBlock rightBlock = (FloatBlock ) right .eval (page )) {
151
+ try {
152
+ left .eval (page );
153
+ right .eval (page );
154
+
155
+ int dimensions = left .getDimensions ();
133
156
int positionCount = page .getPositionCount ();
134
- int dimensions = 0 ;
135
- // Get the first non-empty vector to calculate the dimension
136
- for (int p = 0 ; p < positionCount ; p ++) {
137
- if (leftBlock .getValueCount (p ) != 0 ) {
138
- dimensions = leftBlock .getValueCount (p );
139
- break ;
140
- }
141
- }
142
157
if (dimensions == 0 ) {
143
158
return blockFactory .newConstantNullBlock (positionCount );
144
159
}
145
160
146
- float [] leftScratch = new float [dimensions ];
147
- float [] rightScratch = new float [dimensions ];
148
- try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount * dimensions )) {
161
+ try (DoubleBlock .Builder builder = blockFactory .newDoubleBlockBuilder (positionCount )) {
149
162
for (int p = 0 ; p < positionCount ; p ++) {
150
- int dimsLeft = leftBlock .getValueCount (p );
151
- int dimsRight = rightBlock .getValueCount (p );
163
+ float [] leftVector = left .getVector (p );
164
+ float [] rightVector = right .getVector (p );
165
+
166
+ int dimsLeft = leftVector == null ? 0 : leftVector .length ;
167
+ int dimsRight = rightVector == null ? 0 : rightVector .length ;
152
168
153
169
if (dimsLeft == 0 || dimsRight == 0 ) {
154
170
// A null value on the left or right vector. Similarity is null
@@ -161,19 +177,14 @@ public Block eval(Page page) {
161
177
dimsRight
162
178
);
163
179
}
164
- readFloatArray (leftBlock , leftBlock .getFirstValueIndex (p ), dimensions , leftScratch );
165
- readFloatArray (rightBlock , rightBlock .getFirstValueIndex (p ), dimensions , rightScratch );
166
- float result = similarityFunction .calculateSimilarity (leftScratch , rightScratch );
180
+ float result = similarityFunction .calculateSimilarity (leftVector , rightVector );
167
181
builder .appendDouble (result );
168
182
}
169
183
return builder .build ();
170
184
}
171
- }
172
- }
173
-
174
- private static void readFloatArray (FloatBlock block , int position , int dimensions , float [] scratch ) {
175
- for (int i = 0 ; i < dimensions ; i ++) {
176
- scratch [i ] = block .getFloat (position + i );
185
+ } finally {
186
+ left .finish ();
187
+ right .finish ();
177
188
}
178
189
}
179
190
@@ -192,4 +203,171 @@ public void close() {
192
203
Releasables .close (left , right );
193
204
}
194
205
}
206
+
207
+ interface VectorValueProvider extends Releasable {
208
+
209
+ void eval (Page page );
210
+
211
+ float [] getVector (int position );
212
+
213
+ int getDimensions ();
214
+
215
+ void finish ();
216
+
217
+ long baseRamBytesUsed ();
218
+ }
219
+
220
+ interface VectorValueProviderFactory {
221
+ VectorValueProvider build (DriverContext context );
222
+ }
223
+
224
+ private static class ConstantVectorProvider implements VectorValueProvider {
225
+
226
+ record Factory (List <Float > vector ) implements VectorValueProviderFactory {
227
+ public VectorValueProvider build (DriverContext context ) {
228
+ return new ConstantVectorProvider (vector );
229
+ }
230
+
231
+ @ Override
232
+ public String toString () {
233
+ return ConstantVectorProvider .class .getSimpleName () + "[vector=" + Arrays .toString (vector .toArray ()) + "]" ;
234
+ }
235
+ }
236
+
237
+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ConstantVectorProvider .class );
238
+
239
+ private final float [] vector ;
240
+
241
+ ConstantVectorProvider (List <Float > vector ) {
242
+ assert vector != null ;
243
+ this .vector = new float [vector .size ()];
244
+ for (int i = 0 ; i < vector .size (); i ++) {
245
+ this .vector [i ] = vector .get (i );
246
+ }
247
+ }
248
+
249
+ @ Override
250
+ public void eval (Page page ) {
251
+ // no-op
252
+ }
253
+
254
+ @ Override
255
+ public float [] getVector (int position ) {
256
+ return vector ;
257
+ }
258
+
259
+ @ Override
260
+ public int getDimensions () {
261
+ return vector .length ;
262
+ }
263
+
264
+ @ Override
265
+ public void finish () {
266
+ // no-op
267
+ }
268
+
269
+ @ Override
270
+ public void close () {
271
+ // no-op
272
+ }
273
+
274
+ @ Override
275
+ public long baseRamBytesUsed () {
276
+ return BASE_RAM_BYTES_USED + RamUsageEstimator .shallowSizeOf (vector );
277
+ }
278
+
279
+ @ Override
280
+ public String toString () {
281
+ return this .getClass ().getSimpleName () + "[vector=" + Arrays .toString (vector ) + "]" ;
282
+ }
283
+ }
284
+
285
+ private static class ExpressionVectorProvider implements VectorValueProvider {
286
+
287
+ record Factory (EvalOperator .ExpressionEvaluator .Factory expressionEvaluatorFactory ) implements VectorValueProviderFactory {
288
+ public VectorValueProvider build (DriverContext context ) {
289
+ return new ExpressionVectorProvider (expressionEvaluatorFactory .get (context ));
290
+ }
291
+
292
+ @ Override
293
+ public String toString () {
294
+ return ExpressionVectorProvider .class .getSimpleName () + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]" ;
295
+ }
296
+ }
297
+
298
+ private static final long BASE_RAM_BYTES_USED = RamUsageEstimator .shallowSizeOfInstance (ExpressionVectorProvider .class );
299
+
300
+ private final EvalOperator .ExpressionEvaluator expressionEvaluator ;
301
+ private FloatBlock block ;
302
+ private float [] scratch ;
303
+
304
+ ExpressionVectorProvider (EvalOperator .ExpressionEvaluator expressionEvaluator ) {
305
+ assert expressionEvaluator != null ;
306
+ this .expressionEvaluator = expressionEvaluator ;
307
+ }
308
+
309
+ @ Override
310
+ public void eval (Page page ) {
311
+ block = (FloatBlock ) expressionEvaluator .eval (page );
312
+ }
313
+
314
+ @ Override
315
+ public float [] getVector (int position ) {
316
+ if (block .isNull (position )) {
317
+ return null ;
318
+ }
319
+ if (scratch == null ) {
320
+ int dims = block .getValueCount (position );
321
+ if (dims > 0 ) {
322
+ scratch = new float [dims ];
323
+ }
324
+ }
325
+ if (scratch != null ) {
326
+ readFloatArray (block , block .getFirstValueIndex (position ), scratch );
327
+ }
328
+ return scratch ;
329
+ }
330
+
331
+ @ Override
332
+ public int getDimensions () {
333
+ for (int p = 0 ; p < block .getPositionCount (); p ++) {
334
+ int dims = block .getValueCount (p );
335
+ if (dims > 0 ) {
336
+ return dims ;
337
+ }
338
+ }
339
+ return 0 ;
340
+ }
341
+
342
+ @ Override
343
+ public void finish () {
344
+ if (block != null ) {
345
+ block .close ();
346
+ block = null ;
347
+ scratch = null ;
348
+ }
349
+ }
350
+
351
+ @ Override
352
+ public long baseRamBytesUsed () {
353
+ return BASE_RAM_BYTES_USED + expressionEvaluator .baseRamBytesUsed () + (block == null ? 0 : block .ramBytesUsed ())
354
+ + (scratch == null ? 0 : RamUsageEstimator .shallowSizeOf (scratch ));
355
+ }
356
+
357
+ @ Override
358
+ public void close () {
359
+ Releasables .close (expressionEvaluator );
360
+ }
361
+
362
+ private static void readFloatArray (FloatBlock block , int firstValueIndex , float [] scratch ) {
363
+ for (int i = 0 ; i < scratch .length ; i ++) {
364
+ scratch [i ] = block .getFloat (firstValueIndex + i );
365
+ }
366
+ }
367
+
368
+ @ Override
369
+ public String toString () {
370
+ return this .getClass ().getSimpleName () + "[expressionEvaluator=[" + expressionEvaluator + "]]" ;
371
+ }
372
+ }
195
373
}
0 commit comments