Skip to content

Commit abf225c

Browse files
authored
[ES|QL] Optimize vector similarity when one vector param is constant (#135602)
1 parent 7f7e4a7 commit abf225c

File tree

7 files changed

+293
-33
lines changed

7 files changed

+293
-33
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 210 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@
1616
import org.elasticsearch.compute.data.Page;
1717
import org.elasticsearch.compute.operator.DriverContext;
1818
import org.elasticsearch.compute.operator.EvalOperator;
19+
import org.elasticsearch.core.Releasable;
1920
import org.elasticsearch.core.Releasables;
2021
import org.elasticsearch.xpack.esql.EsqlClientException;
2122
import org.elasticsearch.xpack.esql.core.expression.Expression;
2223
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
24+
import org.elasticsearch.xpack.esql.core.expression.Literal;
2325
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
2426
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
2527
import org.elasticsearch.xpack.esql.core.tree.Source;
2628
import org.elasticsearch.xpack.esql.core.type.DataType;
2729
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2830

2931
import java.io.IOException;
32+
import java.util.ArrayList;
33+
import java.util.Arrays;
34+
import java.util.List;
3035

3136
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
3237
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
@@ -79,22 +84,36 @@ public Object fold(FoldContext ctx) {
7984

8085
@Override
8186
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
87+
VectorValueProviderFactory leftVectorProviderFactory = getVectorValueProviderFactory(left(), toEvaluator);
88+
VectorValueProviderFactory rightVectorProviderFactory = getVectorValueProviderFactory(right(), toEvaluator);
8289
return new SimilarityEvaluatorFactory(
83-
toEvaluator.apply(left()),
84-
toEvaluator.apply(right()),
90+
leftVectorProviderFactory,
91+
rightVectorProviderFactory,
8592
getSimilarityFunction(),
8693
getClass().getSimpleName() + "Evaluator"
8794
);
8895
}
8996

97+
@SuppressWarnings("unchecked")
98+
private static VectorValueProviderFactory getVectorValueProviderFactory(
99+
Expression expression,
100+
EvaluatorMapper.ToEvaluator toEvaluator
101+
) {
102+
if (expression instanceof Literal) {
103+
return new ConstantVectorProvider.Factory((ArrayList<Float>) ((Literal) expression).value());
104+
} else {
105+
return new ExpressionVectorProvider.Factory(toEvaluator.apply(expression));
106+
}
107+
}
108+
90109
/**
91110
* Returns the similarity function to be used for evaluating the similarity between two vectors.
92111
*/
93112
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
94113

95114
private record SimilarityEvaluatorFactory(
96-
EvalOperator.ExpressionEvaluator.Factory left,
97-
EvalOperator.ExpressionEvaluator.Factory right,
115+
VectorValueProviderFactory leftVectorProviderFactory,
116+
VectorValueProviderFactory rightVectorProviderFactory,
98117
SimilarityEvaluatorFunction similarityFunction,
99118
String evaluatorName
100119
) implements EvalOperator.ExpressionEvaluator.Factory {
@@ -103,8 +122,8 @@ private record SimilarityEvaluatorFactory(
103122
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
104123
// TODO check whether to use this custom evaluator or reuse / define an existing one
105124
return new SimilarityEvaluator(
106-
left.get(context),
107-
right.get(context),
125+
leftVectorProviderFactory.build(context),
126+
rightVectorProviderFactory.build(context),
108127
similarityFunction,
109128
evaluatorName,
110129
context.blockFactory()
@@ -113,13 +132,13 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {
113132

114133
@Override
115134
public String toString() {
116-
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
135+
return evaluatorName() + "[left=" + leftVectorProviderFactory + ", right=" + rightVectorProviderFactory + "]";
117136
}
118137
}
119138

120139
private record SimilarityEvaluator(
121-
EvalOperator.ExpressionEvaluator left,
122-
EvalOperator.ExpressionEvaluator right,
140+
VectorValueProvider left,
141+
VectorValueProvider right,
123142
SimilarityEvaluatorFunction similarityFunction,
124143
String evaluatorName,
125144
BlockFactory blockFactory
@@ -129,26 +148,23 @@ private record SimilarityEvaluator(
129148

130149
@Override
131150
public Block eval(Page page) {
132-
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
151+
try {
152+
left.eval(page);
153+
right.eval(page);
154+
155+
int dimensions = left.getDimensions();
133156
int positionCount = page.getPositionCount();
134-
int dimensions = 0;
135-
// Get the first non-empty vector to calculate the dimension
136-
for (int p = 0; p < positionCount; p++) {
137-
if (leftBlock.getValueCount(p) != 0) {
138-
dimensions = leftBlock.getValueCount(p);
139-
break;
140-
}
141-
}
142157
if (dimensions == 0) {
143158
return blockFactory.newConstantNullBlock(positionCount);
144159
}
145160

146-
float[] leftScratch = new float[dimensions];
147-
float[] rightScratch = new float[dimensions];
148-
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount * dimensions)) {
161+
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) {
149162
for (int p = 0; p < positionCount; p++) {
150-
int dimsLeft = leftBlock.getValueCount(p);
151-
int dimsRight = rightBlock.getValueCount(p);
163+
float[] leftVector = left.getVector(p);
164+
float[] rightVector = right.getVector(p);
165+
166+
int dimsLeft = leftVector == null ? 0 : leftVector.length;
167+
int dimsRight = rightVector == null ? 0 : rightVector.length;
152168

153169
if (dimsLeft == 0 || dimsRight == 0) {
154170
// A null value on the left or right vector. Similarity is null
@@ -161,19 +177,14 @@ public Block eval(Page page) {
161177
dimsRight
162178
);
163179
}
164-
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
165-
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
166-
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
180+
float result = similarityFunction.calculateSimilarity(leftVector, rightVector);
167181
builder.appendDouble(result);
168182
}
169183
return builder.build();
170184
}
171-
}
172-
}
173-
174-
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
175-
for (int i = 0; i < dimensions; i++) {
176-
scratch[i] = block.getFloat(position + i);
185+
} finally {
186+
left.finish();
187+
right.finish();
177188
}
178189
}
179190

@@ -192,4 +203,171 @@ public void close() {
192203
Releasables.close(left, right);
193204
}
194205
}
206+
207+
interface VectorValueProvider extends Releasable {
208+
209+
void eval(Page page);
210+
211+
float[] getVector(int position);
212+
213+
int getDimensions();
214+
215+
void finish();
216+
217+
long baseRamBytesUsed();
218+
}
219+
220+
interface VectorValueProviderFactory {
221+
VectorValueProvider build(DriverContext context);
222+
}
223+
224+
private static class ConstantVectorProvider implements VectorValueProvider {
225+
226+
record Factory(List<Float> vector) implements VectorValueProviderFactory {
227+
public VectorValueProvider build(DriverContext context) {
228+
return new ConstantVectorProvider(vector);
229+
}
230+
231+
@Override
232+
public String toString() {
233+
return ConstantVectorProvider.class.getSimpleName() + "[vector=" + Arrays.toString(vector.toArray()) + "]";
234+
}
235+
}
236+
237+
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ConstantVectorProvider.class);
238+
239+
private final float[] vector;
240+
241+
ConstantVectorProvider(List<Float> vector) {
242+
assert vector != null;
243+
this.vector = new float[vector.size()];
244+
for (int i = 0; i < vector.size(); i++) {
245+
this.vector[i] = vector.get(i);
246+
}
247+
}
248+
249+
@Override
250+
public void eval(Page page) {
251+
// no-op
252+
}
253+
254+
@Override
255+
public float[] getVector(int position) {
256+
return vector;
257+
}
258+
259+
@Override
260+
public int getDimensions() {
261+
return vector.length;
262+
}
263+
264+
@Override
265+
public void finish() {
266+
// no-op
267+
}
268+
269+
@Override
270+
public void close() {
271+
// no-op
272+
}
273+
274+
@Override
275+
public long baseRamBytesUsed() {
276+
return BASE_RAM_BYTES_USED + RamUsageEstimator.shallowSizeOf(vector);
277+
}
278+
279+
@Override
280+
public String toString() {
281+
return this.getClass().getSimpleName() + "[vector=" + Arrays.toString(vector) + "]";
282+
}
283+
}
284+
285+
private static class ExpressionVectorProvider implements VectorValueProvider {
286+
287+
record Factory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) implements VectorValueProviderFactory {
288+
public VectorValueProvider build(DriverContext context) {
289+
return new ExpressionVectorProvider(expressionEvaluatorFactory.get(context));
290+
}
291+
292+
@Override
293+
public String toString() {
294+
return ExpressionVectorProvider.class.getSimpleName() + "[expressionEvaluator=[" + expressionEvaluatorFactory + "]]";
295+
}
296+
}
297+
298+
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ExpressionVectorProvider.class);
299+
300+
private final EvalOperator.ExpressionEvaluator expressionEvaluator;
301+
private FloatBlock block;
302+
private float[] scratch;
303+
304+
ExpressionVectorProvider(EvalOperator.ExpressionEvaluator expressionEvaluator) {
305+
assert expressionEvaluator != null;
306+
this.expressionEvaluator = expressionEvaluator;
307+
}
308+
309+
@Override
310+
public void eval(Page page) {
311+
block = (FloatBlock) expressionEvaluator.eval(page);
312+
}
313+
314+
@Override
315+
public float[] getVector(int position) {
316+
if (block.isNull(position)) {
317+
return null;
318+
}
319+
if (scratch == null) {
320+
int dims = block.getValueCount(position);
321+
if (dims > 0) {
322+
scratch = new float[dims];
323+
}
324+
}
325+
if (scratch != null) {
326+
readFloatArray(block, block.getFirstValueIndex(position), scratch);
327+
}
328+
return scratch;
329+
}
330+
331+
@Override
332+
public int getDimensions() {
333+
for (int p = 0; p < block.getPositionCount(); p++) {
334+
int dims = block.getValueCount(p);
335+
if (dims > 0) {
336+
return dims;
337+
}
338+
}
339+
return 0;
340+
}
341+
342+
@Override
343+
public void finish() {
344+
if (block != null) {
345+
block.close();
346+
block = null;
347+
scratch = null;
348+
}
349+
}
350+
351+
@Override
352+
public long baseRamBytesUsed() {
353+
return BASE_RAM_BYTES_USED + expressionEvaluator.baseRamBytesUsed() + (block == null ? 0 : block.ramBytesUsed())
354+
+ (scratch == null ? 0 : RamUsageEstimator.shallowSizeOf(scratch));
355+
}
356+
357+
@Override
358+
public void close() {
359+
Releasables.close(expressionEvaluator);
360+
}
361+
362+
private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) {
363+
for (int i = 0; i < scratch.length; i++) {
364+
scratch[i] = block.getFloat(firstValueIndex + i);
365+
}
366+
}
367+
368+
@Override
369+
public String toString() {
370+
return this.getClass().getSimpleName() + "[expressionEvaluator=[" + expressionEvaluator + "]]";
371+
}
372+
}
195373
}

0 commit comments

Comments
 (0)