Skip to content

Commit 9e0c372

Browse files
author
elasticsearchmachine
committed
Fix merge with main
1 parent dbbe45f commit 9e0c372

File tree

1 file changed

+220
-42
lines changed

1 file changed

+220
-42
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 220 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,23 @@
1616
import org.elasticsearch.compute.data.Page;
1717
import org.elasticsearch.compute.operator.DriverContext;
1818
import org.elasticsearch.compute.operator.EvalOperator;
19+
import org.elasticsearch.core.Releasable;
1920
import org.elasticsearch.core.Releasables;
2021
import org.elasticsearch.xpack.esql.EsqlClientException;
2122
import org.elasticsearch.xpack.esql.core.expression.Expression;
2223
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
2324
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
25+
import org.elasticsearch.xpack.esql.core.expression.Literal;
2426
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
2527
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
2628
import org.elasticsearch.xpack.esql.core.tree.Source;
2729
import org.elasticsearch.xpack.esql.core.type.DataType;
2830
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2931

3032
import java.io.IOException;
33+
import java.util.ArrayList;
34+
import java.util.Arrays;
35+
import java.util.List;
3136

3237
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
3338
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
@@ -79,23 +84,13 @@ public Object fold(FoldContext ctx) {
7984
return EvaluatorMapper.super.fold(source(), ctx);
8085
}
8186

82-
@Override
83-
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
84-
return new SimilarityEvaluatorFactory(
85-
toEvaluator.apply(left()),
86-
toEvaluator.apply(right()),
87-
getSimilarityFunction(),
88-
getClass().getSimpleName() + "Evaluator"
89-
);
90-
}
91-
9287
@Override
9388
public PushableOptions pushableOptions() {
9489
// Can only push if one of the arguments is a literal, and the other a field name
9590
return left().foldable() && left().dataType() != NULL && right() instanceof FieldAttribute
9691
|| right().foldable() && right().dataType() != NULL && left() instanceof FieldAttribute
97-
? PushableOptions.PREFERRED
98-
: PushableOptions.NOT_SUPPORTED;
92+
? PushableOptions.PREFERRED
93+
: PushableOptions.NOT_SUPPORTED;
9994
}
10095

10196
@Override
@@ -107,14 +102,38 @@ public final String asScript() {
107102

108103
protected abstract String scriptFunctionName();
109104

105+
@Override
106+
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
107+
VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory(left(), toEvaluator);
108+
VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory(right(), toEvaluator);
109+
return new SimilarityEvaluatorFactory(
110+
leftVectorProviderFactory,
111+
rightVectorProviderFactory,
112+
getSimilarityFunction(),
113+
getClass().getSimpleName() + "Evaluator"
114+
);
115+
}
116+
117+
@SuppressWarnings("unchecked")
118+
private static VectorValueProviderFactory getVectorValueProviderFactory(
119+
Expression expression,
120+
EvaluatorMapper.ToEvaluator toEvaluator
121+
) {
122+
if (expression instanceof Literal) {
123+
return new ConstantVectorProvider.Factory((ArrayList<Float>) ((Literal) expression).value());
124+
} else {
125+
return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression));
126+
}
127+
}
128+
110129
/**
111130
* Returns the similarity function to be used for evaluating the similarity between two vectors.
112131
*/
113132
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
114133

115134
private record SimilarityEvaluatorFactory(
116-
EvalOperator.ExpressionEvaluator.Factory left,
117-
EvalOperator.ExpressionEvaluator.Factory right,
135+
VectorValueProviderFactory leftVectorProviderFactory,
136+
VectorValueProviderFactory rightVectorProviderFactory,
118137
SimilarityEvaluatorFunction similarityFunction,
119138
String evaluatorName
120139
) implements EvalOperator.ExpressionEvaluator.Factory {
@@ -123,8 +142,8 @@ private record SimilarityEvaluatorFactory(
123142
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
124143
// TODO check whether to use this custom evaluator or reuse / define an existing one
125144
return new SimilarityEvaluator(
126-
left.get(context),
127-
right.get(context),
145+
leftVectorProviderFactory.build(context),
146+
rightVectorProviderFactory.build(context),
128147
similarityFunction,
129148
evaluatorName,
130149
context.blockFactory()
@@ -133,13 +152,13 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
133152

134153
@Override
135154
public String toString() {
136-
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
155+
return evaluatorName() + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]";
137156
}
138157
}
139158

140159
private record SimilarityEvaluator(
141-
EvalOperator.ExpressionEvaluator left,
142-
EvalOperator.ExpressionEvaluator right,
160+
VectorValueProvider left,
161+
VectorValueProvider right,
143162
SimilarityEvaluatorFunction similarityFunction,
144163
String evaluatorName,
145164
BlockFactory blockFactory
@@ -149,26 +168,23 @@ private record SimilarityEvaluator(
149168

150169
@Override
151170
public Block eval(Page page) {
152-
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
171+
try {
172+
left.eval(page);
173+
right.eval(page);
174+
175+
int dimensions = left.getDimensions();
153176
int positionCount = page.getPositionCount();
154-
int dimensions = 0;
155-
// Get the first non-empty vector to calculate the dimension
156-
for (int p = 0; p < positionCount; p++) {
157-
if (leftBlock.getValueCount(p) != 0) {
158-
dimensions = leftBlock.getValueCount(p);
159-
break;
160-
}
161-
}
162177
if (dimensions == 0) {
163178
return blockFactory.newConstantNullBlock(positionCount);
164179
}
165180

166-
float[] leftScratch = new float[dimensions];
167-
float[] rightScratch = new float[dimensions];
168-
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount * dimensions)) {
181+
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) {
169182
for (int p = 0; p < positionCount; p++) {
170-
int dimsLeft = leftBlock.getValueCount(p);
171-
int dimsRight = rightBlock.getValueCount(p);
183+
float[] leftVector = left.getVector(p);
184+
float[] rightVector = right.getVector(p);
185+
186+
int dimsLeft = leftVector == null ? 0 : leftVector.length;
187+
int dimsRight = rightVector == null ? 0 : rightVector.length;
172188

173189
if (dimsLeft == 0 || dimsRight == 0) {
174190
// A null value on the left or right vector. Similarity is null
@@ -181,19 +197,14 @@ public Block eval(Page page) {
181197
dimsRight
182198
);
183199
}
184-
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
185-
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
186-
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
200+
float result = similarityFunction.calculateSimilarity(leftVector, rightVector);
187201
builder.appendDouble(result);
188202
}
189203
return builder.build();
190204
}
191-
}
192-
}
193-
194-
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
195-
for (int i = 0; i < dimensions; i++) {
196-
scratch[i] = block.getFloat(position + i);
205+
} finally {
206+
left.finish();
207+
right.finish();
197208
}
198209
}
199210

@@ -212,4 +223,171 @@ public void close() {
212223
Releasables.close(left, right);
213224
}
214225
}
226+
227+
interface VectorValueProvider extends Releasable {
228+
229+
void eval(Page page);
230+
231+
float[] getVector(int position);
232+
233+
int getDimensions();
234+
235+
void finish();
236+
237+
long baseRamBytesUsed();
238+
}
239+
240+
interface VectorValueProviderFactory {
241+
VectorValueProvider build(DriverContext context);
242+
}
243+
244+
private static class ConstantVectorProvider implements VectorValueProvider {
245+
246+
record Factory(List<Float> vector) implements VectorValueProviderFactory {
247+
public VectorValueProvider build(DriverContext context) {
248+
return new ConstantVectorProvider(vector);
249+
}
250+
251+
@Override
252+
public String toString() {
253+
return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(vector.toArray()) + "]";
254+
}
255+
}
256+
257+
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantVectorProvider.class);
258+
259+
private final float[] vector;
260+
261+
ConstantVectorProvider(List<Float> vector) {
262+
assert vector != null;
263+
this.vector = new float[vector.size()];
264+
for (int i = 0; i < vector.size(); i++) {
265+
this.vector[i] = vector.get(i);
266+
}
267+
}
268+
269+
@Override
270+
public void eval(Page page) {
271+
// no-op
272+
}
273+
274+
@Override
275+
public float[] getVector(int position) {
276+
return vector;
277+
}
278+
279+
@Override
280+
public int getDimensions() {
281+
return vector.length;
282+
}
283+
284+
@Override
285+
public void finish() {
286+
// no-op
287+
}
288+
289+
@Override
290+
public void close() {
291+
// no-op
292+
}
293+
294+
@Override
295+
public long baseRamBytesUsed() {
296+
return BASE_RAM_BYTES_USED + RamUsageEstimator.shallowSizeOf(vector);
297+
}
298+
299+
@Override
300+
public String toString() {
301+
return this.getClass().getSimpleName() + "[vector=" + Arrays.toString(vector) + "]";
302+
}
303+
}
304+
305+
private static class ExpressionVectorProvider implements VectorValueProvider {
306+
307+
record Factory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) implements VectorValueProviderFactory {
308+
public VectorValueProvider build(DriverContext context) {
309+
return new ExpressionVectorProvider(expressionEvaluatorFactory.get(context));
310+
}
311+
312+
@Override
313+
public String toString() {
314+
return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]";
315+
}
316+
}
317+
318+
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ExpressionVectorProvider.class);
319+
320+
private final EvalOperator.ExpressionEvaluator expressionEvaluator;
321+
private FloatBlock block;
322+
private float[] scratch;
323+
324+
ExpressionVectorProvider(EvalOperator.ExpressionEvaluator expressionEvaluator) {
325+
assert expressionEvaluator != null;
326+
this.expressionEvaluator = expressionEvaluator;
327+
}
328+
329+
@Override
330+
public void eval(Page page) {
331+
block = (FloatBlock) expressionEvaluator.eval(page);
332+
}
333+
334+
@Override
335+
public float[] getVector(int position) {
336+
if (block.isNull(position)) {
337+
return null;
338+
}
339+
if (scratch == null) {
340+
int dims = block.getValueCount(position);
341+
if (dims > 0) {
342+
scratch = new float[dims];
343+
}
344+
}
345+
if (scratch != null) {
346+
readFloatArray(block, block.getFirstValueIndex(position), scratch);
347+
}
348+
return scratch;
349+
}
350+
351+
@Override
352+
public int getDimensions() {
353+
for (int p = 0; p < block.getPositionCount(); p++) {
354+
int dims = block.getValueCount(p);
355+
if (dims > 0) {
356+
return dims;
357+
}
358+
}
359+
return 0;
360+
}
361+
362+
@Override
363+
public void finish() {
364+
if (block != null) {
365+
block.close();
366+
block = null;
367+
scratch = null;
368+
}
369+
}
370+
371+
@Override
372+
public long baseRamBytesUsed() {
373+
return BASE_RAM_BYTES_USED + expressionEvaluator.baseRamBytesUsed() + (block == null ? 0 : block.ramBytesUsed())
374+
+ (scratch == null ? 0 : RamUsageEstimator.shallowSizeOf(scratch));
375+
}
376+
377+
@Override
378+
public void close() {
379+
Releasables.close(expressionEvaluator);
380+
}
381+
382+
private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) {
383+
for (int i = 0; i < scratch.length; i++) {
384+
scratch[i] = block.getFloat(firstValueIndex + i);
385+
}
386+
}
387+
388+
@Override
389+
public String toString() {
390+
return this.getClass().getSimpleName() + "[expressionEvaluator=[" + expressionEvaluator + "]]";
391+
}
392+
}
215393
}

0 commit comments

Comments
 (0)