Skip to content

Commit 3ab9a72

Browse files
committed
Add checks for different number of dimensions
1 parent 7517557 commit 3ab9a72

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.xcontent.XContentBuilder;
1818
import org.elasticsearch.xcontent.XContentFactory;
19+
import org.elasticsearch.xpack.esql.EsqlClientException;
1920
import org.elasticsearch.xpack.esql.EsqlTestUtils;
2021
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
2122
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
@@ -96,6 +97,18 @@ public void testSimilarityBetweenConstantVectorAndField() {
9697
}
9798
}
9899

100+
public void testDifferentDimensions() {
101+
var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
102+
var query = String.format(Locale.ROOT, """
103+
FROM test
104+
| EVAL similarity = %s(left_vector, %s)
105+
| KEEP left_vector, similarity
106+
""", functionName, Arrays.toString(randomVector));
107+
108+
EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
109+
assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
110+
}
111+
99112
@SuppressWarnings("unchecked")
100113
public void testSimilarityBetweenConstantVectors() {
101114
var vectorLeft = randomVectorArray();
@@ -155,8 +168,12 @@ private List<Float> randomVector() {
155168

156169
private float[] randomVectorArray() {
157170
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
158-
float[] vector = new float[numDims];
159-
for (int j = 0; j < numDims; j++) {
171+
return randomVectorArray(numDims);
172+
}
173+
174+
private static float[] randomVectorArray(int dimensions) {
175+
float[] vector = new float[dimensions];
176+
for (int j = 0; j < dimensions; j++) {
160177
vector[j] = randomFloat();
161178
}
162179
return vector;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.compute.data.Page;
1515
import org.elasticsearch.compute.operator.DriverContext;
1616
import org.elasticsearch.compute.operator.EvalOperator;
17+
import org.elasticsearch.xpack.esql.EsqlClientException;
1718
import org.elasticsearch.xpack.esql.core.expression.Expression;
1819
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
1920
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -124,8 +125,13 @@ public Block eval(Page page) {
124125

125126
int dimensions = leftBlock.getValueCount(0);
126127
int dimsRight = rightBlock.getValueCount(0);
127-
assert dimensions == dimsRight
128-
: "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight;
128+
if (dimensions != dimsRight) {
129+
throw new EsqlClientException(
130+
"Vectors must have the same dimensions; first vector has {}, and second has {}",
131+
dimensions,
132+
dimsRight
133+
);
134+
}
129135
float[] leftScratch = new float[dimensions];
130136
float[] rightScratch = new float[dimensions];
131137
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {

0 commit comments

Comments
 (0)