diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md b/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md index df15bde7deb55..9b6d20b551e7a 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/knn.md @@ -4,7 +4,7 @@ ```esql from colors metadata _score -| where knn(rgb_vector, [0, 120, 0], 10) +| where knn(rgb_vector, [0, 120, 0]) | sort _score desc, color asc ``` diff --git a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md index 1e87271707676..f38a8e8d84584 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md +++ b/docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md @@ -2,12 +2,12 @@ **Supported function named parameters** -`num_candidates` -: (integer) The number of nearest neighbor candidates to consider per shard while doing knn search. Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * k - `boost` : (float) Floating point number used to decrease or increase the relevance scores of the query.Defaults to 1.0. +`min_candidates` +: (integer) The minimum number of nearest neighbor candidates to consider per shard while doing knn search. KNN may use a higher number of candidates in case the query can't use a approximate results. Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * LIMIT used for the query. + `rescore_oversample` : (double) Applies the specified oversampling for rescoring quantized vectors. See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details. diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md index e33acabbd014f..fb1b98a1e8a7a 100644 --- a/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md @@ -8,9 +8,6 @@ `query` : Vector value to find top nearest neighbours for. -`k` -: The number of nearest neighbors to return from each shard. Elasticsearch collects k results from each shard, then merges them to find the global top results. This value must be less than or equal to num_candidates. - `options` : (Optional) kNN additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params). See [knn query](/reference/query-languages/query-dsl/query-dsl-match-query.md#query-dsl-knn-query) for more information. diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg index 6e20dbc217206..75a104a7cdcfa 100644 --- a/docs/reference/query-languages/esql/images/functions/knn.svg +++ b/docs/reference/query-languages/esql/images/functions/knn.svg @@ -1 +1 @@ -KNN(field,query,k,options) \ No newline at end of file +KNN(field,query,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json index d347891393dcf..f4b77305a200b 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json @@ -5,7 +5,7 @@ "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.", "signatures" : [ ], "examples" : [ - "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc" + "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc, color asc" ], "preview" : true, "snapshot_only" : true diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md index f32319b080dbb..bea09b0bf50de 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -5,6 +5,6 @@ Finds the k nearest vectors to a query vector, as measured by a similarity metri ```esql from colors metadata _score -| where knn(rgb_vector, [0, 120, 0], 10) +| where knn(rgb_vector, [0, 120, 0]) | sort _score desc, color asc ``` diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java index c7f187c6c4a8f..a4561978bedff 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java @@ -60,10 +60,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards) } public Block executeQuery(Page page) { - // Lucene based operators retrieve DocVectors as first block - Block block = page.getBlock(0); - assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input"; - DocVector docs = (DocVector) block.asVector(); + // Search for DocVector block + Block docBlock = null; + for (int i = 0; i < page.getBlockCount(); i++) { + if (page.getBlock(i) instanceof DocBlock) { + docBlock = page.getBlock(i); + break; + } + } + assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock"; + DocVector docs = (DocVector) docBlock.asVector(); try { if (docs.singleSegmentNonDecreasing()) { return evalSingleSegmentNonDecreasing(docs); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java index 2afc885d71124..1c3d522fda5ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java @@ -9,7 +9,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.Page; @@ -46,9 +45,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco @Override protected Page process(Page page) { - assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount(); - assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector(); - assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector(); + assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition; + assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector + : "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector(); Block[] blocks = new Block[page.getBlockCount()]; for (int i = 0; i < page.getBlockCount(); i++) { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index 2cad34e324fda..7a0e854f63f90 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,11 +3,11 @@ # top-n query at the shard level knnSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 // tag::knn-function[] from colors metadata _score -| where knn(rgb_vector, [0, 120, 0], 10) +| where knn(rgb_vector, [0, 120, 0]) | sort _score desc, color asc // end::knn-function[] | keep color, rgb_vector @@ -30,10 +30,10 @@ chartreuse | [127.0, 255.0, 0.0] ; knnSearchWithSimilarityOption -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) +| where knn(rgb_vector, [255,192,203], {"similarity": 40}) | sort _score desc, color asc | keep color, rgb_vector ; @@ -46,13 +46,14 @@ wheat | [245.0, 222.0, 179.0] ; knnHybridSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10) +| where match(color, "blue") or knn(rgb_vector, [65,105,225]) | where primary == true | sort _score desc, color asc | keep color, rgb_vector +| limit 10 ; color:text | rgb_vector:dense_vector @@ -68,10 +69,10 @@ yellow | [255.0, 255.0, 0.0] ; knnWithPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors -| where knn(rgb_vector, [120,180,0], 10) and (match(color, "olive") or match(color, "green")) +| where knn(rgb_vector, [120,180,0]) and (match(color, "olive") or match(color, "green")) | sort color asc | keep color ; @@ -82,10 +83,10 @@ olive ; knnWithNegatedPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) +| where knn(rgb_vector, [128,128,0]) and not (match(color, "olive") or match(color, "chocolate")) | sort _score desc, color asc | keep color, rgb_vector | LIMIT 10 @@ -105,11 +106,11 @@ orange | [255.0, 165.0, 0.0] ; knnAfterKeep -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | keep rgb_vector, color, _score -| where knn(rgb_vector, [128,255,0], 140) +| where knn(rgb_vector, [128,255,0]) | sort _score desc, color asc | keep rgb_vector | limit 5 @@ -124,11 +125,11 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | drop primary -| where knn(rgb_vector, [128,250,0], 140) +| where knn(rgb_vector, [128,250,0]) | sort _score desc, color asc | keep color, rgb_vector | limit 5 @@ -143,11 +144,11 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], 140) +| where knn(rgb_vector, [128,128,0]) | sort _score desc, color asc | keep color, composed_name | limit 5 @@ -162,12 +163,13 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" +| where knn(rgb_vector, [255,255,238]) and hex_code like "#FFF*" | sort _score desc, color asc | keep color, hex_code, rgb_vector +| limit 10 ; color:text | hex_code:keyword | rgb_vector:dense_vector @@ -181,10 +183,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true +| where (knn(rgb_vector, [0,255,255]) or knn(rgb_vector, [128, 0, 255])) and primary == true | keep color, rgb_vector, _score | sort _score desc, color asc | drop _score @@ -204,10 +206,10 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNegationsAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) +| where (knn(rgb_vector, [0,255,255]) and not(primary == true and match(color, "blue"))) | sort _score desc, color asc | keep color, rgb_vector | limit 10 @@ -227,11 +229,11 @@ azure | [240.0, 255.0, 255.0] ; knnWithNonPushableConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], 140) and composed_name == false +| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false | sort _score desc, color asc | keep color, composed_name | limit 10 @@ -251,58 +253,88 @@ maroon | false ; testKnnWithNonPushableDisjunctions -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 +| where knn(rgb_vector, [128,128,0]) or length(color) > 10 | sort _score desc, color asc -| keep color +| keep color +| limit 10 ; color:text -olive -aqua marine -lemon chiffon -papaya whip +olive +sienna +chocolate +peru +golden rod +brown +firebrick +chartreuse +gray +green ; -testKnnWithNonPushableDisjunctionsOnComplexExpressions -required_capability: knn_function_v3 +testKnnWithNonPushableDisjunctionsAndMinCandidates +required_capability: knn_function_v4 from colors metadata _score -| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) +| where (knn(rgb_vector, [128,128,0], {"min_candidates": 2}) and length(color) > 10) or (knn(rgb_vector, [128,0,128], {"min_candidates": 2}) and primary == true) | sort _score desc, color asc | keep color, primary ; color:text | primary:boolean -olive | false -purple | false -indigo | false -; +gray | true +green | true +red | true +black | true +magenta | true +yellow | true +blue | true +aqua marine | false +papaya whip | false +lemon chiffon | false +white | true +cyan | true +; + +testKnnWithStats +required_capability: knn_function_v4 -testKnnInStatsNonPushable -required_capability: knn_function_v3 - -from colors -| where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) +from colors metadata _score +| where knn(rgb_vector, [128,128,0]) +| sort _score desc, color asc +| limit 15 +| stats c = count(*) ; -c: long -50 +c:long +15 ; -testKnnInStatsWithGrouping -required_capability: knn_function_v3 -required_capability: full_text_functions_in_stats_where +testKnnWithRerank +required_capability: knn_function_v4 +required_capability: rerank -from colors -| where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary +from colors metadata _score +| where knn(rgb_vector, [100,120,0]) +| sort _score desc, color asc +| limit 10 +| rerank rerank_score = "deepest blue" ON color WITH { "inference_id" : "test_reranker" } +| sort rerank_score desc, color asc +| keep color ; -c: long | primary: boolean -41 | false -9 | true +color:text +gray +peru +brown +green +olive +maroon +sienna +chocolate +firebrick +golden rod ; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index d44a9b458b082..21ec240d9f8f4 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -74,9 +74,10 @@ public void testKnnDefaults() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 10) + | WHERE knn(vector, %s) | KEEP id, _score, vector | SORT _score DESC + | LIMIT 10 """, Arrays.toString(queryVector)); try (var resp = run(query)) { @@ -113,9 +114,10 @@ public void testKnnOptions() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) + | WHERE knn(vector, %s) | KEEP id, _score, vector | SORT _score DESC + | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { @@ -131,12 +133,12 @@ public void testKnnNonPushedDown() { float[] queryVector = new float[numDims]; Arrays.fill(queryVector, 0.0f); - // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR id > 100 + | WHERE knn(vector, %s) OR id > 100 | KEEP id, _score, vector | SORT _score DESC + | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { @@ -155,7 +157,7 @@ public void testKnnWithPrefilters() { // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 + | WHERE knn(vector, %s) AND id > 5 AND id <= 10 | KEEP id, _score, vector | SORT _score DESC | LIMIT 5 @@ -178,7 +180,8 @@ public void testKnnWithLookupJoin() { var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR id > 100 + | WHERE KNN(lookup_vector, %s) OR id > 100 + | LIMIT 5 """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -193,7 +196,7 @@ public void testKnnWithLookupJoin() { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var indexName = "test"; var client = client().admin().indices(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 19eac8bd9ad03..9a69e9c86fe10 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1291,7 +1291,7 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION_V3(Build.current().isSnapshot()), + KNN_FUNCTION_V4(Build.current().isSnapshot()), /** * Support for the LIKE operator with a list of wildcards. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 9b794d9b9b7b5..f4d20dcafd1a0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -505,7 +505,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Score.class, uni(Score::new), Score.NAME), def(Term.class, bi(Term::new), "term"), - def(Knn.class, quad(Knn::new), "knn"), + def(Knn.class, tri(Knn::new), "knn"), def(ToGeohash.class, ToGeohash::new, "to_geohash"), def(ToGeotile.class, ToGeotile::new, "to_geotile"), def(ToGeohex.class, ToGeohex::new, "to_geohex"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index c273da317dec2..c9e23fdd29387 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -384,18 +384,29 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) { - shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher()); + shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher()); } return new LuceneQueryExpressionEvaluator.Factory(shardConfigs); } + /** + * Returns the query builder to be used when the function cannot be pushed down to Lucene, but uses a + * {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator} instead + * + * @return the query builder to be used in the {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator} + */ + protected QueryBuilder evaluatorQueryBuilder() { + // Use the same query builder as for the translation by default + return queryBuilder(); + } + @Override public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { List shardContexts = toScorer.shardContexts(); ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) { - shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher()); + shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher()); } return new LuceneQueryScoreEvaluator.Factory(shardConfigs); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 0b64fb43909df..9add14da034b5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -7,15 +7,17 @@ package org.elasticsearch.xpack.esql.expression.function.vector; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; +import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -54,14 +56,11 @@ import static java.util.Map.entry; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; -import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; -import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; @@ -70,20 +69,26 @@ import static org.elasticsearch.xpack.esql.expression.Foldables.TypeResolutionValidator.forPreOptimizationValidation; import static org.elasticsearch.xpack.esql.expression.Foldables.resolveTypeQuery; -public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { - private final Logger log = LogManager.getLogger(getClass()); +public class Knn extends FullTextFunction + implements + OptionalArgument, + VectorFunction, + PostAnalysisPlanVerificationAware, + PostOptimizationVerificationAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); private final Expression field; // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes - private final transient Expression k; + private final transient Integer k; private final Expression options; // Expressions to be used as prefilters in knn query private final List filterExpressions; + public static final String MIN_CANDIDATES_OPTION = "min_candidates"; + public static final Map ALLOWED_OPTIONS = Map.ofEntries( - entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), + entry(MIN_CANDIDATES_OPTION, INTEGER), entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT), entry(BOOST_FIELD.getPreferredName(), FLOAT), entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT) @@ -105,13 +110,6 @@ public Knn( type = { "dense_vector" }, description = "Vector value to find top nearest neighbours for." ) Expression query, - @Param( - name = "k", - type = { "integer" }, - description = "The number of nearest neighbors to return from each shard. " - + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " - + "This value must be less than or equal to num_candidates." - ) Expression k, @MapParam( name = "options", params = { @@ -123,12 +121,13 @@ public Knn( + "Defaults to 1.0." ), @MapParam.MapParamEntry( - name = "num_candidates", + name = "min_candidates", type = "integer", valueHint = { "10" }, - description = "The number of nearest neighbor candidates to consider per shard while doing knn search. " - + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. " - + "Defaults to 1.5 * k" + description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. " + + " KNN may use a higher number of candidates in case the query can't use a approximate results. " + + "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. " + + "Defaults to 1.5 * LIMIT used for the query." ), @MapParam.MapParamEntry( name = "similarity", @@ -150,32 +149,29 @@ public Knn( optional = true ) Expression options ) { - this(source, field, query, k, options, null, List.of()); + this(source, field, query, options, null, null, List.of()); } public Knn( Source source, Expression field, Expression query, - Expression k, Expression options, + Integer k, QueryBuilder queryBuilder, List filterExpressions ) { - super(source, query, expressionList(field, query, k, options), queryBuilder); + super(source, query, expressionList(field, query, options), queryBuilder); this.field = field; this.k = k; this.options = options; this.filterExpressions = filterExpressions; } - private static List expressionList(Expression field, Expression query, Expression k, Expression options) { + private static List expressionList(Expression field, Expression query, Expression options) { List result = new ArrayList<>(); result.add(field); result.add(query); - if (k != null) { - result.add(k); - } if (options != null) { result.add(options); } @@ -186,7 +182,7 @@ public Expression field() { return field; } - public Expression k() { + public Integer k() { return k; } @@ -205,7 +201,7 @@ public DataType dataType() { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS)); + return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS)); } private TypeResolution resolveField() { @@ -225,14 +221,9 @@ private TypeResolution resolveQuery() { return TypeResolution.TYPE_RESOLVED; } - private TypeResolution resolveK() { - if (k == null) { - // Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed - return TypeResolution.TYPE_RESOLVED; - } - - return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD)) - .and(isNotNull(k(), sourceText(), THIRD)); + public Knn replaceK(Integer k) { + Check.notNull(k, "k must not be null"); + return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions()); } public List queryAsObject() { @@ -246,16 +237,9 @@ public List queryAsObject() { throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query)); } - int getKIntValue() { - if (k() instanceof Literal literal) { - return (int) (Number) literal.value(); - } - throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k())); - } - @Override public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); + return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions()); } @Override @@ -271,37 +255,39 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { @Override protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { + assert k() != null : "Knn function must have a k value set before translation"; var fieldAttribute = Match.fieldAsFieldAttribute(field()); Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); - List queryFolded = queryAsObject(); - float[] queryAsFloats = new float[queryFolded.size()]; - for (int i = 0; i < queryFolded.size(); i++) { - queryAsFloats[i] = queryFolded.get(i).floatValue(); - } - int kValue = getKIntValue(); - - Map opts = queryOptions(); - opts.put(K_FIELD.getPreferredName(), kValue); + float[] queryAsFloats = queryAsFloats(); List filterQueries = new ArrayList<>(); for (Expression filterExpression : filterExpressions()) { if (filterExpression instanceof TranslationAware translationAware) { // We can only translate filter expressions that are translatable. In case any is not translatable, - // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them - // when creating an evaluator for the non-pushed down query + // Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator + // for the non-pushed down query if (translationAware.translatable(pushdownPredicates) == Translatable.YES) { filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); } } } - return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries); + return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries); + } + + private float[] queryAsFloats() { + List queryFolded = queryAsObject(); + float[] queryAsFloats = new float[queryFolded.size()]; + for (int i = 0; i < queryFolded.size(); i++) { + queryAsFloats[i] = queryFolded.get(i).floatValue(); + } + return queryAsFloats; } public Expression withFilters(List filterExpressions) { - return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); + return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions); } private Map queryOptions() throws InvalidArgumentException { @@ -312,6 +298,17 @@ private Map queryOptions() throws InvalidArgumentException { return options; } + protected QueryBuilder evaluatorQueryBuilder() { + // Either we couldn't push down due to non-pushable filters, or because it's part of a disjuncion. + // Uses a nearest neighbors exact query instead of an approximate one + var fieldAttribute = Match.fieldAsFieldAttribute(field()); + Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument"); + String fieldName = getNameFromFieldAttribute(fieldAttribute); + Map opts = queryOptions(); + + return new ExactKnnQueryBuilder(VectorData.fromFloats(queryAsFloats()), fieldName, (Float) opts.get(VECTOR_SIMILARITY_FIELD)); + } + @Override public BiConsumer postAnalysisPlanVerification() { return (plan, failures) -> { @@ -320,14 +317,24 @@ public BiConsumer postAnalysisPlanVerification() { }; } + @Override + public void postOptimizationVerification(Failures failures) { + // Check that a k has been set + if (k() == null) { + failures.add( + Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find") + ); + } + } + @Override public Expression replaceChildren(List newChildren) { return new Knn( source(), newChildren.get(0), newChildren.get(1), - newChildren.get(2), - newChildren.size() > 3 ? newChildren.get(3) : null, + newChildren.size() > 2 ? newChildren.get(2) : null, + k(), queryBuilder(), filterExpressions() ); @@ -335,7 +342,7 @@ public Expression replaceChildren(List newChildren) { @Override protected NodeInfo info() { - return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions()); + return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions()); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index f4353c28476d2..ab41201ceb328 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -27,7 +27,7 @@ private VectorWritables() { public static List getNamedWritables() { List entries = new ArrayList<>(); - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { entries.add(Knn.ENTRY); } if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index dac533f872022..6f550524c5ca5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownJoinPastProject; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushLimitToKnn; import org.elasticsearch.xpack.esql.optimizer.rules.logical.RemoveStatsOverride; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateNestedExpressionWithEval; @@ -192,6 +193,7 @@ protected static Batch operators(boolean local) { new PruneColumns(), new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), + new PushLimitToKnn(), new PushDownAndCombineFilters(), new PushDownConjunctionsToKnnPrefilters(), new PushDownAndCombineSample(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java new file mode 100644 index 0000000000000..a8503c300bfbc --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.util.Holder; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; + +/** + * Traverses the logical plan and pushes down the limit to the KNN function(s) in filter expressions, so KNN can use + * it to set k if not specified. + */ +public class PushLimitToKnn extends OptimizerRules.ParameterizedOptimizerRule { + + public PushLimitToKnn() { + super(OptimizerRules.TransformDirection.DOWN); + } + + @Override + public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) { + Holder breakerReached = new Holder<>(false); + Holder firstLimit = new Holder<>(false); + return limit.transformDown(plan -> { + if (breakerReached.get()) { + // We reached a breaker and don't want to continue processing + return plan; + } + if (plan instanceof Filter filter) { + Expression limitAppliedExpression = limitFilterExpressions(filter.condition(), limit, ctx); + if (limitAppliedExpression.equals(filter.condition()) == false) { + return filter.with(limitAppliedExpression); + } + } else if (plan instanceof Limit) { + // Break if it's not the initial limit + breakerReached.set(firstLimit.get()); + firstLimit.set(true); + } else if (plan instanceof TopN || plan instanceof Rerank || plan instanceof Aggregate) { + breakerReached.set(true); + } + + return plan; + }); + } + + /** + * Applies a limit to the filter expressions of a condition. Some filter expressions, such as KNN function, + * can be optimized by applying the limit directly to them. + */ + private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) { + return condition.transformDown(exp -> { + if (exp instanceof Knn knn) { + return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx())); + } + return exp; + }); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java index b218b897121df..fedddfa8bcaa4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -12,6 +12,7 @@ import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import java.util.ArrayList; import java.util.Arrays; @@ -20,8 +21,6 @@ import java.util.Objects; import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; -import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; -import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; public class KnnQuery extends Query { @@ -32,9 +31,12 @@ public class KnnQuery extends Query { private final List filterQueries; public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; + private final Integer k; - public KnnQuery(Source source, String field, float[] query, Map options, List filterQueries) { + public KnnQuery(Source source, String field, float[] query, Integer k, Map options, List filterQueries) { super(source); + assert k != null && k > 0 : "k must be a positive integer, but was: " + k; + this.k = k; assert options != null; this.field = field; this.query = query; @@ -44,16 +46,24 @@ public KnnQuery(Source source, String field, float[] query, Map @Override protected QueryBuilder asBuilder() { - Integer k = (Integer) options.get(K_FIELD.getPreferredName()); - Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName()); RescoreVectorBuilder rescoreVectorBuilder = null; Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD); if (oversample != null) { rescoreVectorBuilder = new RescoreVectorBuilder(oversample); } Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName()); - - KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity); + Integer minCandidates = (Integer) options.get(Knn.MIN_CANDIDATES_OPTION); + int adjustedK = Math.max(k, minCandidates == null ? 0 : minCandidates); + minCandidates = minCandidates == null ? null : Math.max(minCandidates, adjustedK); + + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder( + field, + query, + adjustedK, + minCandidates, + rescoreVectorBuilder, + vectorSimilarity + ); for (QueryBuilder filter : filterQueries) { queryBuilder.addFilterQuery(filter); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 869a851a1fb34..97429ea091053 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -305,7 +305,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V4.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f26c14db41604..95a7204b5c71f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2349,20 +2349,19 @@ public void testImplicitCasting() { public void testDenseVectorImplicitCastingKnn() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); checkDenseVectorCastingKnn("float_vector"); } private static void checkDenseVectorCastingKnn(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ - from test | where knn(%s, [0.342, 0.164, 0.234], 10) + from test | where knn(%s, [0.342, 0.164, 0.234]) """, fieldName), "mapping-dense_vector.json"); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); var knn = as(filter.condition(), Knn.class); - var field = knn.field(); var queryVector = as(knn.query(), Literal.class); assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f))); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 4e0814d6cc6f5..815ae4bae6b89 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1268,8 +1268,8 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])"); } } @@ -1401,8 +1401,8 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function"); } } @@ -1456,8 +1456,8 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])"); } } @@ -1521,8 +1521,8 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function"); } } @@ -1592,7 +1592,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2189,8 +2189,8 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})"); } } @@ -2282,10 +2282,9 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); - checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); - checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first"); + checkFullTextFunctionNullArgs("knn(vector, null)", "second"); } } @@ -2314,8 +2313,8 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { - checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { + checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 002c519b001f8..f87e278bd4238 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -52,7 +52,7 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); } private static List testCaseSuppliers() { @@ -121,7 +121,7 @@ private static List addFunctionNamedParams(List args) { - Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null); + Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and // thus test the serialization methods. But we can only do this if the parameters make sense . if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java index 0a42b1962bfe1..35c75d99ab925 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java @@ -32,6 +32,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution; import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; @@ -118,7 +119,7 @@ public static void init() { new EsqlFunctionRegistry(), getIndexResultTypes, enrichResolution, - emptyInferenceResolution() + defaultInferenceResolution() ), TEST_VERIFIER ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 9b161388d6cc3..d1bb7aeaa166a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1376,12 +1376,12 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where KNN(dense_vector, [0.1, 0.2, 0.3], 5, - { "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) + | where KNN(dense_vector, [0.1, 0.2, 0.3], + { "similarity": 0.001, "min_candidates": 5000, "rescore_oversample": 7, "boost": 3.5 }) """; var analyzer = makeAnalyzer("mapping-all-types.json"); var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); @@ -1392,12 +1392,69 @@ public void testKnnOptionsPushDown() { var expectedQuery = new KnnVectorQueryBuilder( "dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, - 5, - 10, + 5000, + 5000, new RescoreVectorBuilder(7), 0.001f ).boost(3.5f); - assertThat(expectedQuery.toString(), is(planStr.get())); + assertEquals(expectedQuery.toString(), planStr.get()); + } + + public void testKnnUsesLimitForK() { + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + String query = """ + from test + | where KNN(dense_vector, [0.1, 0.2, 0.3]) + | limit 10 + """; + var analyzer = makeAnalyzer("mapping-all-types.json"); + var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); + + AtomicReference planStr = new AtomicReference<>(); + plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); + + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 10, null, null, null); + assertEquals(expectedQuery.toString(), planStr.get()); + } + + public void testKnnKAndMinCandidatesLowerK() { + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + String query = """ + from test + | where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 50}) + | limit 10 + """; + var analyzer = makeAnalyzer("mapping-all-types.json"); + var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); + + AtomicReference planStr = new AtomicReference<>(); + plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); + + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null); + assertEquals(expectedQuery.toString(), planStr.get()); + } + + public void testKnnKAndMinCandidatesHigherK() { + assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + String query = """ + from test + | where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 10}) + | limit 50 + """; + var analyzer = makeAnalyzer("mapping-all-types.json"); + var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); + + AtomicReference planStr = new AtomicReference<>(); + plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); + + var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null); + assertEquals(expectedQuery.toString(), planStr.get()); } /** @@ -1842,11 +1899,11 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) { } public void testKnnPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + | where knn(dense_vector, [0, 1, 2]) and integer > 10 """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -1859,12 +1916,12 @@ public void testKnnPrefilters() { query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 45, "integer > 10") + new Source(2, 41, "integer > 10") ); KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( "dense_vector", new float[] { 0, 1, 2 }, - 10, + 1000, null, null, null @@ -1874,11 +1931,11 @@ public void testKnnPrefilters() { } public void testKnnPrefiltersWithMultipleFilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) + | where knn(dense_vector, [0, 1, 2]) | where integer > 10 | where keyword == "test" """; @@ -1900,7 +1957,7 @@ public void testKnnPrefiltersWithMultipleFilters() { KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( "dense_vector", new float[] { 0, 1, 2 }, - 10, + 1000, null, null, null @@ -1910,11 +1967,11 @@ public void testKnnPrefiltersWithMultipleFilters() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + | where knn(dense_vector, [0, 1, 2]) and integer > 10 """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -1929,13 +1986,13 @@ public void testPushDownConjunctionsToKnnPrefilter() { query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 45, "integer > 10") + new Source(2, 41, "integer > 10") ); KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( "dense_vector", new float[] { 0, 1, 2 }, - 10, + 1000, null, null, null @@ -1947,11 +2004,11 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownNegatedConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) and NOT integer > 10 + | where knn(dense_vector, [0, 1, 2]) and NOT integer > 10 """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -1966,13 +2023,13 @@ public void testPushDownNegatedConjunctionsToKnnPrefilter() { query, unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), "integer", - new Source(2, 45, "NOT integer > 10") + new Source(2, 41, "NOT integer > 10") ); KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( "dense_vector", new float[] { 0, 1, 2 }, - 10, + 1000, null, null, null @@ -1984,11 +2041,11 @@ public void testPushDownNegatedConjunctionsToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + | where knn(dense_vector, [0, 1, 2]) or integer > 10 """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -1999,12 +2056,12 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { var queryExec = as(field.child(), EsQueryExec.class); // The disjunction should not be pushed down to the KNN query - KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); QueryBuilder rangeQueryBuilder = wrapWithSingleQuery( query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 44, "integer > 10") + new Source(2, 40, "integer > 10") ); var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder); @@ -2013,11 +2070,11 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testNotPushDownKnnWithNonPushablePrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where ((knn(dense_vector, [0, 1, 2], 10) AND integer > 10) and ((keyword == "test") or length(text) > 10)) + | where ((knn(dense_vector, [0, 1, 2]) AND integer > 10) and ((keyword == "test") or length(text) > 10)) """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -2040,19 +2097,19 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 47, "integer > 10") + new Source(2, 43, "integer > 10") ); assertEquals(integerGtQuery.toString(), queryExec.query().toString()); } public void testPushDownComplexNegationsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) - and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + | where ((knn(dense_vector, [0, 1, 2]) or NOT integer > 10) + and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6]))) """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -2072,18 +2129,18 @@ and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) query, unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), "keyword", - new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))") + new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6]))") ); QueryBuilder notIntegerGt10 = wrapWithSingleQuery( query, unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), "integer", - new Source(2, 46, "NOT integer > 10") + new Source(2, 42, "NOT integer > 10") ); - KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); - KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); + KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null); firstKnn.addFilterQuery(notKeywordFilter); secondKnn.addFilterQuery(notIntegerGt10); @@ -2097,11 +2154,11 @@ and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test - | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + | where ((knn(dense_vector, [0, 1, 2]) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6]))) """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -2111,24 +2168,24 @@ public void testMultipleKnnQueriesInPrefilters() { var field = as(project.child(), FieldExtractExec.class); var queryExec = as(field.child(), EsQueryExec.class); - KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null); // Integer range query (right side of first OR) QueryBuilder integerRangeQuery = wrapWithSingleQuery( query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 46, "integer > 10") + new Source(2, 42, "integer > 10") ); // Second KNN query (right side of second OR) - KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null); // Keyword term query (left side of second OR) QueryBuilder keywordQuery = wrapWithSingleQuery( query, unscore(termQuery("keyword", "test")), "keyword", - new Source(2, 66, "keyword == \"test\"") + new Source(2, 62, "keyword == \"test\"") ); // First OR (knn1 OR integer > 10) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 67bb40214cb7b..c3c4a9e3f1038 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -8499,11 +8499,11 @@ public void testSampleNoPushDownChangePoint() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + | where knn(dense_vector, [0, 1, 2]) and integer > 10 """; var optimized = planTypes(query); @@ -8519,11 +8519,11 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownMultipleFiltersToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) + | where knn(dense_vector, [0, 1, 2]) | where integer > 10 | where keyword == "test" """; @@ -8542,11 +8542,11 @@ public void testPushDownMultipleFiltersToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test - | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + | where knn(dense_vector, [0, 1, 2]) or integer > 10 """; var optimized = planTypes(query); @@ -8559,7 +8559,7 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -8576,7 +8576,7 @@ public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { var query = """ from test | where - ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0)) + ((knn(dense_vector, [0, 1, 2]) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0)) """; var optimized = planTypes(query); @@ -8594,7 +8594,7 @@ public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* or @@ -8611,7 +8611,7 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { var query = """ from test | where - ((knn(dense_vector, [0, 1, 2], 10) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0)) + ((knn(dense_vector, [0, 1, 2]) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0)) """; var optimized = planTypes(query); @@ -8626,7 +8626,7 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -8639,7 +8639,7 @@ public void testMultipleKnnQueriesInPrefilters() { */ var query = """ from test - | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + | where ((knn(dense_vector, [0, 1, 2]) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6]))) """; var optimized = planTypes(query); @@ -8668,6 +8668,156 @@ public void testMultipleKnnQueriesInPrefilters() { assertTrue(secondKnnFilters.contains(firstOr.right())); } + public void testKnnImplicitLimit() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2]) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(1000)); + } + + public void testKnnWithLimit() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2]) + | limit 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(10)); + } + + public void testKnnWithTopN() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) + | sort _score desc + | limit 10 + """; + var optimized = planTypes(query); + + var topN = as(optimized, TopN.class); + var filter = as(topN.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(10)); + } + + public void testKnnWithMultipleLimitsAfterTopN() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) + | limit 20 + | sort _score desc + | limit 10 + """; + var optimized = planTypes(query); + + var topN = as(optimized, TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10)); + var limit = as(topN.child(), Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(20)); + } + + public void testKnnWithMultipleLimitsCombined() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) + | limit 20 + | limit 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(10)); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(10)); + } + + public void testKnnWithMultipleClauses() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) and match(keyword, "test") + | where knn(dense_vector, [1, 2, 3]) + | sort _score + | limit 10 + """; + var optimized = planTypes(query); + + var topN = as(optimized, TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10)); + var filter = as(topN.child(), Filter.class); + var firstAnd = as(filter.condition(), And.class); + var fistKnn = as(firstAnd.right(), Knn.class); + assertThat(((Literal) fistKnn.query()).value(), is(List.of(1.0f, 2.0f, 3.0f))); + var secondAnd = as(firstAnd.left(), And.class); + var secondKnn = as(secondAnd.left(), Knn.class); + assertThat(((Literal) secondKnn.query()).value(), is(List.of(0.0f, 1.0f, 2.0f))); + } + + public void testKnnWithStats() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + assertThat( + typesError("from test | where knn(dense_vector, [0, 1, 2]) | stats c = count(*)"), + containsString("Knn function must be used with a LIMIT clause") + ); + } + + public void testKnnWithRerankAmdTopN() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + assertThat(typesError(""" + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) + | rerank "some text" on text with { "inference_id" : "reranking-inference-id" } + | sort _score desc + | limit 10 + """), containsString("Knn function must be used with a LIMIT clause")); + } + + public void testKnnWithRerankAmdLimit() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + var query = """ + from test metadata _score + | where knn(dense_vector, [0, 1, 2]) + | rerank "some text" on text with { "inference_id" : "reranking-inference-id" } + | limit 100 + """; + + var optimized = planTypes(query); + + var rerank = as(optimized, Rerank.class); + var limit = as(rerank.child(), Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(100)); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + assertThat(knn.k(), equalTo(100)); + } + private LogicalPlanOptimizer getCustomRulesLogicalPlanOptimizer(List> batches) { LogicalOptimizerContext context = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small()); LogicalPlanOptimizer customOptimizer = new LogicalPlanOptimizer(context) {