diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/to_dense_vector.md b/docs/reference/query-languages/esql/_snippets/functions/description/to_dense_vector.md new file mode 100644 index 0000000000000..15bc7760b7803 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/to_dense_vector.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/to_dense_vector.md b/docs/reference/query-languages/esql/_snippets/functions/examples/to_dense_vector.md new file mode 100644 index 0000000000000..f202aeeff6dc9 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/to_dense_vector.md @@ -0,0 +1,15 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Example** + +```esql +row ints = [1, 2, 3] +| eval vector = to_dense_vector(ints) +| keep vector +``` + +| vector:dense_vector | +| --- | +| [1.0, 2.0, 3.0] | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/to_dense_vector.md b/docs/reference/query-languages/esql/_snippets/functions/layout/to_dense_vector.md new file mode 100644 index 0000000000000..a5eaef0deed19 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/to_dense_vector.md @@ -0,0 +1,23 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `TO_DENSE_VECTOR` [esql-to_dense_vector] + +**Syntax** + +:::{image} ../../../images/functions/to_dense_vector.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/to_dense_vector.md +::: + +:::{include} ../description/to_dense_vector.md +::: + +:::{include} ../types/to_dense_vector.md +::: + +:::{include} ../examples/to_dense_vector.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/to_dense_vector.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/to_dense_vector.md new file mode 100644 index 0000000000000..f68b97a694bf9 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/to_dense_vector.md @@ -0,0 +1,7 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field` +: multi-valued input of numbers or hexadecimal string to convert. + diff --git a/docs/reference/query-languages/esql/images/functions/to_dense_vector.svg b/docs/reference/query-languages/esql/images/functions/to_dense_vector.svg new file mode 100644 index 0000000000000..54304ee44b11f --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/to_dense_vector.svg @@ -0,0 +1 @@ +TO_DENSE_VECTOR(field) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/to_dense_vector.json b/docs/reference/query-languages/esql/kibana/definition/functions/to_dense_vector.json new file mode 100644 index 0000000000000..932937bf10c6c --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/to_dense_vector.json @@ -0,0 +1,12 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "to_dense_vector", + "description" : "Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector.", + "signatures" : [ ], + "examples" : [ + "row ints = [1, 2, 3]\n| eval vector = to_dense_vector(ints)\n| keep vector" + ], + "preview" : false, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/to_dense_vector.md b/docs/reference/query-languages/esql/kibana/docs/functions/to_dense_vector.md new file mode 100644 index 0000000000000..309d975be8bfc --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/to_dense_vector.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### TO DENSE VECTOR +Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector. + +```esql +row ints = [1, 2, 3] +| eval vector = to_dense_vector(ints) +| keep vector +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec index eed0328da6060..c8a24d84ce72a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector.csv-spec @@ -45,3 +45,59 @@ id:l | new_vector:dense_vector 2 | [9.0, 8.0, 7.0] 3 | [0.054, 0.032, 0.012] ; + +convertIntsToDenseVector +required_capability: dense_vector_field_type +required_capability: to_dense_vector_function + +// tag::to_dense_vector-ints[] +row ints = [1, 2, 3] +| eval vector = to_dense_vector(ints) +| keep vector +// end::to_dense_vector-ints[] +; + +// tag::to_dense_vector-ints-result[] +vector:dense_vector +[1.0, 2.0, 3.0] +// end::to_dense_vector-ints-result[] +; + +convertLongsToDenseVector +required_capability: dense_vector_field_type +required_capability: to_dense_vector_function + +row longs = [5013792, 2147483647, 501379200000] +| eval vector = to_dense_vector(longs) +| keep vector +; + +vector:dense_vector +[5013792.0, 2147483647.0, 501379200000.0] +; + +convertDoublesToDenseVector +required_capability: dense_vector_field_type +required_capability: to_dense_vector_function + +row doubles = [123.4, 567.8, 901.2] +| eval vector = to_dense_vector(doubles) +| keep vector +; + +vector:dense_vector +[123.4, 567.8, 901.2] +; + +convertHexStringToDenseVector +required_capability: dense_vector_field_type +required_capability: to_dense_vector_function + +row hex_str = "0102030405060708090a0b0c0d0e0f" +| eval vector = to_dense_vector(hex_str) +| keep vector +; + +vector:dense_vector + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index e65d65f414cd1..eadee9266f307 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -1,7 +1,3 @@ -# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values -# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a -# top-n query at the shard level - knnSearch required_capability: knn_function_v5 @@ -410,3 +406,52 @@ host:keyword | semantic_text_dense_field:text "host1" | live long and prosper ; + +knnWithCasting +required_capability: knn_function_v5 +required_capability: to_dense_vector_function + +from colors metadata _score +| eval query = to_dense_vector([0, 120, 0]) +| where knn(rgb_vector, query) +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +olive | [128.0, 128.0, 0.0] +teal | [0.0, 128.0, 128.0] +lime | [0.0, 255.0, 0.0] +sienna | [160.0, 82.0, 45.0] +maroon | [128.0, 0.0, 0.0] +navy | [0.0, 0.0, 128.0] +gray | [128.0, 128.0, 128.0] +chartreuse | [127.0, 255.0, 0.0] +; + +knnWithHexStringCasting +required_capability: knn_function_v5 +required_capability: to_dense_vector_function + +from colors metadata _score +| where knn(rgb_vector, "007800") +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +olive | [128.0, 128.0, 0.0] +teal | [0.0, 128.0, 128.0] +lime | [0.0, 255.0, 0.0] +sienna | [160.0, 82.0, 45.0] +maroon | [128.0, 0.0, 0.0] +navy | [0.0, 0.0, 128.0] +gray | [128.0, 128.0, 128.0] +chartreuse | [127.0, 255.0, 0.0] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec index 46d80609a06bf..451368deec934 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec @@ -90,17 +90,40 @@ total_null:long 59 ; -# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector -similarityWithRow-Ignore +similarityWithRow required_capability: cosine_vector_similarity_function +required_capability: to_dense_vector_function -row vector = [1, 2, 3] +row vector = to_dense_vector([1, 2, 3]) | eval similarity = round(v_cosine(vector, [0, 1, 2]), 3) +; + +vector: dense_vector | similarity:double +[1.0, 2.0, 3.0] | 0.978 +; + +similarityWithVectorField +required_capability: cosine_vector_similarity_function +required_capability: to_dense_vector_function + +from colors +| where color != "black" +| eval query = to_dense_vector([0, 255, 255]) +| eval similarity = v_cosine(rgb_vector, query) | sort similarity desc, color asc | limit 10 | keep color, similarity ; -similarity:double -0.978 +color:text | similarity:double +cyan | 1.0 +teal | 1.0 +turquoise | 0.9890533685684204 +aqua marine | 0.964962363243103 +azure | 0.916246771812439 +lavender | 0.9136701822280884 +mint cream | 0.9122757911682129 +honeydew | 0.9122424125671387 +gainsboro | 0.9082483053207397 +gray | 0.9082483053207397 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec index b6d32b5ae651b..3297ae84db5ff 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-dot-product.csv-spec @@ -88,17 +88,39 @@ total_null:long ; -# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector -similarityWithRow-Ignore +similarityWithRow required_capability: dot_product_vector_similarity_function +required_capability: to_dense_vector_function -row vector = [1, 2, 3] +row vector = to_dense_vector([1, 2, 3]) | eval similarity = round(v_dot_product(vector, [0, 1, 2]), 3) +; + +vector: dense_vector | similarity:double +[1.0, 2.0, 3.0] | 4.5 +; + +similarityWithVectorField +required_capability: dot_product_vector_similarity_function +required_capability: to_dense_vector_function + +from colors +| eval query = to_dense_vector([0, 255, 255]) +| eval similarity = v_dot_product(rgb_vector, query) | sort similarity desc, color asc | limit 10 | keep color, similarity ; -similarity:double -0.978 +color:text | similarity:double +azure | 65025.5 +cyan | 65025.5 +white | 65025.5 +mint cream | 64388.0 +snow | 63750.5 +honeydew | 63113.0 +ivory | 63113.0 +sea shell | 61583.0 +lavender | 61200.5 +old lace | 60563.0 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-hamming.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-hamming.csv-spec index a7e8815139567..37630c94e62e0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-hamming.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-hamming.csv-spec @@ -87,17 +87,39 @@ total_null:long 59 ; -# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector -similarityWithRow-Ignore +similarityWithRow required_capability: hamming_vector_similarity_function +required_capability: to_dense_vector_function -row vector = [1, 2, 3] +row vector = to_dense_vector([1, 2, 3]) | eval similarity = round(v_hamming(vector, [0, 1, 2]), 3) +; + +vector: dense_vector | similarity:double +[1.0, 2.0, 3.0] | 4.0 +; + +similarityWithVectorField +required_capability: hamming_vector_similarity_function +required_capability: to_dense_vector_function + +from colors +| eval query = to_dense_vector([0, 255, 255]) +| eval similarity = v_hamming(rgb_vector, query) | sort similarity desc, color asc | limit 10 | keep color, similarity ; - -similarity:double -0.978 + +color:text | similarity:double +red | 24.0 +orange | 20.0 +gold | 18.0 +indigo | 18.0 +bisque | 17.0 +maroon | 17.0 +pink | 17.0 +salmon | 17.0 +black | 16.0 +firebrick | 16.0 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec index 53f550dd4fe1f..148d9d0da85a9 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec @@ -87,17 +87,39 @@ total_null:long 59 ; -# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector -similarityWithRow-Ignore +similarityWithRow required_capability: l1_norm_vector_similarity_function +required_capability: to_dense_vector_function -row vector = [1, 2, 3] +row vector = to_dense_vector([1, 2, 3]) | eval similarity = round(v_l1_norm(vector, [0, 1, 2]), 3) +; + +vector: dense_vector | similarity:double +[1.0, 2.0, 3.0] | 3.0 +; + +similarityWithVectorField +required_capability: l1_norm_vector_similarity_function +required_capability: to_dense_vector_function + +from colors +| eval query = to_dense_vector([0, 255, 255]) +| eval similarity = v_l1_norm(rgb_vector, query) | sort similarity desc, color asc | limit 10 | keep color, similarity ; - -similarity:double -0.978 + +color:text | similarity:double +red | 765.0 +crimson | 650.0 +maroon | 638.0 +firebrick | 620.0 +orange | 600.0 +tomato | 595.0 +brown | 591.0 +chocolate | 585.0 +coral | 558.0 +gold | 550.0 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec index 03a094ed93cad..d150c65e3b2fa 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec @@ -87,17 +87,39 @@ total_null:long 59 ; -# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector -similarityWithRow-Ignore +similarityWithRow required_capability: l2_norm_vector_similarity_function +required_capability: to_dense_vector_function -row vector = [1, 2, 3] +row vector = to_dense_vector([1, 2, 3]) | eval similarity = round(v_l2_norm(vector, [0, 1, 2]), 3) +; + +vector: dense_vector | similarity:double +[1.0, 2.0, 3.0] | 1.732 +; + +similarityWithVectorField +required_capability: l2_norm_vector_similarity_function +required_capability: to_dense_vector_function + +from colors +| eval query = to_dense_vector([0, 255, 255]) +| eval similarity = v_l2_norm(rgb_vector, query) | sort similarity desc, color asc | limit 10 | keep color, similarity ; - -similarity:double -0.978 + +color:text | similarity:double +red | 441.6729431152344 +maroon | 382.6669616699219 +crimson | 376.36419677734375 +orange | 371.68536376953125 +gold | 362.8360595703125 +black | 360.62445068359375 +magenta | 360.62445068359375 +yellow | 360.62445068359375 +firebrick | 359.67486572265625 +tomato | 351.0227966308594 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec index c670cb9ec678e..bb6d39735d8e4 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-magnitude.csv-spec @@ -85,3 +85,15 @@ row a = 1 magnitude:double null ; + +magnitudeWithRow +required_capability: magnitude_scalar_vector_function +required_capability: to_dense_vector_function + +row vector = to_dense_vector([1, 2, 3]) +| eval magnitude = round(v_magnitude(vector), 3) +; + +vector: dense_vector | magnitude:double +[1.0, 2.0, 3.0] | 3.742 +; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromDoubleEvaluator.java new file mode 100644 index 0000000000000..a5fe8c25610ed --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromDoubleEvaluator.java @@ -0,0 +1,132 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.DoubleVector; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Vector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link ToDenseVector}. + * This class is generated. Edit {@code ConvertEvaluatorImplementer} instead. + */ +public final class ToDenseVectorFromDoubleEvaluator extends AbstractConvertFunction.AbstractEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ToDenseVectorFromDoubleEvaluator.class); + + private final EvalOperator.ExpressionEvaluator d; + + public ToDenseVectorFromDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator d, + DriverContext driverContext) { + super(driverContext, source); + this.d = d; + } + + @Override + public EvalOperator.ExpressionEvaluator next() { + return d; + } + + @Override + public Block evalVector(Vector v) { + DoubleVector vector = (DoubleVector) v; + int positionCount = v.getPositionCount(); + if (vector.isConstant()) { + return driverContext.blockFactory().newConstantFloatBlockWith(evalValue(vector, 0), positionCount); + } + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + builder.appendFloat(evalValue(vector, p)); + } + return builder.build(); + } + } + + private float evalValue(DoubleVector container, int index) { + double value = container.getDouble(index); + return ToDenseVector.fromDouble(value); + } + + @Override + public Block evalBlock(Block b) { + DoubleBlock block = (DoubleBlock) b; + int positionCount = block.getPositionCount(); + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int valueCount = block.getValueCount(p); + int start = block.getFirstValueIndex(p); + int end = start + valueCount; + boolean positionOpened = false; + boolean valuesAppended = false; + for (int i = start; i < end; i++) { + float value = evalValue(block, i); + if (positionOpened == false && valueCount > 1) { + builder.beginPositionEntry(); + positionOpened = true; + } + builder.appendFloat(value); + valuesAppended = true; + } + if (valuesAppended == false) { + builder.appendNull(); + } else if (positionOpened) { + builder.endPositionEntry(); + } + } + return builder.build(); + } + } + + private float evalValue(DoubleBlock container, int index) { + double value = container.getDouble(index); + return ToDenseVector.fromDouble(value); + } + + @Override + public String toString() { + return "ToDenseVectorFromDoubleEvaluator[" + "d=" + d + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(d); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += d.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory d; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory d) { + this.source = source; + this.d = d; + } + + @Override + public ToDenseVectorFromDoubleEvaluator get(DriverContext context) { + return new ToDenseVectorFromDoubleEvaluator(source, d.get(context), context); + } + + @Override + public String toString() { + return "ToDenseVectorFromDoubleEvaluator[" + "d=" + d + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromIntEvaluator.java new file mode 100644 index 0000000000000..ff6ccb6e86917 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromIntEvaluator.java @@ -0,0 +1,132 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Vector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link ToDenseVector}. + * This class is generated. Edit {@code ConvertEvaluatorImplementer} instead. + */ +public final class ToDenseVectorFromIntEvaluator extends AbstractConvertFunction.AbstractEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ToDenseVectorFromIntEvaluator.class); + + private final EvalOperator.ExpressionEvaluator i; + + public ToDenseVectorFromIntEvaluator(Source source, EvalOperator.ExpressionEvaluator i, + DriverContext driverContext) { + super(driverContext, source); + this.i = i; + } + + @Override + public EvalOperator.ExpressionEvaluator next() { + return i; + } + + @Override + public Block evalVector(Vector v) { + IntVector vector = (IntVector) v; + int positionCount = v.getPositionCount(); + if (vector.isConstant()) { + return driverContext.blockFactory().newConstantFloatBlockWith(evalValue(vector, 0), positionCount); + } + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + builder.appendFloat(evalValue(vector, p)); + } + return builder.build(); + } + } + + private float evalValue(IntVector container, int index) { + int value = container.getInt(index); + return ToDenseVector.fromInt(value); + } + + @Override + public Block evalBlock(Block b) { + IntBlock block = (IntBlock) b; + int positionCount = block.getPositionCount(); + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int valueCount = block.getValueCount(p); + int start = block.getFirstValueIndex(p); + int end = start + valueCount; + boolean positionOpened = false; + boolean valuesAppended = false; + for (int i = start; i < end; i++) { + float value = evalValue(block, i); + if (positionOpened == false && valueCount > 1) { + builder.beginPositionEntry(); + positionOpened = true; + } + builder.appendFloat(value); + valuesAppended = true; + } + if (valuesAppended == false) { + builder.appendNull(); + } else if (positionOpened) { + builder.endPositionEntry(); + } + } + return builder.build(); + } + } + + private float evalValue(IntBlock container, int index) { + int value = container.getInt(index); + return ToDenseVector.fromInt(value); + } + + @Override + public String toString() { + return "ToDenseVectorFromIntEvaluator[" + "i=" + i + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(i); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += i.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory i; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory i) { + this.source = source; + this.i = i; + } + + @Override + public ToDenseVectorFromIntEvaluator get(DriverContext context) { + return new ToDenseVectorFromIntEvaluator(source, i.get(context), context); + } + + @Override + public String toString() { + return "ToDenseVectorFromIntEvaluator[" + "i=" + i + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromLongEvaluator.java new file mode 100644 index 0000000000000..4ca69984ff540 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromLongEvaluator.java @@ -0,0 +1,132 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Vector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link ToDenseVector}. + * This class is generated. Edit {@code ConvertEvaluatorImplementer} instead. + */ +public final class ToDenseVectorFromLongEvaluator extends AbstractConvertFunction.AbstractEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ToDenseVectorFromLongEvaluator.class); + + private final EvalOperator.ExpressionEvaluator l; + + public ToDenseVectorFromLongEvaluator(Source source, EvalOperator.ExpressionEvaluator l, + DriverContext driverContext) { + super(driverContext, source); + this.l = l; + } + + @Override + public EvalOperator.ExpressionEvaluator next() { + return l; + } + + @Override + public Block evalVector(Vector v) { + LongVector vector = (LongVector) v; + int positionCount = v.getPositionCount(); + if (vector.isConstant()) { + return driverContext.blockFactory().newConstantFloatBlockWith(evalValue(vector, 0), positionCount); + } + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + builder.appendFloat(evalValue(vector, p)); + } + return builder.build(); + } + } + + private float evalValue(LongVector container, int index) { + long value = container.getLong(index); + return ToDenseVector.fromLong(value); + } + + @Override + public Block evalBlock(Block b) { + LongBlock block = (LongBlock) b; + int positionCount = block.getPositionCount(); + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount)) { + for (int p = 0; p < positionCount; p++) { + int valueCount = block.getValueCount(p); + int start = block.getFirstValueIndex(p); + int end = start + valueCount; + boolean positionOpened = false; + boolean valuesAppended = false; + for (int i = start; i < end; i++) { + float value = evalValue(block, i); + if (positionOpened == false && valueCount > 1) { + builder.beginPositionEntry(); + positionOpened = true; + } + builder.appendFloat(value); + valuesAppended = true; + } + if (valuesAppended == false) { + builder.appendNull(); + } else if (positionOpened) { + builder.endPositionEntry(); + } + } + return builder.build(); + } + } + + private float evalValue(LongBlock container, int index) { + long value = container.getLong(index); + return ToDenseVector.fromLong(value); + } + + @Override + public String toString() { + return "ToDenseVectorFromLongEvaluator[" + "l=" + l + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(l); + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += l.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory l; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory l) { + this.source = source; + this.l = l; + } + + @Override + public ToDenseVectorFromLongEvaluator get(DriverContext context) { + return new ToDenseVectorFromLongEvaluator(source, l.get(context), context); + } + + @Override + public String toString() { + return "ToDenseVectorFromLongEvaluator[" + "l=" + l + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 3d3c0d3b286f1..3bee4b70ab912 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1467,7 +1467,12 @@ public enum Cap { /** * Support for the Present function */ - FN_PRESENT; + FN_PRESENT, + + /** + * TO_DENSE_VECTOR function. + */ + TO_DENSE_VECTOR_FUNCTION(Build.current().isSnapshot()); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 5e447ede71be9..55ec36630d509 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -78,6 +78,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToAggregateMetricDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; @@ -157,6 +158,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.type.DataType.GEO_POINT; @@ -1462,7 +1464,7 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func return processIn(in); } if (f instanceof VectorFunction) { - return processVectorFunction(f); + return processVectorFunction(f, registry); } if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed return processScalarOrGroupingFunction(f, registry); @@ -1675,22 +1677,28 @@ private static Expression castStringLiteral(Expression from, DataType target) { } @SuppressWarnings("unchecked") - private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) { + private static Expression processVectorFunction( + org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction, + EsqlFunctionRegistry registry + ) { + // Perform implicit casting for dense_vector from numeric and keyword values List args = vectorFunction.arguments(); + List targetDataTypes = registry.getDataTypeForStringLiteralConversion(vectorFunction.getClass()); List newArgs = new ArrayList<>(); - for (Expression arg : args) { - if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { - Object folded = arg.fold(FoldContext.small() /* TODO remove me */); - if (folded instanceof List) { - // Convert to floats so blocks are created accordingly - List floatVector; - if (arg.dataType() == FLOAT) { - floatVector = (List) folded; - } else { - floatVector = ((List) folded).stream().map(Number::floatValue).collect(Collectors.toList()); + for (int i = 0; i < args.size(); i++) { + Expression arg = args.get(i); + if (targetDataTypes.get(i) == DENSE_VECTOR && arg.resolved()) { + var dataType = arg.dataType(); + if (dataType == KEYWORD) { + if (arg.foldable()) { + Expression exp = castStringLiteral(arg, DENSE_VECTOR); + if (exp != arg) { + newArgs.add(exp); + continue; + } } - Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR); - newArgs.add(denseVector); + } else if (dataType.isNumeric()) { + newArgs.add(new ToDenseVector(vectorFunction.source(), arg)); continue; } } @@ -1699,7 +1707,6 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor return vectorFunction.replaceChildren(newArgs); } - } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 20de89a53780d..16a38671db62c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; @@ -22,6 +23,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDegrees; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoPoint; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoShape; @@ -205,6 +207,9 @@ public static List unaryScalars() { entries.add(ToDatetime.ENTRY); entries.add(ToDateNanos.ENTRY); entries.add(ToDegrees.ENTRY); + if (EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()) { + entries.add(ToDenseVector.ENTRY); + } entries.add(ToDouble.ENTRY); entries.add(ToGeoShape.ENTRY); entries.add(ToCartesianShape.ENTRY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 4f6c87eb3ec77..79f610bc9ad98 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -71,6 +71,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatePeriod; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDegrees; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoPoint; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoShape; @@ -217,6 +218,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_SHAPE; import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEOHASH; import static org.elasticsearch.xpack.esql.core.type.DataType.GEOHEX; @@ -255,6 +257,7 @@ public class EsqlFunctionRegistry { GEOTILE, BOOLEAN, UNSIGNED_LONG, + DENSE_VECTOR, UNSUPPORTED ); DATA_TYPE_CASTING_PRIORITY = new HashMap<>(); @@ -517,6 +520,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Score.class, uni(Score::new), Score.NAME), def(Term.class, bi(Term::new), "term"), + def(ToDenseVector.class, ToDenseVector::new, "to_dense_vector"), def(Knn.class, tri(Knn::new), "knn"), def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"), def(DotProduct.class, DotProduct::new, "v_dot_product"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVector.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVector.java new file mode 100644 index 0000000000000..f70c0a59b2ece --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVector.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.compute.ann.ConvertEvaluator; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; + +/** + * Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector. + */ +public class ToDenseVector extends AbstractConvertFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "ToDenseVector", + ToDenseVector::new + ); + + private static final Map EVALUATORS = Map.ofEntries( + Map.entry(DENSE_VECTOR, (source, fieldEval) -> fieldEval), + Map.entry(LONG, ToDenseVectorFromLongEvaluator.Factory::new), + Map.entry(INTEGER, ToDenseVectorFromIntEvaluator.Factory::new), + Map.entry(DOUBLE, ToDenseVectorFromDoubleEvaluator.Factory::new), + Map.entry(KEYWORD, ToDenseVectorFromStringEvaluator.Factory::new) + ); + + @FunctionInfo( + returnType = "dense_vector", + description = "Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector.", + examples = @Example(file = "dense_vector", tag = "to_dense_vector-ints") + ) + public ToDenseVector( + Source source, + @Param( + name = "field", + type = { "double", "long", "integer", "keyword" }, + description = "multi-valued input of numbers or hexadecimal string to convert." + ) Expression field + ) { + super(source, field); + } + + private ToDenseVector(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected Map factories() { + return EVALUATORS; + } + + @Override + public DataType dataType() { + return DENSE_VECTOR; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new ToDenseVector(source(), newChildren.get(0)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, ToDenseVector::new, field()); + } + + @ConvertEvaluator(extraName = "FromLong") + static float fromLong(long l) { + return l; + } + + @ConvertEvaluator(extraName = "FromInt") + static float fromInt(int i) { + return i; + } + + @ConvertEvaluator(extraName = "FromDouble") + static float fromDouble(double d) { + return (float) d; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromStringEvaluator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromStringEvaluator.java new file mode 100644 index 0000000000000..9470e744099b2 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorFromStringEvaluator.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.FloatBlock; +import org.elasticsearch.compute.data.Vector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.HexFormat; + +/** + * String evaluator for to_dense_vector function. Converts a hexadecimal string to a dense_vector of bytes. + * Cannot be automatically generated as it generates multivalues for a single hex string, representing the dense_vector byte array. + */ +class ToDenseVectorFromStringEvaluator extends AbstractConvertFunction.AbstractEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ToDenseVectorFromStringEvaluator.class); + + private final EvalOperator.ExpressionEvaluator field; + + ToDenseVectorFromStringEvaluator(Source source, EvalOperator.ExpressionEvaluator field, DriverContext driverContext) { + super(driverContext, source); + this.field = field; + } + + @Override + protected EvalOperator.ExpressionEvaluator next() { + return field; + } + + @Override + protected Block evalVector(Vector v) { + return evalBlock(v.asBlock()); + } + + @Override + public Block evalBlock(Block b) { + BytesRefBlock block = (BytesRefBlock) b; + int positionCount = block.getPositionCount(); + int dimensions = 0; + BytesRef scratch = new BytesRef(); + try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount * dimensions)) { + for (int p = 0; p < positionCount; p++) { + if (block.isNull(p)) { + builder.appendNull(); + } else { + scratch = block.getBytesRef(p, scratch); + try { + byte[] bytes = HexFormat.of().parseHex(scratch.utf8ToString()); + if (bytes.length == 0) { + builder.appendNull(); + continue; + } + if (dimensions == 0) { + dimensions = bytes.length; + } else { + if (bytes.length != dimensions) { + throw new IllegalArgumentException( + "All dense_vector must have the same number of dimensions. Expected: " + + dimensions + + ", found: " + + bytes.length + ); + } + } + builder.beginPositionEntry(); + for (byte value : bytes) { + builder.appendFloat(value); + } + builder.endPositionEntry(); + } catch (IllegalArgumentException e) { + registerException(e); + builder.appendNull(); + } + } + } + return builder.build(); + } + } + + @Override + public String toString() { + return "ToDenseVectorFromStringEvaluator[s=" + field + ']'; + } + + @Override + public long baseRamBytesUsed() { + return BASE_RAM_BYTES_USED + field.baseRamBytesUsed(); + } + + @Override + public void close() { + Releasables.closeExpectNoException(field); + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + private final EvalOperator.ExpressionEvaluator.Factory field; + + Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field) { + this.source = source; + this.field = field; + } + + @Override + public EvalOperator.ExpressionEvaluator get(DriverContext context) { + return new ToDenseVectorFromStringEvaluator(source, field.get(context), context); + } + + @Override + public String toString() { + return "ToDenseVectorFromStringEvaluator[s=" + field + ']'; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java index dc0be7a29fee0..ca983caf5615f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java @@ -9,7 +9,7 @@ /** * Marker interface for vector functions. Makes possible to do implicit casting - * from multi values to dense_vector field types, so parameters are actually + * from multi values and hex strings to dense_vector field types, so parameters are actually * processed as dense_vectors in vector functions */ public interface VectorFunction {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java index e3a6b08cac1fd..446ba48288b84 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java @@ -28,6 +28,7 @@ import org.elasticsearch.search.aggregations.bucket.geogrid.GeoTileUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -47,6 +48,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatePeriod; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoPoint; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToGeoShape; @@ -74,13 +76,16 @@ import java.time.ZoneId; import java.time.temporal.ChronoField; import java.time.temporal.TemporalAmount; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HexFormat; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; -import static java.util.Map.entry; import static org.elasticsearch.xpack.esql.core.type.DataType.AGGREGATE_METRIC_DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; import static org.elasticsearch.xpack.esql.core.type.DataType.CARTESIAN_POINT; @@ -88,6 +93,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.GEOHASH; import static org.elasticsearch.xpack.esql.core.type.DataType.GEOHEX; @@ -127,30 +133,38 @@ public class EsqlDataTypeConverter { public static final DateFormatter HOUR_MINUTE_SECOND = DateFormatter.forPattern("strict_hour_minute_second_fraction"); - private static final Map> TYPE_TO_CONVERTER_FUNCTION = Map.ofEntries( - entry(AGGREGATE_METRIC_DOUBLE, ToAggregateMetricDouble::new), - entry(BOOLEAN, ToBoolean::new), - entry(CARTESIAN_POINT, ToCartesianPoint::new), - entry(CARTESIAN_SHAPE, ToCartesianShape::new), - entry(DATETIME, ToDatetime::new), - entry(DATE_NANOS, ToDateNanos::new), + private static final Map> TYPE_TO_CONVERTER_FUNCTION; + + static { + Map> typeToConverter = new HashMap<>(); + typeToConverter.put(AGGREGATE_METRIC_DOUBLE, ToAggregateMetricDouble::new); + typeToConverter.put(BOOLEAN, ToBoolean::new); + typeToConverter.put(CARTESIAN_POINT, ToCartesianPoint::new); + typeToConverter.put(CARTESIAN_SHAPE, ToCartesianShape::new); + typeToConverter.put(DATETIME, ToDatetime::new); + typeToConverter.put(DATE_NANOS, ToDateNanos::new); // ToDegrees, typeless - entry(DOUBLE, ToDouble::new), - entry(GEO_POINT, ToGeoPoint::new), - entry(GEO_SHAPE, ToGeoShape::new), - entry(GEOHASH, ToGeohash::new), - entry(GEOTILE, ToGeotile::new), - entry(GEOHEX, ToGeohex::new), - entry(INTEGER, ToInteger::new), - entry(IP, ToIpLeadingZerosRejected::new), - entry(LONG, ToLong::new), + typeToConverter.put(DOUBLE, ToDouble::new); + typeToConverter.put(GEO_POINT, ToGeoPoint::new); + typeToConverter.put(GEO_SHAPE, ToGeoShape::new); + typeToConverter.put(GEOHASH, ToGeohash::new); + typeToConverter.put(GEOTILE, ToGeotile::new); + typeToConverter.put(GEOHEX, ToGeohex::new); + typeToConverter.put(INTEGER, ToInteger::new); + typeToConverter.put(IP, ToIpLeadingZerosRejected::new); + typeToConverter.put(LONG, ToLong::new); // ToRadians, typeless - entry(KEYWORD, ToString::new), - entry(UNSIGNED_LONG, ToUnsignedLong::new), - entry(VERSION, ToVersion::new), - entry(DATE_PERIOD, ToDatePeriod::new), - entry(TIME_DURATION, ToTimeDuration::new) - ); + typeToConverter.put(KEYWORD, ToString::new); + typeToConverter.put(UNSIGNED_LONG, ToUnsignedLong::new); + typeToConverter.put(VERSION, ToVersion::new); + typeToConverter.put(DATE_PERIOD, ToDatePeriod::new); + typeToConverter.put(TIME_DURATION, ToTimeDuration::new); + + if (EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()) { + typeToConverter.put(DENSE_VECTOR, ToDenseVector::new); + } + TYPE_TO_CONVERTER_FUNCTION = Collections.unmodifiableMap(typeToConverter); + } public enum INTERVALS { // TIME_DURATION, @@ -272,6 +286,9 @@ public static Converter converterFor(DataType from, DataType to) { if (to == DataType.DATE_PERIOD) { return EsqlConverter.STRING_TO_DATE_PERIOD; } + if (to == DENSE_VECTOR) { + return EsqlConverter.STRING_TO_DENSE_VECTOR; + } } Converter converter = DataTypeConverter.converterFor(from, to); if (converter != null) { @@ -732,6 +749,19 @@ public static boolean unsignedLongToBoolean(long number) { return n instanceof BigInteger || n.longValue() != 0; } + public static List stringToDenseVector(String field) { + try { + byte[] bytes = HexFormat.of().parseHex(field); + List vector = new ArrayList<>(bytes.length); + for (byte value : bytes) { + vector.add((float) value); + } + return vector; + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s is not a valid hex string: %s", field, e.getMessage())); + } + } + public static long booleanToUnsignedLong(boolean number) { return number ? ONE_AS_UNSIGNED_LONG : ZERO_AS_UNSIGNED_LONG; } @@ -827,7 +857,8 @@ public enum EsqlConverter implements Converter { STRING_TO_SPATIAL(x -> EsqlDataTypeConverter.stringToSpatial(BytesRefs.toString(x))), STRING_TO_GEOHASH(x -> Geohash.longEncode(BytesRefs.toString(x))), STRING_TO_GEOTILE(x -> GeoTileUtils.longEncode(BytesRefs.toString(x))), - STRING_TO_GEOHEX(x -> H3.stringToH3(BytesRefs.toString(x))); + STRING_TO_GEOHEX(x -> H3.stringToH3(BytesRefs.toString(x))), + STRING_TO_DENSE_VECTOR(x -> EsqlDataTypeConverter.stringToDenseVector(BytesRefs.toString(x))); private static final String NAME = "esql-converter"; private final Function converter; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index efaf7a934607c..d1dbdb9466614 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.grouping.TBucket; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDateNanos; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDenseVector; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; @@ -2349,21 +2350,40 @@ public void testImplicitCasting() { public void testDenseVectorImplicitCastingKnn() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()); + assumeTrue("dense vector casting must be enabled", EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()); - checkDenseVectorCastingKnn("float_vector"); - + if (EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()) { + checkDenseVectorCastingHexKnn("float_vector"); + checkDenseVectorCastingKnn("float_vector"); + } if (EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE_BYTE_ELEMENTS.isEnabled()) { checkDenseVectorCastingKnn("byte_vector"); + checkDenseVectorCastingHexKnn("byte_vector"); + checkDenseVectorEvalCastingKnn("byte_vector"); } - if (EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE_BIT_ELEMENTS.isEnabled()) { checkDenseVectorCastingKnn("bit_vector"); + checkDenseVectorCastingHexKnn("bit_vector"); + checkDenseVectorEvalCastingKnn("bit_vector"); } } private static void checkDenseVectorCastingKnn(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ - from test | where knn(%s, [0.342, 0.164, 0.234]) + from test | where knn(%s, [0, 1, 2]) + """, fieldName), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + var conversion = as(knn.query(), ToDenseVector.class); + var literal = as(conversion.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(0, 1, 2))); + } + + private static void checkDenseVectorCastingHexKnn(String fieldName) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | where knn(%s, "000102") """, fieldName), "mapping-dense_vector.json"); var limit = as(plan, Limit.class); @@ -2371,50 +2391,65 @@ private static void checkDenseVectorCastingKnn(String fieldName) { var knn = as(filter.condition(), Knn.class); var queryVector = as(knn.query(), Literal.class); assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); - assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f))); + assertThat(queryVector.value(), equalTo(List.of(0.0f, 1.0f, 2.0f))); + } + + private static void checkDenseVectorEvalCastingKnn(String fieldName) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | eval query = to_dense_vector([0, 1, 2]) | where knn(%s, query) + """, fieldName), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + var queryVector = as(knn.query(), ReferenceAttribute.class); + assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); + assertThat(queryVector.name(), is("query")); } public void testDenseVectorImplicitCastingSimilarityFunctions() { + assumeTrue("dense vector casting must be enabled", EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()); + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( "v_cosine(float_vector, [0.342, 0.164, 0.234])", - List.of(0.342f, 0.164f, 0.234f) + List.of(0.342, 0.164, 0.234) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( "v_dot_product(float_vector, [0.342, 0.164, 0.234])", - List.of(0.342f, 0.164f, 0.234f) + List.of(0.342, 0.164, 0.234) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( "v_l1_norm(float_vector, [0.342, 0.164, 0.234])", - List.of(0.342f, 0.164f, 0.234f) + List.of(0.342, 0.164, 0.234) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( "v_l2_norm(float_vector, [0.342, 0.164, 0.234])", - List.of(0.342f, 0.164f, 0.234f) + List.of(0.342, 0.164, 0.234) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1, 2, 3)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); if (EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE_BIT_ELEMENTS.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(bit_vector, [1, 2])", List.of(1f, 2f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(bit_vector, [1, 2])", List.of(1, 2)); } } if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { checkDenseVectorImplicitCastingSimilarityFunction( "v_hamming(byte_vector, [0.342, 0.164, 0.234])", - List.of(0.342f, 0.164f, 0.234f) + List.of(0.342, 0.164, 0.234) ); - checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(byte_vector, [1, 2, 3])", List.of(1f, 2f, 3f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); if (EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE_BIT_ELEMENTS.isEnabled()) { - checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(bit_vector, [1, 2])", List.of(1f, 2f)); + checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(bit_vector, [1, 2])", List.of(1, 2)); } } } @@ -2431,35 +2466,82 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity var similarity = as(alias.child(), VectorSimilarityFunction.class); var left = as(similarity.left(), FieldAttribute.class); assertThat(List.of("float_vector", "byte_vector", "bit_vector"), hasItem(left.name())); - var right = as(similarity.right(), Literal.class); - assertThat(right.dataType(), is(DENSE_VECTOR)); - assertThat(right.value(), equalTo(expectedElems)); + var right = as(similarity.right(), ToDenseVector.class); + var literal = as(right.field(), Literal.class); + assertThat(literal.value(), equalTo(expectedElems)); } - public void testNoDenseVectorFailsSimilarityFunction() { + public void testDenseVectorEvalCastingSimilarityFunctions() { + assumeTrue("dense vector casting must be enabled", EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()); + if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)"); + checkDenseVectorEvalCastingSimilarityFunction("v_cosine(float_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_cosine(byte_vector, query)"); + } + if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkDenseVectorEvalCastingSimilarityFunction("v_dot_product(float_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_dot_product(byte_vector, query)"); + } + if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkDenseVectorEvalCastingSimilarityFunction("v_l1_norm(float_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l1_norm(byte_vector, query)"); + } + if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(float_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(float_vector, query)"); + } + if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { + checkDenseVectorEvalCastingSimilarityFunction("v_hamming(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_hamming(byte_vector, query)"); + } + } + + private void checkDenseVectorEvalCastingSimilarityFunction(String similarityFunction) { + var plan = analyze(String.format(Locale.ROOT, """ + from test | eval query = to_dense_vector([0.342, 0.164, 0.234]) | eval similarity = %s + """, similarityFunction), "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var eval = as(limit.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertEquals("similarity", alias.name()); + var similarity = as(alias.child(), VectorSimilarityFunction.class); + var left = as(similarity.left(), FieldAttribute.class); + assertThat(List.of("float_vector", "byte_vector"), hasItem(left.name())); + var right = as(similarity.right(), ReferenceAttribute.class); + assertThat(right.dataType(), is(DENSE_VECTOR)); + assertThat(right.name(), is("query")); + } + + public void testVectorFunctionHexImplicitCastingError() { + assumeTrue("dense vector casting must be enabled", EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()); + + if (EsqlCapabilities.Cap.KNN_FUNCTION_V5.isEnabled()) { + checkVectorFunctionHexImplicitCastingError("where knn(float_vector, \"notcorrect\")"); } if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkNoDenseVectorFailsSimilarityFunction("v_dot_product([0, 1, 2], 0.342)"); + checkVectorFunctionHexImplicitCastingError("eval s = v_dot_product(\"notcorrect\", 0.342)"); } if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkNoDenseVectorFailsSimilarityFunction("v_l1_norm([0, 1, 2], 0.342)"); + checkVectorFunctionHexImplicitCastingError("eval s = v_l1_norm(\"notcorrect\", 0.342)"); } if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkNoDenseVectorFailsSimilarityFunction("v_l2_norm([0, 1, 2], 0.342)"); + checkVectorFunctionHexImplicitCastingError("eval s = v_l2_norm(\"notcorrect\", 0.342)"); } if (EsqlCapabilities.Cap.HAMMING_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { - checkNoDenseVectorFailsSimilarityFunction("v_hamming([0, 1, 2], 0.342)"); + checkVectorFunctionHexImplicitCastingError("eval s = v_hamming(\"notcorrect\", 0.342)"); } } - private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) { - var query = String.format(Locale.ROOT, "row a = 1 | eval similarity = %s", similarityFunction); - VerificationException error = expectThrows(VerificationException.class, () -> analyze(query)); + private void checkVectorFunctionHexImplicitCastingError(String clause) { + var query = "from test | " + clause; + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query, "mapping-dense_vector.json")); assertThat( error.getMessage(), - containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]") + containsString( + "Cannot convert string [notcorrect] to [DENSE_VECTOR], " + + "error [notcorrect is not a valid hex string: not a hexadecimal digit: \"n\" = 110]" + ) ); } @@ -2475,20 +2557,9 @@ public void testMagnitudePlanWithDenseVectorImplicitCasting() { var alias = as(eval.fields().get(0), Alias.class); assertEquals("scalar", alias.name()); var scalar = as(alias.child(), Magnitude.class); - var child = as(scalar.field(), Literal.class); - assertThat(child.dataType(), is(DENSE_VECTOR)); - assertThat(child.value(), equalTo(List.of(1.0f, 2.0f, 3.0f))); - } - - public void testNoDenseVectorFailsForMagnitude() { - assumeTrue("v_magnitude not available", EsqlCapabilities.Cap.MAGNITUDE_SCALAR_VECTOR_FUNCTION.isEnabled()); - - var query = String.format(Locale.ROOT, "row a = 1 | eval scalar = v_magnitude(0.342)"); - VerificationException error = expectThrows(VerificationException.class, () -> analyze(query)); - assertThat( - error.getMessage(), - containsString("first argument of [v_magnitude(0.342)] must be [dense_vector], found value [0.342] type [double]") - ); + var child = as(scalar.field(), ToDenseVector.class); + var literal = as(child.field(), Literal.class); + assertThat(literal.value(), equalTo(List.of(1, 2, 3))); } public void testRateRequiresCounterTypes() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorTests.java new file mode 100644 index 0000000000000..e4e153d25bf8f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVectorTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.convert; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.List; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class ToDenseVectorTests extends AbstractScalarFunctionTestCase { + + @BeforeClass + public static void checkCapability() { + assumeTrue("To_DenseVector function capability", EsqlCapabilities.Cap.TO_DENSE_VECTOR_FUNCTION.isEnabled()); + } + + public ToDenseVectorTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + + suppliers.add(new TestCaseSupplier("int", List.of(DataType.INTEGER), () -> { + List data = Arrays.asList(randomArray(2, 10, Integer[]::new, ESTestCase::randomInt)); + return new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(data, DataType.INTEGER, "int")), + evaluatorName("Int", "i"), + DataType.DENSE_VECTOR, + equalTo(data.stream().map(Number::floatValue).toList()) + ); + })); + + suppliers.add(new TestCaseSupplier("long", List.of(DataType.LONG), () -> { + List data = Arrays.asList(randomArray(2, 10, Long[]::new, ESTestCase::randomLong)); + return new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(data, DataType.LONG, "long")), + evaluatorName("Long", "l"), + DataType.DENSE_VECTOR, + equalTo(data.stream().map(Number::floatValue).toList()) + ); + })); + + suppliers.add(new TestCaseSupplier("double", List.of(DataType.DOUBLE), () -> { + List data = Arrays.asList(randomArray(2, 10, Double[]::new, ESTestCase::randomDouble)); + return new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(data, DataType.DOUBLE, "double")), + evaluatorName("Double", "d"), + DataType.DENSE_VECTOR, + equalTo(data.stream().map(Number::floatValue).toList()) + ); + })); + + suppliers.add(new TestCaseSupplier("keyword", List.of(DataType.KEYWORD), () -> { + byte[] bytes = randomByteArrayOfLength(randomIntBetween(2, 20)); + String data = HexFormat.of().formatHex(bytes); + List expected = new ArrayList<>(bytes.length); + for (int i = 0; i < bytes.length; i++) { + expected.add((float) bytes[i]); + } + return new TestCaseSupplier.TestCase( + List.of(new TestCaseSupplier.TypedData(new BytesRef(data), DataType.KEYWORD, "keyword")), + evaluatorName("String", "s"), + DataType.DENSE_VECTOR, + is(expected) + ); + })); + + return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers); + } + + private static String evaluatorName(String inner, String next) { + String read = "Attribute[channel=0]"; + return "ToDenseVectorFrom" + inner + "Evaluator[" + next + "=" + read + "]"; + } + + @Override + protected Expression build(Source source, List args) { + return new ToDenseVector(source, args.get(0)); + } +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 605baef08a4e5..30efc81236a4c 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -129,7 +129,7 @@ setup: - match: {esql.functions.coalesce: $functions_coalesce} - gt: {esql.functions.categorize: $functions_categorize} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 174} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 175} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": - requires: