Skip to content

Commit fc88621

Browse files
committed
Allow using literals for dense_vectors, converting them internally to FloatBlocks
1 parent 58bd1c0 commit fc88621

File tree

5 files changed

+97
-32
lines changed

5 files changed

+97
-32
lines changed

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/BlockUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ public static void appendValue(Block.Builder builder, Object val, ElementType ty
213213
case LONG -> ((LongBlock.Builder) builder).appendLong((Long) val);
214214
case INT -> ((IntBlock.Builder) builder).appendInt((Integer) val);
215215
case BYTES_REF -> ((BytesRefBlock.Builder) builder).appendBytesRef(toBytesRef(val));
216-
case FLOAT -> ((FloatBlock.Builder) builder).appendFloat((Float) val);
216+
case FLOAT -> ((FloatBlock.Builder) builder).appendFloat(((Number) val).floatValue());
217217
case DOUBLE -> ((DoubleBlock.Builder) builder).appendDouble((Double) val);
218218
case BOOLEAN -> ((BooleanBlock.Builder) builder).appendBoolean((Boolean) val);
219219
default -> throw new UnsupportedOperationException("unsupported element type [" + type + "]");

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

Lines changed: 86 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
package org.elasticsearch.xpack.esql.vector;
99

10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
12+
1013
import org.apache.lucene.index.VectorSimilarityFunction;
1114
import org.elasticsearch.action.index.IndexRequestBuilder;
1215
import org.elasticsearch.cluster.metadata.IndexMetadata;
@@ -20,67 +23,124 @@
2023

2124
import java.io.IOException;
2225
import java.util.ArrayList;
26+
import java.util.Arrays;
2327
import java.util.List;
28+
import java.util.Locale;
2429

2530
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
2631

2732
public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
2833

34+
@ParametersFactory
35+
public static Iterable<Object[]> parameters() throws Exception {
36+
List<Object[]> params = new ArrayList<>();
37+
38+
params.add(new Object[] { "v_cosine_similarity", VectorSimilarityFunction.COSINE });
39+
40+
return params;
41+
}
42+
43+
private final String functionName;
44+
private final VectorSimilarityFunction similarityFunction;
45+
private int numDims;
46+
47+
public VectorSimilarityFunctionsIT(
48+
@Name("functionName") String functionName,
49+
@Name("similarityFunction") VectorSimilarityFunction similarityFunction
50+
) {
51+
this.functionName = functionName;
52+
this.similarityFunction = similarityFunction;
53+
}
54+
2955
@SuppressWarnings("unchecked")
30-
public void testCosineSimilarity() {
31-
var query = """
56+
public void testCosineSimilarityBetweenVectors() {
57+
var query = String.format(Locale.ROOT, """
3258
FROM test
33-
| EVAL similarity = v_cosine_similarity(left_vector, right_vector)
34-
| KEEP id, left_vector, right_vector, similarity
35-
""";
59+
| EVAL similarity = %s(left_vector, right_vector)
60+
| KEEP left_vector, right_vector, similarity
61+
""", functionName);
3662

3763
try (var resp = run(query)) {
3864
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
3965
valuesList.forEach(values -> {
40-
List<Float> leftVector = (List<Float>) values.get(1);
41-
float[] leftScratch = new float[leftVector.size()];
42-
for (int i = 0; i < leftVector.size(); i++) {
43-
leftScratch[i] = leftVector.get(i);
44-
}
45-
List<Float> rightVector = (List<Float>) values.get(2);
46-
float[] rightScratch = new float[rightVector.size()];
47-
for (int i = 0; i < rightVector.size(); i++) {
48-
rightScratch[i] = rightVector.get(i);
49-
}
50-
Double similarity = (Double) values.get(3);
66+
float[] left = readVector((List<Float>) values.get(0));
67+
float[] right = readVector((List<Float>) values.get(1));
68+
Double similarity = (Double) values.get(2);
69+
5170
assertNotNull(similarity);
71+
float expectedSimilarity = similarityFunction.compare(left, right);
72+
assertEquals(expectedSimilarity, similarity, 0.0001);
73+
});
74+
}
75+
}
76+
77+
@SuppressWarnings("unchecked")
78+
public void testCosineSimilarityBetweenConstantVectorAndField() {
79+
var randomVector = randomVectorArray();
80+
var query = String.format(Locale.ROOT, """
81+
FROM test
82+
| EVAL similarity = %s(left_vector, %s)
83+
| KEEP left_vector, similarity
84+
""", functionName, Arrays.toString(randomVector));
85+
86+
try (var resp = run(query)) {
87+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
88+
valuesList.forEach(values -> {
89+
float[] left = readVector((List<Float>) values.get(0));
90+
Double similarity = (Double) values.get(1);
5291

53-
float expectedSimilarity = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
92+
assertNotNull(similarity);
93+
float expectedSimilarity = similarityFunction.compare(left, randomVector);
5494
assertEquals(expectedSimilarity, similarity, 0.0001);
5595
});
5696
}
5797
}
5898

99+
private static float[] readVector(List<Float> leftVector) {
100+
float[] leftScratch = new float[leftVector.size()];
101+
for (int i = 0; i < leftVector.size(); i++) {
102+
leftScratch[i] = leftVector.get(i);
103+
}
104+
return leftScratch;
105+
}
106+
59107
@Before
60108
public void setup() throws IOException {
61109
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
62110

63111
createIndexWithDenseVector("test");
64112

65-
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
113+
numDims = randomIntBetween(32, 64) * 2; // min 64, even number
66114
int numDocs = randomIntBetween(10, 100);
67115
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
68116
for (int i = 0; i < numDocs; i++) {
69-
List<Float> leftVector = new ArrayList<>(numDims);
70-
for (int j = 0; j < numDims; j++) {
71-
leftVector.add(randomFloat());
72-
}
73-
List<Float> rightVector = new ArrayList<>(numDims);
74-
for (int j = 0; j < numDims; j++) {
75-
rightVector.add(randomFloat());
76-
}
117+
List<Float> leftVector = randomVector();
118+
List<Float> rightVector = randomVector();
77119
docs[i] = prepareIndex("test").setId("" + i)
78120
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
79121
}
80122

81123
indexRandom(true, docs);
82124
}
83125

126+
private List<Float> randomVector() {
127+
assert numDims != 0 : "numDims must be set before calling randomVector()";
128+
List<Float> vector = new ArrayList<>(numDims);
129+
for (int j = 0; j < numDims; j++) {
130+
vector.add(randomFloat());
131+
}
132+
return vector;
133+
}
134+
135+
private float[] randomVectorArray() {
136+
assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
137+
float[] vector = new float[numDims];
138+
for (int j = 0; j < numDims; j++) {
139+
vector[j] = randomFloat();
140+
}
141+
return vector;
142+
}
143+
84144
private void createIndexWithDenseVector(String indexName) throws IOException {
85145
var client = client().admin().indices();
86146
XContentBuilder mapping = XContentFactory.jsonBuilder()

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,15 +1399,15 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func
13991399
if (f instanceof In in) {
14001400
return processIn(in);
14011401
}
1402+
if (f instanceof VectorFunction) {
1403+
return processVectorFunction(f);
1404+
}
14021405
if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
14031406
return processScalarOrGroupingFunction(f, registry);
14041407
}
14051408
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
14061409
return processBinaryOperator((BinaryOperator) f);
14071410
}
1408-
if (f instanceof VectorFunction vectorFunction) {
1409-
return processVectorFunction(f);
1410-
}
14111411
return f;
14121412
}
14131413

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.xpack.esql.core.expression.Expression;
2525
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
2626
import org.elasticsearch.xpack.esql.core.expression.Literal;
27+
import org.elasticsearch.xpack.esql.core.type.DataType;
2728
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
2829
import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper;
2930
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
@@ -248,7 +249,11 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions
248249
if (multiValue.isEmpty()) {
249250
return blockFactory.newConstantNullBlock(positions);
250251
}
251-
var wrapper = BlockUtils.wrapperFor(blockFactory, ElementType.fromJava(multiValue.get(0).getClass()), positions);
252+
// dense_vector create internally float values, even if they are specified as doubles
253+
ElementType elementType = lit.dataType() == DataType.DENSE_VECTOR
254+
? ElementType.FLOAT
255+
: ElementType.fromJava(multiValue.get(0).getClass());
256+
var wrapper = BlockUtils.wrapperFor(blockFactory, elementType, positions);
252257
for (int i = 0; i < positions; i++) {
253258
wrapper.accept(multiValue);
254259
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
2424

25-
public class CosineSimilarity extends org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction {
25+
public class CosineSimilarity extends VectorSimilarityFunction {
2626

2727
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
2828
Expression.class,

0 commit comments

Comments
 (0)