Skip to content

Commit 7383dc9

Browse files
committed
Add tests for casting using ToDenseVector
1 parent 8889ab5 commit 7383dc9

File tree

10 files changed

+267
-44
lines changed

10 files changed

+267
-44
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values
2-
# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a
3-
# top-n query at the shard level
4-
51
knnSearch
62
required_capability: knn_function_v3
73

@@ -306,3 +302,28 @@ c: long | primary: boolean
306302
41 | false
307303
9 | true
308304
;
305+
306+
knnWithCasting
307+
required_capability: knn_function_v3
308+
required_capability: to_dense_vector_function
309+
310+
from colors metadata _score
311+
| eval query = [0, 120, 0]
312+
| where knn(rgb_vector, query, 10)
313+
| sort _score desc, color asc
314+
| keep color, rgb_vector
315+
| limit 10
316+
;
317+
318+
color:text | rgb_vector:dense_vector
319+
green | [0.0, 128.0, 0.0]
320+
black | [0.0, 0.0, 0.0]
321+
olive | [128.0, 128.0, 0.0]
322+
teal | [0.0, 128.0, 128.0]
323+
lime | [0.0, 255.0, 0.0]
324+
sienna | [160.0, 82.0, 45.0]
325+
maroon | [128.0, 0.0, 0.0]
326+
navy | [0.0, 0.0, 128.0]
327+
gray | [128.0, 128.0, 128.0]
328+
chartreuse | [127.0, 255.0, 0.0]
329+
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,40 @@ total_null:long
9090
59
9191
;
9292

93-
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
94-
similarityWithRow-Ignore
93+
similarityWithRow
9594
required_capability: cosine_vector_similarity_function
95+
required_capability: to_dense_vector_function
9696

9797
row vector = [1, 2, 3]
9898
| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3)
99+
;
100+
101+
vector: integer | similarity:double
102+
[1, 2, 3] | 0.978
103+
;
104+
105+
similarityWithVectorField
106+
required_capability: cosine_vector_similarity_function
107+
required_capability: to_dense_vector_function
108+
109+
from colors
110+
| where color != "black"
111+
| eval query = [0, 255, 255]
112+
| eval similarity = v_cosine(rgb_vector, query)
99113
| sort similarity desc, color asc
100114
| limit 10
101115
| keep color, similarity
102116
;
103117

104-
similarity:double
105-
0.978
118+
color:text | similarity:double
119+
cyan | 1.0
120+
teal | 1.0
121+
turquoise | 0.9890533685684204
122+
aqua marine | 0.964962363243103
123+
azure | 0.916246771812439
124+
lavender | 0.9136701822280884
125+
mint cream | 0.9122757911682129
126+
honeydew | 0.9122424125671387
127+
gainsboro | 0.9082483053207397
128+
gray | 0.9082483053207397
106129
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,39 @@ total_null:long
8888
;
8989

9090

91-
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
92-
similarityWithRow-Ignore
91+
similarityWithRow
9392
required_capability: dot_product_vector_similarity_function
93+
required_capability: to_dense_vector_function
9494

9595
row vector = [1, 2, 3]
9696
| eval similarity = round(v_dot_product(vector, [0, 1, 2]), 3)
97+
;
98+
99+
vector: integer | similarity:double
100+
[1, 2, 3] | 4.5
101+
;
102+
103+
similarityWithVectorField
104+
required_capability: dot_product_vector_similarity_function
105+
required_capability: to_dense_vector_function
106+
107+
from colors
108+
| eval query = [0, 255, 255]
109+
| eval similarity = v_dot_product(rgb_vector, query)
97110
| sort similarity desc, color asc
98111
| limit 10
99112
| keep color, similarity
100113
;
101114

102-
similarity:double
103-
0.978
115+
color:text | similarity:double
116+
azure | 65025.5
117+
cyan | 65025.5
118+
white | 65025.5
119+
mint cream | 64388.0
120+
snow | 63750.5
121+
honeydew | 63113.0
122+
ivory | 63113.0
123+
sea shell | 61583.0
124+
lavender | 61200.5
125+
old lace | 60563.0
104126
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-hamming.csv-spec

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,39 @@ total_null:long
8787
59
8888
;
8989

90-
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
91-
similarityWithRow-Ignore
90+
similarityWithRow
9291
required_capability: hamming_vector_similarity_function
92+
required_capability: to_dense_vector_function
9393

9494
row vector = [1, 2, 3]
9595
| eval similarity = round(v_hamming(vector, [0, 1, 2]), 3)
96+
;
97+
98+
vector: integer | similarity:double
99+
[1, 2, 3] | 4.0
100+
;
101+
102+
similarityWithVectorField
103+
required_capability: hamming_vector_similarity_function
104+
required_capability: to_dense_vector_function
105+
106+
from colors
107+
| eval query = [0, 255, 255]
108+
| eval similarity = v_hamming(rgb_vector, query)
96109
| sort similarity desc, color asc
97110
| limit 10
98111
| keep color, similarity
99112
;
100-
101-
similarity:double
102-
0.978
113+
114+
color:text | similarity:double
115+
red | 24.0
116+
orange | 20.0
117+
gold | 18.0
118+
indigo | 18.0
119+
bisque | 17.0
120+
maroon | 17.0
121+
pink | 17.0
122+
salmon | 17.0
123+
black | 16.0
124+
firebrick | 16.0
103125
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,39 @@ total_null:long
8787
59
8888
;
8989

90-
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
91-
similarityWithRow-Ignore
90+
similarityWithRow
9291
required_capability: l1_norm_vector_similarity_function
92+
required_capability: to_dense_vector_function
9393

9494
row vector = [1, 2, 3]
9595
| eval similarity = round(v_l1_norm(vector, [0, 1, 2]), 3)
96+
;
97+
98+
vector: integer | similarity:double
99+
[1, 2, 3] | 3.0
100+
;
101+
102+
similarityWithVectorField
103+
required_capability: l1_norm_vector_similarity_function
104+
required_capability: to_dense_vector_function
105+
106+
from colors
107+
| eval query = [0, 255, 255]
108+
| eval similarity = v_l1_norm(rgb_vector, query)
96109
| sort similarity desc, color asc
97110
| limit 10
98111
| keep color, similarity
99112
;
100-
101-
similarity:double
102-
0.978
113+
114+
color:text | similarity:double
115+
red | 765.0
116+
crimson | 650.0
117+
maroon | 638.0
118+
firebrick | 620.0
119+
orange | 600.0
120+
tomato | 595.0
121+
brown | 591.0
122+
chocolate | 585.0
123+
coral | 558.0
124+
gold | 550.0
103125
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,39 @@ total_null:long
8787
59
8888
;
8989

90-
# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
91-
similarityWithRow-Ignore
90+
similarityWithRow
9291
required_capability: l2_norm_vector_similarity_function
92+
required_capability: to_dense_vector_function
9393

9494
row vector = [1, 2, 3]
9595
| eval similarity = round(v_l2_norm(vector, [0, 1, 2]), 3)
96+
;
97+
98+
vector: integer | similarity:double
99+
[1, 2, 3] | 1.732
100+
;
101+
102+
similarityWithVectorField
103+
required_capability: l2_norm_vector_similarity_function
104+
required_capability: to_dense_vector_function
105+
106+
from colors
107+
| eval query = [0, 255, 255]
108+
| eval similarity = v_l2_norm(rgb_vector, query)
96109
| sort similarity desc, color asc
97110
| limit 10
98111
| keep color, similarity
99112
;
100-
101-
similarity:double
102-
0.978
113+
114+
color:text | similarity:double
115+
red | 441.6729431152344
116+
maroon | 382.6669616699219
117+
crimson | 376.36419677734375
118+
orange | 371.68536376953125
119+
gold | 362.8360595703125
120+
black | 360.62445068359375
121+
magenta | 360.62445068359375
122+
yellow | 360.62445068359375
123+
firebrick | 359.67486572265625
124+
tomato | 351.0227966308594
103125
;

x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,15 @@ row a = 1
8585
magnitude:double
8686
null
8787
;
88+
89+
magnitudeWithRow
90+
required_capability: magnitude_scalar_vector_function
91+
required_capability: to_dense_vector_function
92+
93+
row vector = [1, 2, 3]
94+
| eval magnitude = round(v_magnitude(vector), 3)
95+
;
96+
97+
vector: integer | magnitude:double
98+
[1, 2, 3] | 3.742
99+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,12 @@ public enum Cap {
14161416
/**
14171417
* URL encoding function.
14181418
*/
1419-
URL_ENCODE(Build.current().isSnapshot());
1419+
URL_ENCODE(Build.current().isSnapshot()),
1420+
1421+
/**
1422+
* TO_DENSE_VECTOR function.
1423+
*/
1424+
TO_DENSE_VECTOR_FUNCTION(Build.current().isSnapshot());
14201425

14211426
private final boolean enabled;
14221427

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
7777
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble;
7878
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos;
79+
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector;
7980
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
8081
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
8182
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
@@ -168,6 +169,7 @@
168169
import static org.elasticsearch.xpack.esql.core.type.DataType.TIME_DURATION;
169170
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
170171
import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION;
172+
import static org.elasticsearch.xpack.esql.core.type.DataType.isString;
171173
import static org.elasticsearch.xpack.esql.core.type.DataType.isTemporalAmount;
172174
import static org.elasticsearch.xpack.esql.telemetry.FeatureMetric.LIMIT;
173175
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.maybeParseTemporalAmount;
@@ -1668,18 +1670,24 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor
16681670
List<Expression> args = vectorFunction.arguments();
16691671
List<Expression> newArgs = new ArrayList<>();
16701672
for (Expression arg : args) {
1671-
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
1672-
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
1673-
if (folded instanceof List) {
1674-
// Convert to floats so blocks are created accordingly
1675-
List<Float> floatVector;
1676-
if (arg.dataType() == FLOAT) {
1677-
floatVector = (List<Float>) folded;
1678-
} else {
1679-
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
1673+
if (arg.resolved()) {
1674+
if (arg.foldable() && arg.dataType().isNumeric()) {
1675+
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
1676+
if (folded instanceof List) {
1677+
// Convert to floats so blocks are created accordingly
1678+
List<Float> floatVector;
1679+
if (arg.dataType() == FLOAT) {
1680+
floatVector = (List<Float>) folded;
1681+
} else {
1682+
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
1683+
}
1684+
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
1685+
newArgs.add(denseVector);
1686+
continue;
16801687
}
1681-
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
1682-
newArgs.add(denseVector);
1688+
} else if (arg.dataType().isNumeric() || isString(arg.dataType())) {
1689+
// add casting function
1690+
newArgs.add(new ToDenseVector(arg.source(), arg));
16831691
continue;
16841692
}
16851693
}

0 commit comments

Comments
 (0)