diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java index d91df60621fce..d268206cff3ff 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java @@ -61,10 +61,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards) } public Block executeQuery(Page page) { - // Lucene based operators retrieve DocVectors as first block - Block block = page.getBlock(0); - assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input"; - DocVector docs = (DocVector) block.asVector(); + // Search for DocVector block + Block docBlock = null; + for (int i = 0; i < page.getBlockCount(); i++) { + if (page.getBlock(i) instanceof DocBlock) { + docBlock = page.getBlock(i); + break; + } + } + assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock"; + DocVector docs = (DocVector) docBlock.asVector(); try { if (docs.singleSegmentNonDecreasing()) { return evalSingleSegmentNonDecreasing(docs).asBlock(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java index 2afc885d71124..1c3d522fda5ab 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java @@ -9,7 +9,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.DoubleVector; import org.elasticsearch.compute.data.Page; @@ -46,9 +45,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco @Override protected Page process(Page page) { - assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount(); - assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector(); - assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector(); + assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition; + assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector + : "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector(); Block[] blocks = new Block[page.getBlockCount()]; for (int i = 0; i < page.getBlockCount(); i++) { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index ce8061534ddbb..0be345645401b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,7 +3,7 @@ # top-n query at the shard level knnSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 // tag::knn-function[] from colors metadata _score @@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0] ; knnSearchWithSimilarityOption -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) @@ -46,7 +46,7 @@ wheat | [245.0, 222.0, 179.0] ; knnHybridSearch -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where match(color, "blue") or knn(rgb_vector, [65,105,225], 10) @@ -68,7 +68,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green")) @@ -82,7 +82,7 @@ green | [0.0, 128.0, 0.0] ; knnWithNegatedPrefilter -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) @@ -105,7 +105,7 @@ orange | [255.0, 165.0, 0.0] ; knnAfterKeep -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | keep rgb_vector, color, _score @@ -124,7 +124,7 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | drop primary @@ -143,7 +143,7 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -162,7 +162,7 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" @@ -181,7 +181,7 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true @@ -204,7 +204,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNegationsAndFiltersConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) @@ -227,62 +227,76 @@ azure | [240.0, 255.0, 255.0] ; knnWithNonPushableConjunction -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], 140) and composed_name == false +| where knn(rgb_vector, [100,100,0], 10) and composed_name == false | sort _score desc, color asc -| keep color, composed_name -| limit 10 +| keep color ; -color:text | composed_name:boolean -olive | false -sienna | false -chocolate | false -peru | false -brown | false -firebrick | false -chartreuse | false -gray | false -green | false -maroon | false +color:text +olive +sienna +brown +green +maroon +firebrick +chocolate +peru +gray ; testKnnWithNonPushableDisjunctions -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 +| where knn(rgb_vector, [128,128,0], 10) or length(color) > 10 | sort _score desc, color asc | keep color ; color:text olive +sienna +chocolate +peru +golden rod +brown +firebrick +chartreuse +green +maroon aqua marine lemon chiffon papaya whip ; -testKnnWithNonPushableDisjunctionsOnComplexExpressions -required_capability: knn_function_v3 +testKnnWithNonPushableConjunctionsOnComplexExpressions +required_capability: knn_function_v4 from colors metadata _score -| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) +| where knn(rgb_vector, [128,128,0], 10) and length(color) < 7 and knn(rgb_vector, [128,0,128], 10) and primary == false | sort _score desc, color asc -| keep color, primary +| keep color ; -color:text | primary:boolean -olive | false -purple | false -indigo | false +color:text +brown +coral +gold +maroon +olive +orange +peru +salmon +sienna +tomato ; testKnnInStatsNonPushable -required_capability: knn_function_v3 +required_capability: knn_function_v4 from colors | where length(color) < 10 @@ -294,7 +308,7 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function_v3 +required_capability: knn_function_v4 required_capability: full_text_functions_in_stats_where from colors diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..5f38f2fcfdf8b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -31,7 +32,9 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.equalTo; +@TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class KnnFunctionIT extends AbstractEsqlIntegTestCase { private final Map> indexedVectors = new HashMap<>(); @@ -157,9 +160,71 @@ public void testKnnWithLookupJoin() { ); } + public void testKnnNotPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, and 5 from the non-pushed down disjunction. They are disjoint so we get 10 as a result + var query = String.format(Locale.ROOT, """ + FROM test + | WHERE knn(vector, %s, 5) OR (length(keyword) > 5 AND length(keyword) <= 10) + | KEEP id, vector, keyword + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "vector", "keyword")); + assertColumnTypes(resp.columns(), List.of("integer", "dense_vector", "keyword")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(10, valuesList.size()); + } + } + + public void testKnnPrefiltersNotPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction + var query = String.format(Locale.ROOT, """ + FROM test + | WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10 + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + // No added columns + assertThat(resp.columns().size(), equalTo(4)); + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + + public void testKnnPrefiltersNotPushedDownWithScoring() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with the non-pushed down conjunction + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, 5) AND length(keyword) > 5 AND length(keyword) <= 10 + | SORT id ASC + | LIMIT 20 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + // No added columns + assertThat(resp.columns().size(), equalTo(5)); + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var indexName = "test"; var client = client().admin().indices(); @@ -176,6 +241,9 @@ public void setup() throws IOException { .startObject("floats") .field("type", "float") .endObject() + .startObject("keyword") + .field("type", "keyword") + .endObject() .endObject() .endObject(); @@ -195,7 +263,8 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { vector.add(value++); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId("" + i) + .setSource("id", String.valueOf(i), "floats", vector, "vector", vector, "keyword", randomAlphaOfLength(i)); indexedVectors.put(i, vector); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 35f6d1af2e76f..5d750b0c00a8b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1223,7 +1223,7 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION_V3(Build.current().isSnapshot()), + KNN_FUNCTION_V4(Build.current().isSnapshot()), /** * Support for the LIKE operator with a list of wildcards. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java new file mode 100644 index 0000000000000..1ff093e3fdc36 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/PostOptimizationPlanVerificationAware.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.capabilities; + +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.function.BiConsumer; + +/** + * Interface implemented by expressions or plans that require validation post logical optimization, + * when the plan and references have been not just resolved but also replaced. + * The interface is similar to {@link PostOptimizationVerificationAware}, but focused on the tree structure + */ +public interface PostOptimizationPlanVerificationAware { + + /** + * Allows the implementer to return a consumer that will perform self-validation in the context of the tree structure the implementer + * is part of. This usually involves checking the type and configuration of the children or that of the parent. + */ + BiConsumer postOptimizationVerification(); +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index b2fa2a13f6710..10f65b57c58c6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -182,6 +182,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity; import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -493,7 +494,8 @@ private static FunctionDefinition[][] snapshotFunctions() { def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"), def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"), def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"), - def(DotProduct.class, DotProduct::new, "v_dot_product") } }; + def(DotProduct.class, DotProduct::new, "v_dot_product"), + def(ExactNN.class, tri(ExactNN::new), "exact_nn") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index b5378db783f46..8f553423bd17f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -377,7 +377,7 @@ public static void fieldVerifier(LogicalPlan plan, FullTextFunction function, Ex } @Override - public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { List shardContexts = toEvaluator.shardContexts(); ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; @@ -388,7 +388,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua } @Override - public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { + public final ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) { List shardContexts = toScorer.shardContexts(); ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()]; int i = 0; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java new file mode 100644 index 0000000000000..0942cb11f238d --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/ExactNN.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Check; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; +import org.elasticsearch.xpack.esql.querydsl.query.ExactNNQuery; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.function.BiConsumer; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; + +/** + * Exact nearest neighbour search using a dense_vector similarity function. Used to translate {@link Knn} into exact search + * when it can't be pushed down to Lucene. Not exposed to users directly. + */ +public class ExactNN extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "ExactNN", + ExactNN::readFrom + ); + + private final Expression field; + private final Expression minimumSimilarity; + + @FunctionInfo( + returnType = "boolean", + preview = true, + description = "Finds all nearest vectors to a query vector, as measured by a similarity metric. " + + "performs brute force search over all vectors in the index.", + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public ExactNN( + Source source, + @Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field, + @Param( + name = "query", + type = { "dense_vector" }, + description = "Vector value to find top nearest neighbours for." + ) Expression query, + @Param( + name = "similarity", + type = { "double" }, + optional = true, + description = "The minimum similarity required for a document to be considered a match. " + + "The similarity value calculated relates to the raw similarity used, not the document score." + ) Expression minimumSimilarity + ) { + this(source, field, query, minimumSimilarity, null); + } + + public ExactNN(Source source, Expression field, Expression query, Expression minimumSimilarity, QueryBuilder queryBuilder) { + super(source, query, minimumSimilarity == null ? List.of(field, query) : List.of(field, query, minimumSimilarity), queryBuilder); + this.field = field; + this.minimumSimilarity = minimumSimilarity; + } + + public Expression field() { + return field; + } + + public Expression minimumSimilarity() { + return minimumSimilarity; + } + + @Override + public DataType dataType() { + return DataType.BOOLEAN; + } + + @Override + protected TypeResolution resolveParams() { + return resolveField().and(resolveQuery()).and(resolveMinimumSimilarity()); + } + + private TypeResolution resolveField() { + return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + } + + private TypeResolution resolveQuery() { + return isNotNull(query(), sourceText(), SECOND).and( + isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector") + ); + } + + private TypeResolution resolveMinimumSimilarity() { + if (minimumSimilarity == null) { + return TypeResolution.TYPE_RESOLVED; + } + + return isNotNull(minimumSimilarity(), sourceText(), THIRD).and(isNumeric(minimumSimilarity(), sourceText(), THIRD)); + } + + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new ExactNN(source(), field(), query(), minimumSimilarity(), queryBuilder); + } + + @Override + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { + var fieldAttribute = Match.fieldAsFieldAttribute(field()); + + Check.notNull(fieldAttribute, "Exact must have a field attribute as the first argument"); + String fieldName = getNameFromFieldAttribute(fieldAttribute); + @SuppressWarnings("unchecked") + List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); + float[] queryAsFloats = new float[queryFolded.size()]; + for (int i = 0; i < queryFolded.size(); i++) { + queryAsFloats[i] = queryFolded.get(i).floatValue(); + } + Float similarity = minimumSimilarity != null ? ((Number) minimumSimilarity().fold(FoldContext.small())).floatValue() : null; + + return new ExactNNQuery(source(), fieldName, queryAsFloats, similarity); + } + + @Override + public BiConsumer postAnalysisPlanVerification() { + return (plan, failures) -> { + super.postAnalysisPlanVerification().accept(plan, failures); + fieldVerifier(plan, this, field, failures); + }; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new ExactNN( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.size() > 2 ? newChildren.get(2) : null, + queryBuilder() + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, ExactNN::new, field(), query(), minimumSimilarity(), queryBuilder()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + private static ExactNN readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression field = in.readNamedWriteable(Expression.class); + Expression query = in.readNamedWriteable(Expression.class); + Expression minimumSimilarity = in.readOptionalNamedWriteable(Expression.class); + QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); + return new ExactNN(source, field, query, minimumSimilarity, queryBuilder); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + out.writeNamedWriteable(query()); + out.writeOptionalNamedWriteable(minimumSimilarity()); + out.writeOptionalNamedWriteable(queryBuilder()); + } + + @Override + public boolean equals(Object o) { + // ExactNN does not serialize options, as they get included in the query builder. We need to override equals and hashcode to + // ignore options when comparing two ExactNN functions + if (o == null || getClass() != o.getClass()) return false; + ExactNN exact = (ExactNN) o; + return Objects.equals(field(), exact.field()) + && Objects.equals(query(), exact.query()) + && Objects.equals(minimumSimilarity(), exact.minimumSimilarity()) + && Objects.equals(queryBuilder(), exact.queryBuilder()); + } + + @Override + public int hashCode() { + return Objects.hash(field(), query(), minimumSimilarity(), queryBuilder()); + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index cab5ec862d7f5..29ed31f908132 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -11,7 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; @@ -34,25 +34,30 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import static java.util.Map.entry; import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; +import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -65,7 +70,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; -public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware { +public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostOptimizationPlanVerificationAware { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); @@ -259,8 +264,8 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato for (Expression filterExpression : filterExpressions()) { if (filterExpression instanceof TranslationAware translationAware) { // We can only translate filter expressions that are translatable. In case any is not translatable, - // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them - // when creating an evaluator for the non-pushed down query + // Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator + // for the non-pushed down query if (translationAware.translatable(pushdownPredicates) == Translatable.YES) { filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); } @@ -274,6 +279,18 @@ public Expression withFilters(List filterExpressions) { return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } + public Collection nonPushableFilters() { + List nonPushableFilters = new ArrayList<>(); + for (Expression filterExpression : filterExpressions()) { + if (filterExpression instanceof TranslationAware translationAware) { + if (translationAware.translatable(LucenePushdownPredicates.DEFAULT) == Translatable.NO) { + nonPushableFilters.add(filterExpression); + } + } + } + return nonPushableFilters; + } + private Map queryOptions() throws InvalidArgumentException { Map options = new HashMap<>(); if (options() != null) { @@ -290,6 +307,30 @@ public BiConsumer postAnalysisPlanVerification() { }; } + @Override + public BiConsumer postOptimizationVerification() { + return (plan, failures) -> { + if (plan instanceof Filter f) { + f.condition().forEachDown(Or.class, or -> { + or.forEachDown(Knn.class, knn -> { + Collection nonPushableFilters = knn.nonPushableFilters(); + if (nonPushableFilters.isEmpty() == false) { + failures.add( + fail( + plan, + "knn function [{}] cannot be used in an OR clause when it is being filtered with " + + "the following AND conditions: {}.", + knn.sourceText(), + nonPushableFilters.stream().map(Expression::sourceText).collect(Collectors.joining(", ")) + ) + ); + } + }); + }); + } + }; + } + @Override public Expression replaceChildren(List newChildren) { return new Knn( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index d44c5681438b0..5a6d1aa70d554 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -27,8 +27,10 @@ private VectorWritables() { public static List getNamedWritables() { List entries = new ArrayList<>(); - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { entries.add(Knn.ENTRY); + // ExactNN is needed as a KNN optimization + entries.add(ExactNN.ENTRY); } if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) { entries.add(CosineSimilarity.ENTRY); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index dac533f872022..8ecf12ef4e0c1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -48,6 +48,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateNestedExpressionWithEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAliasingEvalWithProject; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceLimitAndSortAsTopN; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceOrderByExpressionWithEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceRegexMatch; @@ -194,6 +195,7 @@ protected static Batch operators(boolean local) { new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), new PushDownConjunctionsToKnnPrefilters(), + new ReplaceKnnWithNoPushedDownFilters(), new PushDownAndCombineSample(), new PushDownInferencePlan(), new PushDownEval(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 4a04b46be295a..b3af1489e71f5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.capabilities.PostOptimizationPlanVerificationAware; import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.optimizer.rules.PlanConsistencyChecker; @@ -39,10 +40,14 @@ void checkPlanConsistency(LogicalPlan optimizedPlan, Failures failures, Failures if (failures.hasFailures() == false) { if (p instanceof PostOptimizationVerificationAware pova) { pova.postOptimizationVerification(failures); + } else if (p instanceof PostOptimizationPlanVerificationAware popva) { + popva.postOptimizationVerification().accept(p, failures); } p.forEachExpression(ex -> { if (ex instanceof PostOptimizationVerificationAware va) { va.postOptimizationVerification(failures); + } else if (ex instanceof PostOptimizationPlanVerificationAware pva) { + pva.postOptimizationVerification().accept(p, failures); } }); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java new file mode 100644 index 0000000000000..f6d16002903bd --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceKnnWithNoPushedDownFilters.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Holder; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; +import static org.elasticsearch.xpack.esql.core.expression.Attribute.rawTemporaryName; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; + +/** + * Replaces KNN queries with non pushable prefilters used in filters. + * + * A query like: + * WHERE knn(field1, [..], 10) AND non-pushable-filter + * + * Will be replaced with: + * | WHERE non-pushable-filter + * | EVAL knn_score = SCORE(exact_nn(field1, [..])) + * | TOPN 10 knn_score DESC + * | WHERE knn_score > 0 + * | DROP knn_score + */ +public class ReplaceKnnWithNoPushedDownFilters extends OptimizerRules.OptimizerRule { + + public static final String EXACT_SCORE_ATTR_NAME = "knn_score"; + + public ReplaceKnnWithNoPushedDownFilters() { + super(UP); + } + + @Override + protected LogicalPlan rule(Filter filter) { + Expression condition = filter.condition(); + + Holder> knnQueries = new Holder<>(new ArrayList<>()); + Expression conditionWithoutKnns = condition.transformDown(Knn.class, knn -> replaceNonPushableKnnByTrue(knn, knnQueries)); + if (conditionWithoutKnns.equals(condition)) { + return filter; + } + + // Check that knn is not part of a disjunction + Holder hasNonPushableDisjunctions = new Holder<>(false); + filter.condition().forEachDown(Or.class, or -> { + or.forEachDown(Knn.class, knn -> { + Collection nonPushableFilters = knn.nonPushableFilters(); + if (nonPushableFilters.isEmpty() == false) { + hasNonPushableDisjunctions.set(true); + } + }); + }); + if (hasNonPushableDisjunctions.get()) { + return filter; + } + + // Replace knn with scoring expressions of exact queries + List exactQueries = knnQueries.get().stream().map(ReplaceKnnWithNoPushedDownFilters::replaceKnnByExactQuery).toList(); + assert exactQueries.isEmpty() == false; + + // Create an Eval for scoring the exact queries + List exactScoreAliases = exactQueryScoreAliases(exactQueries); + LogicalPlan scoringPlan = new Eval(EMPTY, filter.with(conditionWithoutKnns), exactScoreAliases); + List scoreAttrs = exactScoreAliases.stream().map(Alias::toAttribute).toList(); + + // Sort on the scores, limit on the minimum k from the queries + TopN topN = createTopN(scoreAttrs, knnQueries.get(), scoringPlan); + + // Filter on scores > 0 + Filter scoreFilter = createScoreFilter(scoreAttrs, topN); + + // Drop the scores + return new Project(EMPTY, scoreFilter, filter.output()); + } + + private static Expression replaceKnnByExactQuery(Knn knn) { + Expression minimumSimilarity = knn.options() == null + ? null + : ((MapExpression) knn.options()).get(VECTOR_SIMILARITY_FIELD.getPreferredName()); + ExactNN exact = new ExactNN(knn.source(), knn.field(), knn.query(), minimumSimilarity); + // Replaces query builder as it was not resolved during post analysis phase + return exact.replaceQueryBuilder( + TranslatorHandler.TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, exact).toQueryBuilder() + ); + } + + private static List exactQueryScoreAliases(List exactQueries) { + List scoringAliases = new ArrayList<>(); + for (int i = 0; i < exactQueries.size(); i++) { + String name = rawTemporaryName(EXACT_SCORE_ATTR_NAME, String.valueOf(i)); + Alias alias = new Alias(EMPTY, name, new Score(EMPTY, exactQueries.get(i))); + scoringAliases.add(alias); + } + return scoringAliases; + } + + private static Filter createScoreFilter(List scoreAttrs, LogicalPlan planToFilter) { + Expression scoreComparison = null; + for (Attribute scoringAttr : scoreAttrs) { + GreaterThan gt = new GreaterThan(EMPTY, scoringAttr, new Literal(EMPTY, 0.0, DataType.DOUBLE)); + if (scoreComparison == null) { + scoreComparison = gt; + } else { + scoreComparison = new And(EMPTY, gt, scoreComparison); + } + } + + return new Filter(EMPTY, planToFilter, scoreComparison); + } + + private static Expression replaceNonPushableKnnByTrue(Knn knn, Holder> replaced) { + if (knn.nonPushableFilters().isEmpty()) { + return knn; + } + replaced.get().add(knn); + return Literal.TRUE; + } + + private static TopN createTopN(List scoreAttrs, List knnQueries, LogicalPlan scoringPlan) { + List orders = scoreAttrs.stream() + .map(a -> new Order(EMPTY, a, Order.OrderDirection.DESC, Order.NullsPosition.LAST)) + .toList(); + int minimumK = knnQueries.stream().mapToInt(knn -> (Integer) knn.k().fold(FoldContext.small())).min().orElseThrow(); + return new TopN(EMPTY, scoringPlan, orders, new Literal(EMPTY, minimumK, DataType.INTEGER)); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java new file mode 100644 index 0000000000000..d45b10a07b708 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/ExactNNQuery.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.querydsl.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.ExactKnnQueryBuilder; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.Arrays; +import java.util.Objects; + +public class ExactNNQuery extends Query { + + private final String field; + private final float[] query; + private final Float minimumSimilarity; + + public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; + + public ExactNNQuery(Source source, String field, float[] query, Float minimumSimilarity) { + super(source); + this.field = field; + this.query = query; + this.minimumSimilarity = minimumSimilarity; + } + + @Override + protected QueryBuilder asBuilder() { + return new ExactKnnQueryBuilder(VectorData.fromFloats(query), field, minimumSimilarity); + } + + @Override + protected String innerToString() { + return "exactNN(" + field + ", " + Arrays.toString(query) + " minimumSimilarity=" + minimumSimilarity + ")"; + } + + @Override + public boolean equals(Object o) { + if (super.equals(o) == false) return false; + + if (o == null || getClass() != o.getClass()) return false; + ExactNNQuery query = (ExactNNQuery) o; + return Objects.equals(field, query.field) + && Objects.deepEquals(this.query, query.query) + && Objects.equals(minimumSimilarity, query.minimumSimilarity); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), minimumSimilarity); + } + + @Override + public boolean scorable() { + return true; + } + + @Override + public boolean containsPlan() { + return false; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index b38b3089823d5..f64414cf2a817 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -305,7 +305,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V4.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 86c5fed0f6c24..7d92345ebf0f6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1244,7 +1244,7 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1377,7 +1377,7 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } @@ -1432,7 +1432,7 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1497,7 +1497,7 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1568,7 +1568,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2147,7 +2147,7 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2235,7 +2235,7 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); @@ -2261,7 +2261,7 @@ public void testFullTextFunctionsConstantArg() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } @@ -2292,7 +2292,7 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) { checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 595eb58118a09..3ce00fadb95cc 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -51,7 +51,7 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); } private static List testCaseSuppliers() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index cd6371e4d4d5e..997fe859632d1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -65,8 +65,9 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; -import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -143,6 +144,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.elasticsearch.xpack.esql.plan.physical.AbstractPhysicalPlanSerializationTests.randomEstimatedRowSize; import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType; import static org.hamcrest.Matchers.contains; @@ -1377,7 +1379,7 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1843,7 +1845,7 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) { } public void testKnnPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1875,7 +1877,7 @@ public void testKnnPrefilters() { } public void testKnnPrefiltersWithMultipleFilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1911,7 +1913,7 @@ public void testKnnPrefiltersWithMultipleFilters() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1948,7 +1950,7 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownNegatedConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -1985,7 +1987,7 @@ public void testPushDownNegatedConjunctionsToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2014,7 +2016,7 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testNotPushDownKnnWithNonPushablePrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2022,19 +2024,58 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); - var limit = as(plan, LimitExec.class); - var exchange = as(limit.child(), ExchangeExec.class); - var project = as(exchange.child(), ProjectExec.class); - var field = as(project.child(), FieldExtractExec.class); - var secondLimit = as(field.child(), LimitExec.class); - var filter = as(secondLimit.child(), FilterExec.class); - var and = as(filter.condition(), And.class); - var knn = as(and.left(), Knn.class); - assertEquals("(keyword == \"test\") or length(text) > 10", knn.filterExpressions().get(0).toString()); - assertEquals("integer > 10", knn.filterExpressions().get(1).toString()); + var project = as(plan, ProjectExec.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); - var fieldExtract = as(filter.child(), FieldExtractExec.class); - var queryExec = as(fieldExtract.child(), EsQueryExec.class); + // LimitExec + var limit = as(project.child(), LimitExec.class); + assertThat(as(limit.limit(), Literal.class).value(), is(1000)); + + // FilterExec on $$knn_score$0 > 0.0 + var filter = as(limit.child(), FilterExec.class); + var gt = as(filter.condition(), GreaterThan.class); + ReferenceAttribute scoreAttr = as(gt.left(), ReferenceAttribute.class); + assertThat(scoreAttr.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt.right().fold(FoldContext.small()), is(0.0)); + + // TopNExec on $$knn_score$0 desc + var topN = as(filter.child(), TopNExec.class); + assertThat(as(topN.limit(), Literal.class).value(), is(10)); + assertThat(Expressions.name(topN.order().getFirst().child()), equalTo(scoreAttr.name())); + + // ExchangeExec + var exchange = as(topN.child(), ExchangeExec.class); + + // ProjectExec (with score column) + var project2 = as(exchange.child(), ProjectExec.class); + assertTrue(project2.output().contains(scoreAttr)); + + var fieldExtract = as(project2.child(), FieldExtractExec.class); + + var topN2 = as(fieldExtract.child(), TopNExec.class); + + // EvalExec for score + var eval = as(topN2.child(), EvalExec.class); + var scoreAlias = as(eval.fields().getFirst(), Alias.class); + assertThat(scoreAlias.name(), containsString(EXACT_SCORE_ATTR_NAME)); + var score = as(scoreAlias.child(), Score.class); + var exactNN = as(score.children().getFirst(), ExactNN.class); + var field = as(exactNN.field(), FieldAttribute.class); + assertThat(field.name(), equalTo("dense_vector")); + assertThat(exactNN.query().toString(), equalTo("[0.0, 1.0, 2.0]")); + + // FieldExtractExec for dense_vector + var fieldExtract2 = as(eval.child(), FieldExtractExec.class); + + // FilterExec for OR + var filter2 = as(fieldExtract2.child(), FilterExec.class); + var or = as(filter2.condition(), Or.class); + + // FieldExtractExec for keyword, text + var fieldExtract3 = as(filter2.child(), FieldExtractExec.class); + + // EsQueryExec for integer > 10 + var esQuery = as(fieldExtract3.child(), EsQueryExec.class); // The query should only contain the pushable condition QueryBuilder integerGtQuery = wrapWithSingleQuery( @@ -2044,11 +2085,11 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { new Source(2, 47, "integer > 10") ); - assertEquals(integerGtQuery.toString(), queryExec.query().toString()); + assertEquals(integerGtQuery.toString(), esQuery.query().toString()); } public void testPushDownComplexNegationsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test @@ -2098,7 +2139,7 @@ and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); String query = """ from test diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index a0dd67105097d..d35f97c36795b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Score; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; @@ -74,6 +75,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.vector.ExactNN; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; @@ -182,6 +184,7 @@ import static org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison.BinaryComparisonOperation.LTE; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceKnnWithNoPushedDownFilters.EXACT_SCORE_ATTR_NAME; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; @@ -7861,7 +7864,7 @@ public void testSampleNoPushDownChangePoint() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7881,7 +7884,7 @@ public void testPushDownConjunctionsToKnnPrefilter() { } public void testPushDownMultipleFiltersToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7904,7 +7907,7 @@ public void testPushDownMultipleFiltersToKnnPrefilter() { } public void testNotPushDownDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); var query = """ from test @@ -7921,7 +7924,7 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { } public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -7956,7 +7959,7 @@ public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* or @@ -7988,7 +7991,7 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { } public void testMultipleKnnQueriesInPrefilters() { - assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); /* and @@ -8122,4 +8125,146 @@ public List output() { assertThat(e.getMessage(), containsString("Output has changed from")); } + public void testKnnWithNonPushablePrefiltersNoScoring() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + // Disjunctions with pushable conditions are allowed + var plan = planTypes(""" + from types + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and match(text, "hello") and length(keyword) > 10 + """); + + var project = as(plan, Project.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); + + var limit = as(project.child(), Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[$$knn_score$0 > 0.0] + var filter = as(limit.child(), Filter.class); + var gt = as(filter.condition(), GreaterThan.class); + ReferenceAttribute scoreAttr = as(gt.left(), ReferenceAttribute.class); + assertThat(scoreAttr.toString(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 10] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().getFirst().child(), equalTo(scoreAttr)); + + // Next: Eval[SCORE(EXACTNN(...)) AS $$knn_score$...] + var eval = as(topN.child(), Eval.class); + var alias = as(eval.fields().get(0), Alias.class); + assertThat(alias.name(), equalTo(scoreAttr.name())); + var score = as(alias.child(), Score.class); + var exactNN = as(score.children().getFirst(), ExactNN.class); + var field = as(exactNN.field(), FieldAttribute.class); + assertThat(field.name(), equalTo("dense_vector")); + assertThat(exactNN.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + + var prefilter = as(eval.child(), Filter.class); + var and = as(prefilter.condition(), And.class); + as(and.left(), Match.class); + var lenGt = as(and.right(), GreaterThan.class); + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Next: EsRelation[types] + var esRelation = as(prefilter.child(), EsRelation.class); + } + + public void testKnnWithNonPushablePrefiltersScoringMultipleKnn() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + // Disjunctions with pushable conditions are allowed + var plan = planTypes(""" + from types metadata _score + | where knn(dense_vector, [0.1, 0.2, 0.3], 5) and knn(dense_vector, [0.4, 0.5, 0.6], 7) + and match(text, "hello") and length(keyword) > 10 + """); + + var project = as(plan, Project.class); + assertFalse(project.projections().stream().anyMatch(p -> p.toString().contains(EXACT_SCORE_ATTR_NAME))); + + var limit = as(project.child(), Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // Next: Filter[$$knn_score$1 > 0.0 AND $$knn_score$0 > 0.0] + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + + // Both sides are GreaterThan for the two score attrs + var gt1 = as(and.left(), GreaterThan.class); + var gt2 = as(and.right(), GreaterThan.class); + + ReferenceAttribute scoreAttr0 = as(gt1.left(), ReferenceAttribute.class); + assertThat(scoreAttr0.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt1.right().fold(FoldContext.small()), equalTo(0.0)); + ReferenceAttribute scoreAttr1 = as(gt2.left(), ReferenceAttribute.class); + assertThat(scoreAttr1.name(), containsString(EXACT_SCORE_ATTR_NAME)); + assertThat(gt2.right().fold(FoldContext.small()), equalTo(0.0)); + + // Next: TopN[..., 5] + var topN = as(filter.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); + assertThat(topN.order().size(), equalTo(2)); + assertThat(topN.order().get(0).child(), equalTo(scoreAttr1)); + assertThat(topN.order().get(1).child(), equalTo(scoreAttr0)); + + // Next: Eval[SCORE(EXACTNN(...)) AS $$knn_score$0, ...] + var eval = as(topN.child(), Eval.class); + assertThat(eval.fields().size(), equalTo(2)); + var alias0 = as(eval.fields().get(0), Alias.class); + assertThat(alias0.name(), equalTo(scoreAttr1.name())); + var score0 = as(alias0.child(), Score.class); + var exactNN0 = as(score0.children().getFirst(), ExactNN.class); + var field0 = as(exactNN0.field(), FieldAttribute.class); + assertThat(field0.name(), equalTo("dense_vector")); + assertThat(exactNN0.query().toString(), equalTo("[0.1, 0.2, 0.3]")); + + var alias1 = as(eval.fields().get(1), Alias.class); + assertThat(alias1.name(), equalTo(scoreAttr0.name())); + var score1 = as(alias1.child(), Score.class); + var exactNN1 = as(score1.children().getFirst(), ExactNN.class); + var field1 = as(exactNN1.field(), FieldAttribute.class); + assertThat(field1.name(), equalTo("dense_vector")); + assertThat(exactNN1.query().toString(), equalTo("[0.4, 0.5, 0.6]")); + + // Next: Filter[MATCH(...) AND LENGTH(keyword) > 10] + var prefilter = as(eval.child(), Filter.class); + var andPref = as(prefilter.condition(), And.class); + as(andPref.left(), Match.class); + var lenGt = as(andPref.right(), GreaterThan.class); + assertThat(Expressions.name(lenGt.left()), containsString("length(keyword)")); + assertThat(lenGt.right().fold(FoldContext.small()), equalTo(10)); + + // Next: EsRelation[types] + var esRelation = as(prefilter.child(), EsRelation.class); + } + + public void testKnnInDisjunctionsWithNonPushablePrefilters() { + assumeTrue("requires KNN", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()); + + // Disjunctions with non-pushable conditions as a prefilter must fail + assertThat( + typesError( + "from types | where (knn(dense_vector, [0.1, 0.2, 0.3], 10) or match(text, \"hello\")) " + "and length(keyword) > 10" + ), + containsString( + "knn function [knn(dense_vector, [0.1, 0.2, 0.3], 10)] cannot be used in an OR clause " + + "when it is being filtered with the following AND conditions: length(keyword) > 10." + ) + ); + + assertThat( + typesError( + "from types | where ((knn(dense_vector, [0.1, 0.2, 0.3], 10) and match(text, \"hello\")) or keyword == \"hello\")" + + "and (length(keyword) > 10 or long == 50)" + ), + containsString( + "knn function [knn(dense_vector, [0.1, 0.2, 0.3], 10)] cannot be used in an OR clause " + + "when it is being filtered with the following AND conditions: length(keyword) > 10 or long == 50." + ) + ); + } }