From 21fe40d71e6009ea1645cd000e819a4364950af1 Mon Sep 17 00:00:00 2001 From: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com> Date: Tue, 10 Jun 2025 11:11:42 +0200 Subject: [PATCH 1/2] [8.19] ES|QL - kNN function initial support (#127322) --- .../esql/functions/kibana/definition/knn.json | 13 + .../esql/functions/kibana/docs/knn.md | 10 + .../esql/images/functions/knn.svg | 1 + .../vectors/DenseVectorFieldMapper.java | 5 + .../compute/lucene/LuceneQueryEvaluator.java | 2 +- .../xpack/esql/qa/rest/RestEsqlTestCase.java | 2 +- .../xpack/esql/CsvTestsDataLoader.java | 4 +- .../src/main/resources/data/colors.csv | 60 ++++ .../src/main/resources/knn-function.csv-spec | 285 ++++++++++++++++++ .../src/main/resources/mapping-all-types.json | 3 + .../src/main/resources/mapping-colors.json | 20 ++ .../xpack/esql/DenseVectorFieldTypeIT.java | 65 +++- .../xpack/esql/plugin/KnnFunctionIT.java | 156 ++++++++++ .../xpack/esql/action/EsqlCapabilities.java | 7 +- .../xpack/esql/analysis/Analyzer.java | 23 ++ .../esql/expression/ExpressionWritables.java | 10 + .../function/EsqlFunctionRegistry.java | 4 +- .../function/fulltext/FullTextFunction.java | 69 ++++- .../esql/expression/function/vector/Knn.java | 285 ++++++++++++++++++ .../function/vector/VectorFunction.java | 15 + .../xpack/esql/querydsl/query/KnnQuery.java | 84 ++++++ .../elasticsearch/xpack/esql/CsvTests.java | 4 + .../xpack/esql/SerializationTestUtils.java | 2 + .../xpack/esql/analysis/AnalyzerTests.java | 17 ++ .../xpack/esql/analysis/VerifierTests.java | 30 ++ .../function/fulltext/KnnTests.java | 132 ++++++++ .../function/fulltext/MatchTests.java | 2 +- .../LocalPhysicalPlanOptimizerTests.java | 25 ++ 28 files changed, 1312 insertions(+), 23 deletions(-) create mode 100644 docs/reference/esql/functions/kibana/definition/knn.json create mode 100644 docs/reference/esql/functions/kibana/docs/knn.md create mode 100644 docs/reference/query-languages/esql/images/functions/knn.svg create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json create mode 100644 x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java diff --git a/docs/reference/esql/functions/kibana/definition/knn.json b/docs/reference/esql/functions/kibana/definition/knn.json new file mode 100644 index 0000000000000..48d3e582eec58 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/knn.json @@ -0,0 +1,13 @@ +{ + "comment" : "This is generated by ESQL’s AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "knn", + "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.", + "signatures" : [ ], + "examples" : [ + "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc", + "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/esql/functions/kibana/docs/knn.md b/docs/reference/esql/functions/kibana/docs/knn.md new file mode 100644 index 0000000000000..45d1f294ea0a8 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/knn.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +### KNN +Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors. + +```esql +from colors metadata _score +| where knn(rgb_vector, [0, 120, 0]) +| sort _score desc +``` diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg new file mode 100644 index 0000000000000..75a104a7cdcfa --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/knn.svg @@ -0,0 +1 @@ +KNN(field,query,options) \ No newline at end of file diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0e12fc9af9243..61fc655cd10f2 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2354,6 +2354,11 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { return null; } + if (dims == null) { + // No data has been indexed yet + return BlockLoader.CONSTANT_NULLS; + } + if (indexed) { return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java index c3676822fea64..1f27fd7981b5f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java @@ -111,7 +111,7 @@ private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException int min = docs.docs().getInt(0); int max = docs.docs().getInt(docs.getPositionCount() - 1); int length = max - min + 1; - try (T scoreBuilder = createVectorBuilder(blockFactory, length)) { + try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) { if (length == docs.getPositionCount() && length > 1) { return segmentState.scoreDense(scoreBuilder, min, max); } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index 5958f995f126d..43cf7eac61e98 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -1008,7 +1008,7 @@ public void testMultipleBatchesWithLookupJoin() throws IOException { var query = requestObjectBuilder().query(format(null, "from * | lookup join {} on integer {}", testIndexName(), sort)); Map result = runEsql(query); var columns = as(result.get("columns"), List.class); - assertEquals(21, columns.size()); + assertEquals(22, columns.size()); var values = as(result.get("values"), List.class); assertEquals(10, values.size()); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index 7a83c57fd2ebb..9b2160eec4398 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -121,6 +121,7 @@ public class CsvTestsDataLoader { private static final TestDataset SEMANTIC_TEXT = new TestDataset("semantic_text").withInferenceEndpoint(true); private static final TestDataset MV_TEXT = new TestDataset("mv_text"); private static final TestDataset DENSE_VECTOR = new TestDataset("dense_vector"); + private static final TestDataset COLORS = new TestDataset("colors"); public static final Map CSV_DATASET_MAP = Map.ofEntries( Map.entry(EMPLOYEES.indexName, EMPLOYEES), @@ -171,7 +172,8 @@ public class CsvTestsDataLoader { Map.entry(BOOKS.indexName, BOOKS), Map.entry(SEMANTIC_TEXT.indexName, SEMANTIC_TEXT), Map.entry(MV_TEXT.indexName, MV_TEXT), - Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR) + Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR), + Map.entry(COLORS.indexName, COLORS) ); private static final EnrichConfig LANGUAGES_ENRICH = new EnrichConfig("languages_policy", "enrich-policy-languages.json"); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv new file mode 100644 index 0000000000000..b82ef7087a54c --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv @@ -0,0 +1,60 @@ +color:text,hex_code:keyword,rgb_vector:dense_vector,primary:boolean +maroon, #800000, [128,0,0], false +brown, #A52A2A, [165,42,42], false +firebrick, #B22222, [178,34,34], false +crimson, #DC143C, [220,20,60], false +red, #FF0000, [255,0,0], true +tomato, #FF6347, [255,99,71], false +coral, #FF7F50, [255,127,80], false +salmon, #FA8072, [250,128,114], false +orange, #FFA500, [255,165,0], false +gold, #FFD700, [255,215,0], false +golden rod, #DAA520, [218,165,32], false +khaki, #F0E68C, [240,230,140], false +olive, #808000, [128,128,0], false +yellow, #FFFF00, [255,255,0], true +chartreuse, #7FFF00, [127,255,0], false +green, #008000, [0,128,0], true +lime, #00FF00, [0,255,0], false +teal, #008080, [0,128,128], false +cyan, #00FFFF, [0,255,255], true +turquoise, #40E0D0, [64,224,208], false +aqua marine, #7FFFD4, [127,255,212], false +navy, #000080, [0,0,128], false +blue, #0000FF, [0,0,255], true +indigo, #4B0082, [75,0,130], false +purple, #800080, [128,0,128], false +thistle, #D8BFD8, [216,191,216], false +plum, #DDA0DD, [221,160,221], false +violet, #EE82EE, [238,130,238], false +magenta, #FF00FF, [255,0,255], true +orchid, #DA70D6, [218,112,214], false +pink, #FFC0CB, [255,192,203], false +beige, #F5F5DC, [245,245,220], false +bisque, #FFE4C4, [255,228,196], false +wheat, #F5DEB3, [245,222,179], false +corn silk, #FFF8DC, [255,248,220], false +lemon chiffon, #FFFACD, [255,250,205], false +sienna, #A0522D, [160,82,45], false +chocolate, #D2691E, [210,105,30], false +peru, #CD853F, [205,133,63], false +burly wood, #DEB887, [222,184,135], false +tan, #D2B48C, [210,180,140], false +moccasin, #FFE4B5, [255,228,181], false +peach puff, #FFDAB9, [255,218,185], false +misty rose, #FFE4E1, [255,228,225], false +linen, #FAF0E6, [250,240,230], false +old lace, #FDF5E6, [253,245,230], false +papaya whip, #FFEFD5, [255,239,213], false +sea shell, #FFF5EE, [255,245,238], false +mint cream, #F5FFFA, [245,255,250], false +lavender, #E6E6FA, [230,230,250], false +honeydew, #F0FFF0, [240,255,240], false +ivory, #FFFFF0, [255,255,240], false +azure, #F0FFFF, [240,255,255], false +snow, #FFFAFA, [255,250,250], false +black, #000000, [0,0,0], true +gray, #808080, [128,128,128], true +silver, #C0C0C0, [192,192,192], false +gainsboro, #DCDCDC, [220,220,220], false +white, #FFFFFF, [255,255,255], true diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec new file mode 100644 index 0000000000000..5e65e6269e652 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -0,0 +1,285 @@ +# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values +# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a +# top-n query at the shard level + +knnSearch +required_capability: knn_function + +// tag::knn-function[] +from colors metadata _score +| where knn(rgb_vector, [0, 120, 0]) +| sort _score desc, color asc +// end::knn-function[] +| keep color, rgb_vector +| limit 10 +; + +// tag::knn-function-result[] +color:text | rgb_vector:dense_vector +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +olive | [128.0, 128.0, 0.0] +teal | [0.0, 128.0, 128.0] +lime | [0.0, 255.0, 0.0] +sienna | [160.0, 82.0, 45.0] +maroon | [128.0, 0.0, 0.0] +navy | [0.0, 0.0, 128.0] +gray | [128.0, 128.0, 128.0] +chartreuse | [127.0, 255.0, 0.0] +// end::knn-function-result[] +; + +knnSearchWithKOption +required_capability: knn_function + +// tag::knn-function-options[] +from colors metadata _score +| where knn(rgb_vector, [0,255,255], {"k": 4}) +| sort _score desc, color asc +// end::knn-function-options[] +| keep color, rgb_vector +| limit 4 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +turquoise | [64.0, 224.0, 208.0] +aqua marine | [127.0, 255.0, 212.0] +teal | [0.0, 128.0, 128.0] +; + +knnSearchWithSimilarityOption +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40}) +| sort _score desc, color asc +| keep color, rgb_vector +; + +color:text | rgb_vector:dense_vector +pink | [255.0, 192.0, 203.0] +peach puff | [255.0, 218.0, 185.0] +bisque | [255.0, 228.0, 196.0] +wheat | [245.0, 222.0, 179.0] + +; + +knnHybridSearch +required_capability: knn_function + +from colors metadata _score +| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140}) +| where primary == true +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +blue | [0.0, 0.0, 255.0] +gray | [128.0, 128.0, 128.0] +cyan | [0.0, 255.0, 255.0] +magenta | [255.0, 0.0, 255.0] +green | [0.0, 128.0, 0.0] +white | [255.0, 255.0, 255.0] +black | [0.0, 0.0, 0.0] +red | [255.0, 0.0, 0.0] +yellow | [255.0, 255.0, 0.0] +; + +knnWithMultipleFunctions +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive") +| sort _score desc, color asc +| keep color, rgb_vector +; + +color:text | rgb_vector:dense_vector +olive | [128.0, 128.0, 0.0] +; + +knnAfterKeep +required_capability: knn_function + +from colors metadata _score +| keep rgb_vector, color, _score +| where knn(rgb_vector, [128,255,0], {"k": 140}) +| sort _score desc, color asc +| keep rgb_vector +| limit 5 +; + +rgb_vector:dense_vector +[127.0, 255.0, 0.0] +[128.0, 128.0, 0.0] +[255.0, 255.0, 0.0] +[0.0, 255.0, 0.0] +[218.0, 165.0, 32.0] +; + +knnAfterDrop +required_capability: knn_function + +from colors metadata _score +| drop primary +| where knn(rgb_vector, [128,250,0], {"k": 140}) +| sort _score desc, color asc +| keep color, rgb_vector +| limit 5 +; + +color:text | rgb_vector: dense_vector +chartreuse | [127.0, 255.0, 0.0] +olive | [128.0, 128.0, 0.0] +yellow | [255.0, 255.0, 0.0] +golden rod | [218.0, 165.0, 32.0] +lime | [0.0, 255.0, 0.0] +; + +knnAfterEval +required_capability: knn_function + +from colors metadata _score +| eval composed_name = locate(color, " ") > 0 +| where knn(rgb_vector, [128,128,0], {"k": 140}) +| sort _score desc, color asc +| keep color, composed_name +| limit 5 +; + +color:text | composed_name:boolean +olive | false +sienna | false +chocolate | false +peru | false +golden rod | true +; + +knnWithConjunction +required_capability: knn_function + +# TODO We need kNN prefiltering here so we get more candidates that pass the filter +from colors metadata _score +| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*" +| sort _score desc, color asc +| keep color, hex_code, rgb_vector +| limit 10 +; + +color:text | hex_code:keyword | rgb_vector:dense_vector +ivory | #FFFFF0 | [255.0, 255.0, 240.0] +sea shell | #FFF5EE | [255.0, 245.0, 238.0] +snow | #FFFAFA | [255.0, 250.0, 250.0] +white | #FFFFFF | [255.0, 255.0, 255.0] +corn silk | #FFF8DC | [255.0, 248.0, 220.0] +lemon chiffon | #FFFACD | [255.0, 250.0, 205.0] +yellow | #FFFF00 | [255.0, 255.0, 0.0] +; + +knnWithDisjunctionAndFiltersConjunction +required_capability: knn_function + +# TODO We need kNN prefiltering here so we get more candidates that pass the filter +from colors metadata _score +| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true +| keep color, rgb_vector, _score +| sort _score desc, color asc +| drop _score +| limit 10 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +blue | [0.0, 0.0, 255.0] +magenta | [255.0, 0.0, 255.0] +gray | [128.0, 128.0, 128.0] +white | [255.0, 255.0, 255.0] +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +red | [255.0, 0.0, 0.0] +yellow | [255.0, 255.0, 0.0] +; + +knnWithNonPushableConjunction +required_capability: knn_function + +from colors metadata _score +| eval composed_name = locate(color, " ") > 0 +| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false +| sort _score desc, color asc +| keep color, composed_name +| limit 10 +; + +color:text | composed_name:boolean +olive | false +sienna | false +chocolate | false +peru | false +brown | false +firebrick | false +chartreuse | false +gray | false +green | false +maroon | false +; + +testKnnWithNonPushableDisjunctions +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10 +| sort _score desc, color asc +| keep color +; + +color:text +olive +aqua marine +lemon chiffon +papaya whip +; + +testKnnWithNonPushableDisjunctionsOnComplexExpressions +required_capability: knn_function + +from colors metadata _score +| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false) +| sort _score desc, color asc +| keep color, primary +; + +color:text | primary:boolean +olive | false +purple | false +indigo | false +; + +testKnnInStatsNonPushable +required_capability: knn_function + +from colors +| where length(color) < 10 +| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) +; + +c: long +50 +; + +testKnnInStatsWithGrouping +required_capability: knn_function +required_capability: full_text_functions_in_stats_where + +from colors +| where length(color) < 10 +| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary +; + +c: long | primary: boolean +41 | false +9 | true +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json index 17348adb6af4f..a7ef2f4840709 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json @@ -63,6 +63,9 @@ "semantic_text": { "type": "semantic_text", "inference_id": "foo_inference_id" + }, + "dense_vector": { + "type": "dense_vector" } } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json new file mode 100644 index 0000000000000..24c4102e428f8 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json @@ -0,0 +1,20 @@ +{ + "properties": { + "color": { + "type": "text" + }, + "hex_code": { + "type": "keyword" + }, + "primary": { + "type": "boolean" + }, + "rgb_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "index_options": { + "type": "hnsw" + } + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 12631fdeaed5b..a130b026cd88a 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.junit.Before; import java.io.IOException; @@ -127,9 +128,57 @@ public void testRetrieveDenseVectorFieldData() { } } + public void testNonIndexedDenseVectorField() throws IOException { + createIndexWithDenseVector("no_dense_vectors"); + + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + docs[i] = prepareIndex("no_dense_vectors").setId("" + i).setSource("id", String.valueOf(i)); + } + + indexRandom(true, docs); + + var query = """ + FROM no_dense_vectors + | KEEP id, vector + """; + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(numDocs, valuesList.size()); + valuesList.forEach(value -> { + assertEquals(2, value.size()); + Integer id = (Integer) value.get(0); + assertNotNull(id); + Object vector = value.get(1); + assertNull(vector); + }); + } + } + @Before public void setup() throws IOException { - var indexName = "test"; + assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + + createIndexWithDenseVector("test"); + + int numDims = randomIntBetween(32, 64) * 2; // min 64, even number + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(randomFloat()); + } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); + indexedVectors.put(i, vector); + } + + indexRandom(true, docs); + } + + private void createIndexWithDenseVector(String indexName) throws IOException { var client = client().admin().indices(); XContentBuilder mapping = XContentFactory.jsonBuilder() .startObject() @@ -159,19 +208,5 @@ public void setup() throws IOException { .setMapping(mapping) .setSettings(settingsBuilder.build()); assertAcked(CreateRequest); - - int numDims = randomIntBetween(32, 64) * 2; // min 64, even number - int numDocs = randomIntBetween(10, 100); - IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); - for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); - } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); - indexedVectors.put(i, vector); - } - - indexRandom(true, docs); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java new file mode 100644 index 0000000000000..b5780cae49a4a --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +public class KnnFunctionIT extends AbstractEsqlIntegTestCase { + + private final Map> indexedVectors = new HashMap<>(); + private int numDocs; + private int numDims; + + public void testKnnDefaults() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s) + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); + for (int i = 0; i < valuesList.size(); i++) { + List row = valuesList.get(i); + // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases + assertEquals(i, row.get(0)); + @SuppressWarnings("unchecked") + // Vectors should be the same + List floats = (List) row.get(1); + for (int j = 0; j < floats.size(); j++) { + assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + } + var score = (Double) row.get(2); + assertNotNull(score); + assertTrue(score > 0.0); + } + } + } + + public void testKnnOptions() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, {"k": 5}) + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + + public void testKnnNonPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, {"k": 5}) OR id > 10 + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + // K = 5, 1 more for every id > 10 + assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + } + } + + @Before + public void setup() throws IOException { + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + + var indexName = "test"; + var client = client().admin().indices(); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "integer") + .endObject() + .startObject("vector") + .field("type", "dense_vector") + .field("similarity", "l2_norm") + .endObject() + .startObject("floats") + .field("type", "float") + .endObject() + .endObject() + .endObject(); + + Settings.Builder settingsBuilder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1); + + var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); + assertAcked(createRequest); + + numDocs = randomIntBetween(10, 20); + numDims = randomIntBetween(3, 10); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + float value = 0.0f; + for (int i = 0; i < numDocs; i++) { + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(value++); + } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + indexedVectors.put(i, vector); + } + + indexRandom(true, docs); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 5c58d30bc452f..6db522acf19ca 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -925,7 +925,12 @@ public enum Cap { /** * Allow lookup join on mixed numeric fields, among byte, short, int, long, half_float, scaled_float, float and double. */ - LOOKUP_JOIN_ON_MIXED_NUMERIC_FIELDS; + LOOKUP_JOIN_ON_MIXED_NUMERIC_FIELDS, + + /** + * Support knn function + */ + KNN_FUNCTION(Build.current().isSnapshot()); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 9aaa659f181b0..d0a479f476b6c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -62,6 +62,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; @@ -1108,6 +1109,9 @@ private static Expression cast(org.elasticsearch.xpack.esql.core.expression.func if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { return processBinaryOperator((BinaryOperator) f); } + if (f instanceof VectorFunction vectorFunction) { + return processVectorFunction(f); + } return f; } @@ -1307,6 +1311,25 @@ private static Expression castStringLiteral(Expression from, DataType target) { return unresolvedAttribute(from, target.toString(), e); } } + + private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) { + List args = vectorFunction.arguments(); + List newArgs = new ArrayList<>(); + for (Expression arg : args) { + if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { + Object folded = arg.fold(FoldContext.small() /* TODO remove me */); + if (folded instanceof List) { + Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR); + newArgs.add(denseVector); + continue; + } + } + newArgs.add(arg); + } + + return vectorFunction.replaceChildren(newArgs); + } + } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 4246cf32cf595..be20a0f6a35f9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; @@ -81,6 +82,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; @@ -113,6 +115,7 @@ public static List getNamedWriteables() { entries.addAll(binaryComparisons()); entries.addAll(fullText()); entries.addAll(unaryScalars()); + entries.addAll(vector()); return entries; } @@ -248,4 +251,11 @@ private static List binaryComparisons() { private static List fullText() { return FullTextWritables.getNamedWriteables(); } + + private static List vector() { + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + return List.of(Knn.ENTRY); + } + return List.of(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 61e8a02de0443..69752b9985263 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -161,6 +161,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -454,7 +455,8 @@ private static FunctionDefinition[][] snapshotFunctions() { // This is an experimental function and can be removed without notice. def(Delay.class, Delay::new, "delay"), def(Rate.class, Rate::withUnresolvedTimestamp, "rate"), - def(Term.class, bi(Term::new), "term") } }; + def(Term.class, bi(Term::new), "term"), + def(Knn.class, tri(Knn::new), "knn") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index 4d534768cbeae..7e2e3d459d477 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.expression.function.fulltext; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.ShardConfig; import org.elasticsearch.compute.lucene.LuceneQueryExpressionEvaluator; @@ -17,15 +18,23 @@ import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.EntryExpression; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypeConverter; +import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; @@ -42,12 +51,15 @@ import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Predicate; +import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; @@ -134,7 +146,7 @@ public String functionType() { @Override public int hashCode() { - return Objects.hash(super.hashCode(), queryBuilder); + return Objects.hash(super.hashCode(), query, queryBuilder); } @Override @@ -143,7 +155,7 @@ public boolean equals(Object obj) { return false; } - return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder); + return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder) && Objects.equals(query, ((FullTextFunction) obj).query); } @Override @@ -321,4 +333,57 @@ public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { } return new LuceneQueryScoreEvaluator.Factory(shardConfigs); } + + protected static void populateOptionsMap( + final MapExpression options, + final Map optionsMap, + final TypeResolutions.ParamOrdinal paramOrdinal, + final String sourceText, + final Map allowedOptions + ) throws InvalidArgumentException { + for (EntryExpression entry : options.entryExpressions()) { + Expression optionExpr = entry.key(); + Expression valueExpr = entry.value(); + TypeResolution resolution = isFoldable(optionExpr, sourceText, paramOrdinal).and( + isFoldable(valueExpr, sourceText, paramOrdinal) + ); + if (resolution.unresolved()) { + throw new InvalidArgumentException(resolution.message()); + } + Object optionExprLiteral = ((Literal) optionExpr).value(); + Object valueExprLiteral = ((Literal) valueExpr).value(); + String optionName = optionExprLiteral instanceof BytesRef br ? br.utf8ToString() : optionExprLiteral.toString(); + String optionValue = valueExprLiteral instanceof BytesRef br ? br.utf8ToString() : valueExprLiteral.toString(); + // validate the optionExpr is supported + DataType dataType = allowedOptions.get(optionName); + if (dataType == null) { + throw new InvalidArgumentException( + format(null, "Invalid option [{}] in [{}], expected one of {}", optionName, sourceText, allowedOptions.keySet()) + ); + } + try { + optionsMap.put(optionName, DataTypeConverter.convert(optionValue, dataType)); + } catch (InvalidArgumentException e) { + throw new InvalidArgumentException(format(null, "Invalid option [{}] in [{}], {}", optionName, sourceText, e.getMessage())); + } + } + } + + public static String getNameFromFieldAttribute(FieldAttribute fieldAttribute) { + String fieldName = fieldAttribute.name(); + if (fieldAttribute.field() instanceof MultiTypeEsField multiTypeEsField) { + // If we have multiple field types, we allow the query to be done, but getting the underlying field name + fieldName = multiTypeEsField.getName(); + } + return fieldName; + } + + public static FieldAttribute fieldAsFieldAttribute(Expression field) { + Expression fieldExpression = field; + // Field may be converted to other data type (field_name :: data_type), so we need to check the original field + if (fieldExpression instanceof AbstractConvertFunction convertFunction) { + fieldExpression = convertFunction.field(); + } + return fieldExpression instanceof FieldAttribute fieldAttribute ? fieldAttribute : null; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java new file mode 100644 index 0000000000000..c79868b73fb0e --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -0,0 +1,285 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Check; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; +import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static java.util.Map.entry; +import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; + +public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); + + private final Expression field; + private final Expression options; + + public static final Map ALLOWED_OPTIONS = Map.ofEntries( + entry(K_FIELD.getPreferredName(), INTEGER), + entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), + entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT), + entry(BOOST_FIELD.getPreferredName(), FLOAT), + entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT) + ); + + @FunctionInfo( + returnType = "boolean", + preview = true, + description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " + + "knn function finds nearest vectors through approximate search on indexed dense_vectors.", + examples = { @Example(file = "knn-function", tag = "knn-function"), @Example(file = "knn-function", tag = "knn-function-options"), } + ) + public Knn( + Source source, + @Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field, + @Param( + name = "query", + type = { "dense_vector" }, + description = "Vector value to find top nearest neighbours for." + ) Expression query, + @MapParam( + name = "options", + params = { + @MapParam.MapParamEntry( + name = "boost", + type = "float", + valueHint = { "2.5" }, + description = "Floating point number used to decrease or increase the relevance scores of the query." + + "Defaults to 1.0." + ), + @MapParam.MapParamEntry( + name = "k", + type = "integer", + valueHint = { "10" }, + description = "The number of nearest neighbors to return from each shard. " + + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " + + "This value must be less than or equal to num_candidates. Defaults to 10." + ), + @MapParam.MapParamEntry( + name = "num_candidates", + type = "integer", + valueHint = { "10" }, + description = "The number of nearest neighbor candidates to consider per shard while doing knn search. " + + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. " + + "Defaults to 1.5 * k" + ), + @MapParam.MapParamEntry( + name = "similarity", + type = "double", + valueHint = { "0.01" }, + description = "The minimum similarity required for a document to be considered a match. " + + "The similarity value calculated relates to the raw similarity used, not the document score." + ), + @MapParam.MapParamEntry( + name = "rescore_oversample", + type = "double", + valueHint = { "3.5" }, + description = "Applies the specified oversampling for rescoring quantized vectors. " + + "See [oversampling and rescoring quantized vectors]" + + "(docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details." + ), }, + description = "(Optional) kNN additional options as <>." + + " See <> for more information.", + optional = true + ) Expression options + ) { + this(source, field, query, options, null); + } + + private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) { + super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder); + this.field = field; + this.options = options; + } + + public Expression field() { + return field; + } + + public Expression options() { + return options; + } + + @Override + public DataType dataType() { + return DataType.BOOLEAN; + } + + @Override + protected TypeResolution resolveParams() { + return resolveField().and(resolveQuery()).and(resolveOptions()); + } + + private TypeResolution resolveField() { + return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + } + + private TypeResolution resolveQuery() { + return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( + isNotNullAndFoldable(query(), sourceText(), SECOND) + ); + } + + private TypeResolution resolveOptions() { + if (options() != null) { + TypeResolution resolution = isNotNull(options(), sourceText(), THIRD); + if (resolution.unresolved()) { + return resolution; + } + // MapExpression does not have a DataType associated with it + resolution = isMapExpression(options(), sourceText(), THIRD); + if (resolution.unresolved()) { + return resolution; + } + + try { + knnQueryOptions(); + } catch (InvalidArgumentException e) { + return new TypeResolution(e.getMessage()); + } + } + return TypeResolution.TYPE_RESOLVED; + } + + private Map knnQueryOptions() throws InvalidArgumentException { + if (options() == null) { + return Map.of(); + } + + Map matchOptions = new HashMap<>(); + populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS); + return matchOptions; + } + + @Override + protected Query translate(TranslatorHandler handler) { + var fieldAttribute = fieldAsFieldAttribute(field()); + + Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); + String fieldName = getNameFromFieldAttribute(fieldAttribute); + @SuppressWarnings("unchecked") + List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); + float[] queryAsFloats = new float[queryFolded.size()]; + for (int i = 0; i < queryFolded.size(); i++) { + queryAsFloats[i] = queryFolded.get(i).floatValue(); + } + + return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions()); + } + + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new Knn(source(), field(), query(), options(), queryBuilder); + } + + private Map queryOptions() throws InvalidArgumentException { + if (options() == null) { + return Map.of(); + } + + Map options = new HashMap<>(); + populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS); + return options; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Knn( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.size() > 2 ? newChildren.get(2) : null, + queryBuilder() + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Knn::new, field(), query(), options()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + private static Knn readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression field = in.readNamedWriteable(Expression.class); + Expression query = in.readNamedWriteable(Expression.class); + QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); + + return new Knn(source, field, query, null, queryBuilder); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + out.writeNamedWriteable(query()); + out.writeOptionalNamedWriteable(queryBuilder()); + } + + @Override + public boolean equals(Object o) { + // Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to + // ignore options when comparing two Knn functions + if (o == null || getClass() != o.getClass()) return false; + Knn knn = (Knn) o; + return Objects.equals(field(), knn.field()) + && Objects.equals(query(), knn.query()) + && Objects.equals(queryBuilder(), knn.queryBuilder()); + } + + @Override + public int hashCode() { + return Objects.hash(field(), query(), queryBuilder()); + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java new file mode 100644 index 0000000000000..dc0be7a29fee0 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +/** + * Marker interface for vector functions. Makes possible to do implicit casting + * from multi values to dense_vector field types, so parameters are actually + * processed as dense_vectors in vector functions + */ +public interface VectorFunction {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java new file mode 100644 index 0000000000000..aa0e896dfc013 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.querydsl.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; + +public class KnnQuery extends Query { + + private final String field; + private final float[] query; + private final Map options; + + public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; + + public KnnQuery(Source source, String field, float[] query, Map options) { + super(source); + assert options != null; + this.field = field; + this.query = query; + this.options = options; + } + + @Override + protected QueryBuilder asBuilder() { + Integer k = (Integer) options.get(K_FIELD.getPreferredName()); + Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName()); + RescoreVectorBuilder rescoreVectorBuilder = null; + Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD); + if (oversample != null) { + rescoreVectorBuilder = new RescoreVectorBuilder(oversample); + } + Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName()); + + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity); + Number boost = (Number) options.get(BOOST_FIELD.getPreferredName()); + if (boost != null) { + queryBuilder.boost(boost.floatValue()); + } + return queryBuilder; + } + + @Override + protected String innerToString() { + return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))"; + } + + @Override + public boolean equals(Object o) { + if (super.equals(o) == false) return false; + + KnnQuery knnQuery = (KnnQuery) o; + return Objects.equals(field, knnQuery.field) + && Objects.deepEquals(query, knnQuery.query) + && Objects.equals(options, knnQuery.options); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options); + } + + @Override + public boolean scorable() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 30756957416f3..d44dcfdf8f72b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -273,6 +273,10 @@ public final void test() throws Throwable { "can't use KQL function in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName()) ); + assumeFalse( + "can't use KNN function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName()) + ); assumeFalse( "lookup join disabled for csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V12.capabilityName()) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index 8e396e4753f09..e55a1b039258e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.WildcardQueryBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.test.EqualsHashCodeTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.ExpressionWritables; @@ -111,6 +112,7 @@ public static NamedWriteableRegistry writableRegistry() { entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, WildcardQueryBuilder.NAME, WildcardQueryBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RegexpQueryBuilder.NAME, RegexpQueryBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KnnVectorQueryBuilder.NAME, KnnVectorQueryBuilder::new)); entries.add(SingleValueQuery.ENTRY); entries.addAll(ExpressionWritables.getNamedWriteables()); entries.addAll(PlanWritables.getNamedWriteables()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 99356d88cec5a..88b7c8800aeec 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDatetime; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.index.EsIndex; @@ -2332,6 +2333,22 @@ public void testImplicitCasting() { assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]")); } + public void testDenseVectorImplicitCasting() { + Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); + + var plan = analyze(""" + from test | where knn(vector, [0.342, 0.164, 0.234]) + """, "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + var field = knn.field(); + var queryVector = as(knn.query(), Literal.class); + assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); + assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234))); + } + public void testRateRequiresCounterTypes() { assumeTrue("rate requires snapshot builds", Build.current().isSnapshot()); Analyzer analyzer = analyzer(tsdbIndexResolution()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index f4a24773342d9..8d5e1cfa03c1f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.parser.EsqlParser; @@ -1229,6 +1230,9 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])"); + } } private void checkFieldBasedFunctionNotAllowedAfterCommands(String functionName, String functionType, String functionInvocation) { @@ -1355,6 +1359,9 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("Term", "term(title, \"Meditation\")", "function"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function"); + } } private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType) @@ -1387,6 +1394,9 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])"); + } } private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) { @@ -1445,6 +1455,9 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function"); + } } private void checkFullTextFunctionsWithNonBooleanFunctions(String functionName, String functionInvocation, String functionType) { @@ -1509,6 +1522,9 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])"); + } } private void testFullTextFunctionTargetsExistingField(String functionInvocation) throws Exception { @@ -2027,6 +2043,9 @@ public void testLookupJoinDataTypeMismatch() { public void testFullTextFunctionOptions() { checkOptionDataTypes(Match.ALLOWED_OPTIONS, "FROM test | WHERE match(title, \"Jean\", {\"%s\": %s})"); checkOptionDataTypes(QueryString.ALLOWED_OPTIONS, "FROM test | WHERE QSTR(\"title: Jean\", {\"%s\": %s})"); + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})"); + } } /** @@ -2102,6 +2121,10 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first"); + checkFullTextFunctionNullArgs("knn(vector, null)", "second"); + } } private void checkFullTextFunctionNullArgs(String functionInvocation, String argOrdinal) throws Exception { @@ -2118,6 +2141,9 @@ public void testFullTextFunctionsConstantQuery() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantQuery("term(title, tags)", "second"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second"); + } } private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception { @@ -2132,6 +2158,10 @@ public void testFullTextFunctionsInStats() { checkFullTextFunctionsInStats("title : \"Meditation\""); checkFullTextFunctionsInStats("qstr(\"title: Meditation\")"); checkFullTextFunctionsInStats("kql(\"title: Meditation\")"); + checkFullTextFunctionsInStats("match_phrase(title, \"Meditation\")"); + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); + } } private void checkFullTextFunctionsInStats(String functionInvocation) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java new file mode 100644 index 0000000000000..c2bc381e2663c --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize; +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; +import static org.elasticsearch.xpack.esql.planner.TranslatorHandler.TRANSLATOR_HANDLER; +import static org.hamcrest.Matchers.equalTo; + +public class KnnTests extends AbstractFunctionTestCase { + + public KnnTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + return parameterSuppliersFromTypedData(addFunctionNamedParams(testCaseSuppliers())); + } + + @Before + public void checkCapability() { + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + } + + private static List testCaseSuppliers() { + List suppliers = new ArrayList<>(); + + suppliers.add( + TestCaseSupplier.testCaseSupplier( + new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR), + new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true), + (d1, d2) -> equalTo("string"), + BOOLEAN, + (o1, o2) -> true + ) + ); + + return suppliers; + } + + private static List randomDenseVector() { + int dimensions = randomIntBetween(64, 128); + List vector = new ArrayList<>(); + for (int i = 0; i < dimensions; i++) { + vector.add(randomFloat()); + } + return vector; + } + + /** + * Adds function named parameters to all the test case suppliers provided + */ + private static List addFunctionNamedParams(List suppliers) { + // TODO get to a common class with MatchTests + List result = new ArrayList<>(); + for (TestCaseSupplier supplier : suppliers) { + List dataTypes = new ArrayList<>(supplier.types()); + dataTypes.add(UNSUPPORTED); + result.add(new TestCaseSupplier(supplier.name() + ", options", dataTypes, () -> { + List values = new ArrayList<>(supplier.get().getData()); + values.add( + new TestCaseSupplier.TypedData( + new MapExpression(Source.EMPTY, List.of(new Literal(Source.EMPTY, randomAlphaOfLength(10), KEYWORD))), + UNSUPPORTED, + "options" + ).forceLiteral() + ); + + return new TestCaseSupplier.TestCase(values, equalTo("KnnEvaluator"), BOOLEAN, equalTo(true)); + })); + } + return result; + } + + @Override + protected Expression build(Source source, List args) { + Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); + // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and + // thus test the serialization methods. But we can only do this if the parameters make sense . + if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { + QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, knn).toQueryBuilder(); + knn = (Knn) knn.replaceQueryBuilder(queryBuilder); + } + return knn; + } + + /** + * Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match. + */ + @Override + protected Expression serializeDeserializeExpression(Expression expression) { + Expression newExpression = serializeDeserialize( + expression, + PlanStreamOutput::writeNamedWriteable, + in -> in.readNamedWriteable(Expression.class), + testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work + ); + // Fields use synthetic sources, which can't be serialized. So we use the originals instead. + return newExpression.replaceChildren(expression.children()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java index 6993f7583dd02..301cbd6844afe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java @@ -82,7 +82,7 @@ protected Expression build(Source source, List args) { // thus test the serialization methods. But we can only do this if the parameters make sense . if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, match).toQueryBuilder(); - match.replaceQueryBuilder(queryBuilder); + match = (Match) match.replaceQueryBuilder(queryBuilder); } return match; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 9465d6e8665ce..a1b5b994850c3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -25,6 +25,8 @@ import org.elasticsearch.index.query.QueryStringQueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -1251,6 +1253,29 @@ public void testQStrOptionsPushDown() { assertThat(expectedQStrQuery.toString(), is(planStr.get())); } + public void testKnnOptionsPushDown() { + String query = """ + from test + | where KNN(dense_vector, [0.1, 0.2, 0.3], + { "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) + """; + var analyzer = makeAnalyzer("mapping-all-types.json", new EnrichResolution()); + var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); + + AtomicReference planStr = new AtomicReference<>(); + plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); + + var expectedQuery = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0.1f, 0.2f, 0.3f }, + 5, + 10, + new RescoreVectorBuilder(7), + 0.001f + ).boost(3.5f); + assertThat(expectedQuery.toString(), is(planStr.get())); + } + /** * Expecting * LimitExec[1000[INTEGER]] From dd447d333fa3ac04d62a897f6b1f32ea88468ee3 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 10 Jun 2025 17:48:44 +0200 Subject: [PATCH 2/2] Fix tests --- .../org/elasticsearch/xpack/esql/analysis/VerifierTests.java | 1 - .../yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 8d5e1cfa03c1f..0677eb01b5231 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -2158,7 +2158,6 @@ public void testFullTextFunctionsInStats() { checkFullTextFunctionsInStats("title : \"Meditation\""); checkFullTextFunctionsInStats("qstr(\"title: Meditation\")"); checkFullTextFunctionsInStats("kql(\"title: Meditation\")"); - checkFullTextFunctionsInStats("match_phrase(title, \"Meditation\")"); if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 0b43ee1cc2150..071b44b65e98c 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -101,7 +101,7 @@ setup: - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 145} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 146} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version":