-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ES|QL] Optimize vector similarity when one param is constant #135602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
7d9e975
abe7f01
989235e
aeebc6e
240dd53
4a759ba
e386b62
0301c5b
97d1958
cafafcb
9f12dbe
3509bba
1fc70a2
7b5b760
9d51231
504c92e
59cdda9
6da8fb4
42dcf15
c336ede
4640bb8
8d72e5c
c25f0bd
5d95e16
ee019f2
96c1a53
9f47ab2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,17 +16,21 @@ | |
| import org.elasticsearch.compute.data.Page; | ||
| import org.elasticsearch.compute.operator.DriverContext; | ||
| import org.elasticsearch.compute.operator.EvalOperator; | ||
| import org.elasticsearch.core.Releasable; | ||
| import org.elasticsearch.core.Releasables; | ||
| import org.elasticsearch.xpack.esql.EsqlClientException; | ||
| import org.elasticsearch.xpack.esql.core.expression.Expression; | ||
| import org.elasticsearch.xpack.esql.core.expression.FoldContext; | ||
| import org.elasticsearch.xpack.esql.core.expression.Literal; | ||
| import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; | ||
| import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; | ||
| import org.elasticsearch.xpack.esql.core.tree.Source; | ||
| import org.elasticsearch.xpack.esql.core.type.DataType; | ||
| import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayList; | ||
| import java.util.Arrays; | ||
|
|
||
| import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; | ||
| import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; | ||
|
|
@@ -78,10 +82,23 @@ public Object fold(FoldContext ctx) { | |
| } | ||
|
|
||
| @Override | ||
| @SuppressWarnings("unchecked") | ||
| public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { | ||
| VectorValueProvider.Builder leftVectorProviderBuilder = new VectorValueProvider.Builder(); | ||
| VectorValueProvider.Builder rightVectorProviderBuilder = new VectorValueProvider.Builder(); | ||
| if (left() instanceof Literal && left().dataType() == DENSE_VECTOR) { | ||
|
||
| leftVectorProviderBuilder.constantVector((ArrayList<Float>) ((Literal) left()).value()); | ||
| } else { | ||
| leftVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(left())); | ||
| } | ||
| if (right() instanceof Literal && right().dataType() == DENSE_VECTOR) { | ||
| rightVectorProviderBuilder.constantVector((ArrayList<Float>) ((Literal) right()).value()); | ||
| } else { | ||
| rightVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(right())); | ||
| } | ||
| return new SimilarityEvaluatorFactory( | ||
| toEvaluator.apply(left()), | ||
| toEvaluator.apply(right()), | ||
| leftVectorProviderBuilder, | ||
| rightVectorProviderBuilder, | ||
| getSimilarityFunction(), | ||
| getClass().getSimpleName() + "Evaluator" | ||
| ); | ||
|
|
@@ -93,8 +110,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe | |
| protected abstract SimilarityEvaluatorFunction getSimilarityFunction(); | ||
|
|
||
| private record SimilarityEvaluatorFactory( | ||
| EvalOperator.ExpressionEvaluator.Factory left, | ||
| EvalOperator.ExpressionEvaluator.Factory right, | ||
| VectorValueProvider.Builder leftVectorProviderBuilder, | ||
| VectorValueProvider.Builder rightVectorProviderBuilder, | ||
| SimilarityEvaluatorFunction similarityFunction, | ||
| String evaluatorName | ||
| ) implements EvalOperator.ExpressionEvaluator.Factory { | ||
|
|
@@ -103,8 +120,8 @@ private record SimilarityEvaluatorFactory( | |
| public EvalOperator.ExpressionEvaluator get(DriverContext context) { | ||
| // TODO check whether to use this custom evaluator or reuse / define an existing one | ||
| return new SimilarityEvaluator( | ||
| left.get(context), | ||
| right.get(context), | ||
| leftVectorProviderBuilder.build(context), | ||
| rightVectorProviderBuilder.build(context), | ||
| similarityFunction, | ||
| evaluatorName, | ||
| context.blockFactory() | ||
|
|
@@ -113,13 +130,150 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) { | |
|
|
||
| @Override | ||
| public String toString() { | ||
| return evaluatorName() + "[left=" + left + ", right=" + right + "]"; | ||
| return evaluatorName() + "[left=" + leftVectorProviderBuilder + ", right=" + rightVectorProviderBuilder + "]"; | ||
| } | ||
| } | ||
|
|
||
| private static class VectorValueProvider implements Releasable { | ||
|
|
||
| private static final class Builder { | ||
| private ArrayList<Float> constantVector; | ||
| private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory; | ||
|
|
||
| void constantVector(ArrayList<Float> constantVector) { | ||
| this.constantVector = constantVector; | ||
| } | ||
|
|
||
| void expressionEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) { | ||
| this.expressionEvaluatorFactory = expressionEvaluatorFactory; | ||
| } | ||
|
|
||
| VectorValueProvider build(DriverContext context) { | ||
| if (false == (constantVector == null ^ expressionEvaluatorFactory == null)) { | ||
| throw new IllegalArgumentException( | ||
| "One of [constantVector] or [expressionEvaluatorFactory] must be set, but not both." | ||
| ); | ||
| } | ||
| return new VectorValueProvider( | ||
| constantVector, | ||
| expressionEvaluatorFactory == null ? null : expressionEvaluatorFactory.get(context) | ||
| ); | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return "constantVector=" | ||
| + (constantVector == null ? "[null]" : constantVector) | ||
| + ", expressionEvaluator=[" | ||
| + expressionEvaluatorFactory | ||
| + "]"; | ||
| } | ||
| } | ||
|
|
||
| private final float[] constantVector; | ||
| private final EvalOperator.ExpressionEvaluator expressionEvaluator; | ||
| private FloatBlock block; | ||
| private float[] scratch; | ||
|
|
||
| VectorValueProvider(ArrayList<Float> constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) { | ||
|
||
| if (constantVector != null) { | ||
| this.constantVector = new float[constantVector.size()]; | ||
| for (int i = 0; i < constantVector.size(); i++) { | ||
| this.constantVector[i] = constantVector.get(i); | ||
| } | ||
| } else { | ||
| this.constantVector = null; | ||
| } | ||
| this.expressionEvaluator = expressionEvaluator; | ||
| } | ||
|
|
||
| private void eval(Page page) { | ||
| if (expressionEvaluator != null) { | ||
| block = (FloatBlock) expressionEvaluator.eval(page); | ||
| } | ||
| } | ||
|
|
||
| private float[] getVector(int position) { | ||
| if (constantVector != null) { | ||
| return constantVector; | ||
| } else if (block != null) { | ||
| if (block.isNull(position)) { | ||
| return null; | ||
| } | ||
| if (scratch == null) { | ||
| int dims = block.getValueCount(position); | ||
| if (dims > 0) { | ||
| scratch = new float[dims]; | ||
| } | ||
| } | ||
| if (scratch != null) { | ||
| readFloatArray(block, block.getFirstValueIndex(position), scratch); | ||
| } | ||
| return scratch; | ||
| } else { | ||
| throw new EsqlClientException( | ||
| "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) { | ||
| for (int i = 0; i < scratch.length; i++) { | ||
| scratch[i] = block.getFloat(firstValueIndex + i); | ||
| } | ||
| } | ||
|
|
||
| public int getDimensions() { | ||
| if (constantVector != null) { | ||
| return constantVector.length; | ||
| } else if (block != null) { | ||
| for (int p = 0; p < block.getPositionCount(); p++) { | ||
| int dims = block.getValueCount(p); | ||
| if (dims > 0) { | ||
| return dims; | ||
| } | ||
| } | ||
| return 0; | ||
| } else { | ||
| throw new EsqlClientException( | ||
| "[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null." | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| public long baseRamBytesUsed() { | ||
| return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null | ||
| ? 0 | ||
| : expressionEvaluator.baseRamBytesUsed()); | ||
| } | ||
|
|
||
| // Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be | ||
| // filled by the next page through `eval` | ||
| void finishBlock() { | ||
| if (block != null) { | ||
| block.close(); | ||
| block = null; | ||
| scratch = null; | ||
| } | ||
| } | ||
|
|
||
| public void close() { | ||
| Releasables.close(expressionEvaluator); | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return "constantVector=" | ||
| + (constantVector == null ? "[null]" : Arrays.toString(constantVector)) | ||
| + ", expressionEvaluator=[" | ||
| + expressionEvaluator | ||
| + "]"; | ||
| } | ||
| } | ||
|
|
||
| private record SimilarityEvaluator( | ||
| EvalOperator.ExpressionEvaluator left, | ||
| EvalOperator.ExpressionEvaluator right, | ||
| VectorValueProvider left, | ||
| VectorValueProvider right, | ||
| SimilarityEvaluatorFunction similarityFunction, | ||
| String evaluatorName, | ||
| BlockFactory blockFactory | ||
|
|
@@ -129,26 +283,23 @@ private record SimilarityEvaluator( | |
|
|
||
| @Override | ||
| public Block eval(Page page) { | ||
| try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) { | ||
| try { | ||
| left.eval(page); | ||
| right.eval(page); | ||
|
|
||
| int dimensions = left.getDimensions(); | ||
| int positionCount = page.getPositionCount(); | ||
| int dimensions = 0; | ||
| // Get the first non-empty vector to calculate the dimension | ||
| for (int p = 0; p < positionCount; p++) { | ||
| if (leftBlock.getValueCount(p) != 0) { | ||
| dimensions = leftBlock.getValueCount(p); | ||
| break; | ||
| } | ||
| } | ||
| if (dimensions == 0) { | ||
| return blockFactory.newConstantNullBlock(positionCount); | ||
| } | ||
|
|
||
| float[] leftScratch = new float[dimensions]; | ||
| float[] rightScratch = new float[dimensions]; | ||
| try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount * dimensions)) { | ||
| try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) { | ||
| for (int p = 0; p < positionCount; p++) { | ||
| int dimsLeft = leftBlock.getValueCount(p); | ||
| int dimsRight = rightBlock.getValueCount(p); | ||
| float[] leftVector = left.getVector(p); | ||
| float[] rightVector = right.getVector(p); | ||
|
|
||
| int dimsLeft = leftVector == null ? 0 : leftVector.length; | ||
| int dimsRight = rightVector == null ? 0 : rightVector.length; | ||
|
|
||
| if (dimsLeft == 0 || dimsRight == 0) { | ||
| // A null value on the left or right vector. Similarity is null | ||
|
|
@@ -161,19 +312,14 @@ public Block eval(Page page) { | |
| dimsRight | ||
| ); | ||
| } | ||
| readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); | ||
| readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); | ||
| float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); | ||
| float result = similarityFunction.calculateSimilarity(leftVector, rightVector); | ||
| builder.appendDouble(result); | ||
| } | ||
| return builder.build(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { | ||
| for (int i = 0; i < dimensions; i++) { | ||
| scratch[i] = block.getFloat(position + i); | ||
| } finally { | ||
| left.finishBlock(); | ||
| right.finishBlock(); | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😁 thanks for fixing!