Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7d9e975
iter
pmpailis Sep 26, 2025
abe7f01
iter
pmpailis Sep 26, 2025
989235e
iter
pmpailis Sep 26, 2025
aeebc6e
iter
pmpailis Sep 26, 2025
240dd53
iter
pmpailis Sep 26, 2025
4a759ba
iter
pmpailis Sep 26, 2025
e386b62
Merge remote-tracking branch 'origin/main' into pb_134210_optimize_ve…
pmpailis Sep 26, 2025
0301c5b
iter
pmpailis Sep 29, 2025
97d1958
Merge remote-tracking branch 'origin/main' into pb_134210_optimize_ve…
pmpailis Sep 29, 2025
cafafcb
[CI] Auto commit changes from spotless
Sep 29, 2025
9f12dbe
iter
pmpailis Sep 29, 2025
3509bba
Merge branch 'pb_134210_optimize_vector_similarity_when_constant' of …
pmpailis Sep 29, 2025
1fc70a2
checkstyle
pmpailis Sep 29, 2025
7b5b760
[CI] Auto commit changes from spotless
Sep 29, 2025
9d51231
checkstyle
pmpailis Sep 29, 2025
504c92e
checkstyle
pmpailis Sep 29, 2025
59cdda9
iter
pmpailis Sep 30, 2025
6da8fb4
iter
pmpailis Sep 30, 2025
42dcf15
iter
pmpailis Sep 30, 2025
c336ede
Merge branch 'main' into pb_134210_optimize_vector_similarity_when_co…
pmpailis Oct 1, 2025
4640bb8
iter
pmpailis Oct 1, 2025
8d72e5c
adding test cases for toString when one vector is a literal
pmpailis Oct 1, 2025
c25f0bd
[CI] Auto commit changes from spotless
Oct 1, 2025
5d95e16
Merge branch 'main' into pb_134210_optimize_vector_similarity_when_co…
pmpailis Oct 1, 2025
ee019f2
addressing PR comments - adding common factory interface
pmpailis Oct 2, 2025
96c1a53
Merge branch 'pb_134210_optimize_vector_similarity_when_constant' of …
pmpailis Oct 2, 2025
9f47ab2
extracting method for generating the appropriate factory
pmpailis Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ public enum DataType implements Writeable {
),

/**
* Fields with this type are dense vectors, represented as an array of double values.
* Fields with this type are dense vectors, represented as an array of float values.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😁 thanks for fixing!

*/
DENSE_VECTOR(
builder().esType("dense_vector")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.EsqlClientException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;

import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
Expand Down Expand Up @@ -78,10 +82,23 @@ public Object fold(FoldContext ctx) {
}

@Override
@SuppressWarnings("unchecked")
public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
VectorValueProvider.Builder leftVectorProviderBuilder = new VectorValueProvider.Builder();
VectorValueProvider.Builder rightVectorProviderBuilder = new VectorValueProvider.Builder();
if (left() instanceof Literal && left().dataType() == DENSE_VECTOR) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can count on the arguments being type dense_vector, it is already checked as part of checkDenseVectorParam() and would have been caught at the Analyzer level

leftVectorProviderBuilder.constantVector((ArrayList<Float>) ((Literal) left()).value());
} else {
leftVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(left()));
}
if (right() instanceof Literal && right().dataType() == DENSE_VECTOR) {
rightVectorProviderBuilder.constantVector((ArrayList<Float>) ((Literal) right()).value());
} else {
rightVectorProviderBuilder.expressionEvaluatorFactory(toEvaluator.apply(right()));
}
return new SimilarityEvaluatorFactory(
toEvaluator.apply(left()),
toEvaluator.apply(right()),
leftVectorProviderBuilder,
rightVectorProviderBuilder,
getSimilarityFunction(),
getClass().getSimpleName() + "Evaluator"
);
Expand All @@ -93,8 +110,8 @@ public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMappe
protected abstract SimilarityEvaluatorFunction getSimilarityFunction();

private record SimilarityEvaluatorFactory(
EvalOperator.ExpressionEvaluator.Factory left,
EvalOperator.ExpressionEvaluator.Factory right,
VectorValueProvider.Builder leftVectorProviderBuilder,
VectorValueProvider.Builder rightVectorProviderBuilder,
SimilarityEvaluatorFunction similarityFunction,
String evaluatorName
) implements EvalOperator.ExpressionEvaluator.Factory {
Expand All @@ -103,8 +120,8 @@ private record SimilarityEvaluatorFactory(
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
// TODO check whether to use this custom evaluator or reuse / define an existing one
return new SimilarityEvaluator(
left.get(context),
right.get(context),
leftVectorProviderBuilder.build(context),
rightVectorProviderBuilder.build(context),
similarityFunction,
evaluatorName,
context.blockFactory()
Expand All @@ -113,13 +130,150 @@ public EvalOperator.ExpressionEvaluator get(DriverContext context) {

@Override
public String toString() {
return evaluatorName() + "[left=" + left + ", right=" + right + "]";
return evaluatorName() + "[left=" + leftVectorProviderBuilder + ", right=" + rightVectorProviderBuilder + "]";
}
}

private static class VectorValueProvider implements Releasable {

private static final class Builder {
private ArrayList<Float> constantVector;
private EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory;

void constantVector(ArrayList<Float> constantVector) {
this.constantVector = constantVector;
}

void expressionEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory expressionEvaluatorFactory) {
this.expressionEvaluatorFactory = expressionEvaluatorFactory;
}

VectorValueProvider build(DriverContext context) {
if (false == (constantVector == null ^ expressionEvaluatorFactory == null)) {
throw new IllegalArgumentException(
"One of [constantVector] or [expressionEvaluatorFactory] must be set, but not both."
);
}
return new VectorValueProvider(
constantVector,
expressionEvaluatorFactory == null ? null : expressionEvaluatorFactory.get(context)
);
}

@Override
public String toString() {
return "constantVector="
+ (constantVector == null ? "[null]" : constantVector)
+ ", expressionEvaluator=["
+ expressionEvaluatorFactory
+ "]";
}
}

private final float[] constantVector;
private final EvalOperator.ExpressionEvaluator expressionEvaluator;
private FloatBlock block;
private float[] scratch;

VectorValueProvider(ArrayList<Float> constantVector, EvalOperator.ExpressionEvaluator expressionEvaluator) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to use a class to wrap the different ways of retrieving vectors!

Would it make sense to create two separate subclasses, so we don't need to keep checking for what has been provided?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice suggestion and will make things easier to follow, thanks! Will refactor to have 2 subclasses.

if (constantVector != null) {
this.constantVector = new float[constantVector.size()];
for (int i = 0; i < constantVector.size(); i++) {
this.constantVector[i] = constantVector.get(i);
}
} else {
this.constantVector = null;
}
this.expressionEvaluator = expressionEvaluator;
}

private void eval(Page page) {
if (expressionEvaluator != null) {
block = (FloatBlock) expressionEvaluator.eval(page);
}
}

private float[] getVector(int position) {
if (constantVector != null) {
return constantVector;
} else if (block != null) {
if (block.isNull(position)) {
return null;
}
if (scratch == null) {
int dims = block.getValueCount(position);
if (dims > 0) {
scratch = new float[dims];
}
}
if (scratch != null) {
readFloatArray(block, block.getFirstValueIndex(position), scratch);
}
return scratch;
} else {
throw new EsqlClientException(
"[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null."
);
}
}

private static void readFloatArray(FloatBlock block, int firstValueIndex, float[] scratch) {
for (int i = 0; i < scratch.length; i++) {
scratch[i] = block.getFloat(firstValueIndex + i);
}
}

public int getDimensions() {
if (constantVector != null) {
return constantVector.length;
} else if (block != null) {
for (int p = 0; p < block.getPositionCount(); p++) {
int dims = block.getValueCount(p);
if (dims > 0) {
return dims;
}
}
return 0;
} else {
throw new EsqlClientException(
"[" + getClass().getSimpleName() + "] not properly initialized. Both [constantVector] and [block] are null."
);
}
}

public long baseRamBytesUsed() {
return (constantVector == null ? 0 : RamUsageEstimator.shallowSizeOf(constantVector)) + (expressionEvaluator == null
? 0
: expressionEvaluator.baseRamBytesUsed());
}

// Once we're doing processing a block, ensure that we dereference it and reset scratch so that it can be
// filled by the next page through `eval`
void finishBlock() {
if (block != null) {
block.close();
block = null;
scratch = null;
}
}

public void close() {
Releasables.close(expressionEvaluator);
}

@Override
public String toString() {
return "constantVector="
+ (constantVector == null ? "[null]" : Arrays.toString(constantVector))
+ ", expressionEvaluator=["
+ expressionEvaluator
+ "]";
}
}

private record SimilarityEvaluator(
EvalOperator.ExpressionEvaluator left,
EvalOperator.ExpressionEvaluator right,
VectorValueProvider left,
VectorValueProvider right,
SimilarityEvaluatorFunction similarityFunction,
String evaluatorName,
BlockFactory blockFactory
Expand All @@ -129,26 +283,23 @@ private record SimilarityEvaluator(

@Override
public Block eval(Page page) {
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
try {
left.eval(page);
right.eval(page);

int dimensions = left.getDimensions();
int positionCount = page.getPositionCount();
int dimensions = 0;
// Get the first non-empty vector to calculate the dimension
for (int p = 0; p < positionCount; p++) {
if (leftBlock.getValueCount(p) != 0) {
dimensions = leftBlock.getValueCount(p);
break;
}
}
if (dimensions == 0) {
return blockFactory.newConstantNullBlock(positionCount);
}

float[] leftScratch = new float[dimensions];
float[] rightScratch = new float[dimensions];
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount * dimensions)) {
try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(positionCount)) {
for (int p = 0; p < positionCount; p++) {
int dimsLeft = leftBlock.getValueCount(p);
int dimsRight = rightBlock.getValueCount(p);
float[] leftVector = left.getVector(p);
float[] rightVector = right.getVector(p);

int dimsLeft = leftVector == null ? 0 : leftVector.length;
int dimsRight = rightVector == null ? 0 : rightVector.length;

if (dimsLeft == 0 || dimsRight == 0) {
// A null value on the left or right vector. Similarity is null
Expand All @@ -161,19 +312,14 @@ public Block eval(Page page) {
dimsRight
);
}
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
float result = similarityFunction.calculateSimilarity(leftVector, rightVector);
builder.appendDouble(result);
}
return builder.build();
}
}
}

private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
for (int i = 0; i < dimensions; i++) {
scratch[i] = block.getFloat(position + i);
} finally {
left.finishBlock();
right.finishBlock();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ protected static Iterable<Object[]> similarityParameters(
VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction
) {

final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]";
final String evaluatorName = className
+ "Evaluator[left=constantVector=[null], expressionEvaluator=[Attribute[channel=0]], "
+ "right=constantVector=[null], expressionEvaluator=[Attribute[channel=1]]]";

List<TestCaseSupplier> suppliers = new ArrayList<>();

Expand Down