From b7a99fca3e63c299030f16b7dc3ba60b1b3052b0 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 10:25:12 +0200 Subject: [PATCH 01/22] Add prefilters to Knn --- .../esql/kibana/docs/functions/knn.md | 2 +- .../org/elasticsearch/TransportVersions.java | 1 + .../function/fulltext/FullTextFunction.java | 1 - .../esql/expression/function/vector/Knn.java | 45 ++++++++++++++++--- .../xpack/esql/querydsl/query/KnnQuery.java | 27 ++++++++++- 5 files changed, 65 insertions(+), 11 deletions(-) diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md index c7af797488ba4..f32319b080dbb 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -1,4 +1,4 @@ -% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. ### KNN Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors. diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 03fbcbc5b18fe..6cd4ad63d22f6 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -333,6 +333,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00); public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00); public static final TransportVersion ESQL_SPLIT_ON_BIG_VALUES = def(9_116_0_00); + public static final TransportVersion ESQL_KNN_FUNCTION_PREFILTER = def(9_117_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index ec29b4b658c76..e29fe681fed90 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -163,7 +163,6 @@ public boolean equals(Object obj) { @Override public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { - // In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions return Translatable.YES; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 63026fb9d7201..a2427a4577cb5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.expression.function.vector; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -70,6 +71,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes private final transient Expression k; private final Expression options; + // Expressions to be used as prefilters in knn query + private final List filterExpressions; public static final Map ALLOWED_OPTIONS = Map.ofEntries( entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), @@ -139,14 +142,23 @@ public Knn( optional = true ) Expression options ) { - this(source, field, query, k, options, null); + this(source, field, query, k, options, null, List.of()); } - private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) { + private Knn( + Source source, + Expression field, + Expression query, + Expression k, + Expression options, + QueryBuilder queryBuilder, + List filterExpressions + ) { super(source, query, expressionList(field, query, k, options), queryBuilder); this.field = field; this.k = k; this.options = options; + this.filterExpressions = filterExpressions; } private static List expressionList(Expression field, Expression query, Expression k, Expression options) { @@ -174,6 +186,10 @@ public Expression options() { return options; } + public List filterExpressions() { + return filterExpressions; + } + @Override public DataType dataType() { return DataType.BOOLEAN; @@ -257,7 +273,13 @@ protected Query translate(TranslatorHandler handler) { @Override public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), k(), options(), queryBuilder); + return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); + } + + public Expression withFilter(Expression filterExpression) { + List newFilterExpressions = new ArrayList<>(filterExpressions); + newFilterExpressions.add(filterExpression); + return new Knn(source(), field(), query(), k(), options(), queryBuilder(), List.copyOf(newFilterExpressions)); } private Map queryOptions() throws InvalidArgumentException { @@ -284,7 +306,8 @@ public Expression replaceChildren(List newChildren) { newChildren.get(1), newChildren.get(2), newChildren.size() > 3 ? newChildren.get(3) : null, - queryBuilder() + queryBuilder(), + filterExpressions() ); } @@ -303,7 +326,11 @@ private static Knn readFrom(StreamInput in) throws IOException { Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - return new Knn(source, field, query, null, null, queryBuilder); + List filterExpressions = List.of(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_KNN_FUNCTION_PREFILTER)) { + filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class); + } + return new Knn(source, field, query, null, null, queryBuilder, filterExpressions); } @Override @@ -312,6 +339,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(field()); out.writeNamedWriteable(query()); out.writeOptionalNamedWriteable(queryBuilder()); + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_KNN_FUNCTION_PREFILTER)) { + out.writeNamedWriteableCollection(filterExpressions()); + } } @Override @@ -322,12 +352,13 @@ public boolean equals(Object o) { Knn knn = (Knn) o; return Objects.equals(field(), knn.field()) && Objects.equals(query(), knn.query()) - && Objects.equals(queryBuilder(), knn.queryBuilder()); + && Objects.equals(queryBuilder(), knn.queryBuilder()) + && Objects.equals(filterExpressions(), knn.filterExpressions()); } @Override public int hashCode() { - return Objects.hash(field(), query(), queryBuilder()); + return Objects.hash(field(), query(), queryBuilder(), filterExpressions()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java index aa0e896dfc013..58c29c0c5a79c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -13,7 +13,9 @@ import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -27,15 +29,21 @@ public class KnnQuery extends Query { private final String field; private final float[] query; private final Map options; + private final List filterQueries; public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; public KnnQuery(Source source, String field, float[] query, Map options) { + this(source, field, query, options, List.of()); + } + + public KnnQuery(Source source, String field, float[] query, Map options, List filterQueries) { super(source); assert options != null; this.field = field; this.query = query; this.options = options; + this.filterQueries = new ArrayList<>(filterQueries); } @Override @@ -50,12 +58,21 @@ protected QueryBuilder asBuilder() { Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName()); KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity); + for (QueryBuilder filter : filterQueries) { + queryBuilder.addFilterQuery(filter); + } Number boost = (Number) options.get(BOOST_FIELD.getPreferredName()); if (boost != null) { queryBuilder.boost(boost.floatValue()); } return queryBuilder; } + + public KnnQuery withFilterQueries(List newFilterQueries) { + List combinedFilterQueries = new ArrayList<>(filterQueries); + combinedFilterQueries.addAll(newFilterQueries); + return new KnnQuery(source(), field, query, options, combinedFilterQueries); + } @Override protected String innerToString() { @@ -66,19 +83,25 @@ protected String innerToString() { public boolean equals(Object o) { if (super.equals(o) == false) return false; + if (o == null || getClass() != o.getClass()) return false; KnnQuery knnQuery = (KnnQuery) o; return Objects.equals(field, knnQuery.field) && Objects.deepEquals(query, knnQuery.query) - && Objects.equals(options, knnQuery.options); + && Objects.equals(options, knnQuery.options) + && Objects.equals(filterQueries, knnQuery.filterQueries); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options); + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options, filterQueries); } @Override public boolean scorable() { return true; } + + public List filterQueries() { + return filterQueries; + } } From 61a1e4bf950a2239d544dc09f68634191a5eb48b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 12:23:04 +0200 Subject: [PATCH 02/22] Add logical plan optimizer rule to add prefilters --- .../esql/expression/function/vector/Knn.java | 6 +- .../esql/optimizer/LogicalPlanOptimizer.java | 2 + .../PushDownConjunctionsToKnnPrefilters.java | 74 +++++++++++++++++++ .../optimizer/LogicalPlanOptimizerTests.java | 21 ++++++ 4 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index a2427a4577cb5..b930a3a855462 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -276,10 +276,8 @@ public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); } - public Expression withFilter(Expression filterExpression) { - List newFilterExpressions = new ArrayList<>(filterExpressions); - newFilterExpressions.add(filterExpression); - return new Knn(source(), field(), query(), k(), options(), queryBuilder(), List.copyOf(newFilterExpressions)); + public Expression withFilters(List filterExpressions) { + return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } private Map queryOptions() throws InvalidArgumentException { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 14a858f85fd2a..c5ccc52f5e4ed 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineSample; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownConjunctionsToKnnPrefilters; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan; @@ -192,6 +193,7 @@ protected static Batch operators(boolean local) { new PruneLiteralsInOrderBy(), new PushDownAndCombineLimits(), new PushDownAndCombineFilters(), + new PushDownConjunctionsToKnnPrefilters(), new PushDownAndCombineSample(), new PushDownInferencePlan(), new PushDownEval(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java new file mode 100644 index 0000000000000..de152eb3b3636 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; + +public class PushDownConjunctionsToKnnPrefilters extends OptimizerRules.OptimizerRule { + + @Override + protected LogicalPlan rule(Filter filter) { + Stack filters = new Stack<>(); + Expression condition = filter.condition(); + Expression newCondition = pushConjunctionsToKnn(condition, filters, null); + + return condition.equals(newCondition) ? filter : filter.with(newCondition); + } + + private static Expression pushConjunctionsToKnn(Expression expression, List filters, Expression addedFilter) { + if (addedFilter != null) { + filters.add(addedFilter); + } + Expression result = switch(expression) { + case And and: + Expression newLeft = pushConjunctionsToKnn(and.left(), filters, and.right()); + Expression newRight = pushConjunctionsToKnn(and.right(), filters, and.left()); + if (newLeft.equals(and.left()) && newRight.equals(and.right())) { + yield and; + } + yield and.replaceChildrenSameSize(List.of(newLeft, newRight)); + case Knn knn: + yield knn.withFilters(List.copyOf(filters)); + default: + List children = expression.children(); + boolean childrenChanged = false; + + // TODO This copies transformChildren + List transformedChildren = null; + + for (int i = 0, s = children.size(); i < s; i++) { + Expression child = children.get(i); + Expression next = pushConjunctionsToKnn(child, filters, null); + if (child.equals(next) == false) { + // lazy copy + replacement in place + if (childrenChanged == false) { + childrenChanged = true; + transformedChildren = new ArrayList<>(children); + } + transformedChildren.set(i, next); + } + } + + yield (childrenChanged ? expression.replaceChildrenSameSize(transformedChildren) : expression); + }; + + if (addedFilter != null) { + filters.removeLast(); + } + + return result; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 7795bf5c2d9ff..fe451c7d18721 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; @@ -7849,4 +7850,24 @@ public void testSampleNoPushDownChangePoint() { var topN = as(changePoint.child(), TopN.class); var source = as(topN.child(), EsRelation.class); } + + public void testPushDownConjunctionsToKnnPrefilter() { + assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + var knn = as(and.left(), Knn.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(1)); + var prefilter = as(filterExpressions.get(0), GreaterThan.class); + assertThat(and.right(), equalTo(prefilter)); + var esRelation = as(filter.child(), EsRelation.class); + } } From fc0e54b329c264a71d071d6c097a2ff005e2e52e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 12:49:18 +0200 Subject: [PATCH 03/22] Adds filters to KnnVectorQueryBuilder when translating --- .../function/fulltext/FullTextFunction.java | 4 +- .../expression/function/fulltext/Kql.java | 3 +- .../expression/function/fulltext/Match.java | 3 +- .../function/fulltext/MatchPhrase.java | 3 +- .../function/fulltext/MultiMatch.java | 3 +- .../function/fulltext/QueryString.java | 3 +- .../expression/function/fulltext/Term.java | 3 +- .../esql/expression/function/vector/Knn.java | 24 ++++++- .../xpack/esql/querydsl/query/KnnQuery.java | 8 +-- .../LocalPhysicalPlanOptimizerTests.java | 65 ++++++++++++++++++- 10 files changed, 101 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index e29fe681fed90..99c1ca236a6c0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -168,14 +168,14 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { @Override public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { - return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler); + return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler); } public QueryBuilder queryBuilder() { return queryBuilder; } - protected abstract Query translate(TranslatorHandler handler); + protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler); public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java index b373becca9965..df3cf5af84232 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery; @@ -93,7 +94,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { return new KqlQuery(source(), Objects.toString(queryAsObject())); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java index e6d99d158aaaf..743263a878552 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery; @@ -423,7 +424,7 @@ public Object queryAsObject() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = fieldAsFieldAttribute(); Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java index 4a99227576611..a7f5282fa94b4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery; @@ -278,7 +279,7 @@ public Object queryAsObject() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = fieldAsFieldAttribute(); Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java index 2c398c7f6c6f1..1178178c432fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery; @@ -335,7 +336,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { Map fieldsWithBoost = new HashMap<>(); for (Expression field : fields) { var fieldAttribute = Match.fieldAsFieldAttribute(field); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java index 4e201a17a4aec..a4c1b1f12fb56 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import java.io.IOException; @@ -345,7 +346,7 @@ protected NodeInfo info() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { return new QueryStringQuery(source(), Objects.toString(queryAsObject()), Map.of(), queryStringOptions()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java index 76188dc146ee6..cecef10a136f7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; @@ -130,7 +131,7 @@ protected TypeResolutions.ParamOrdinal queryParamOrdinal() { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { // Uses a term query that contributes to scoring return new TermQuery(source(), ((FieldAttribute) field()).name(), queryAsObject(), false, true); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index b930a3a855462..b1c0cb95dfe18 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; +import org.elasticsearch.xpack.esql.capabilities.TranslationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -34,6 +35,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.planner.TranslatorHandler; import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; @@ -252,10 +254,10 @@ private Map knnQueryOptions() throws InvalidArgumentException { } @Override - protected Query translate(TranslatorHandler handler) { + protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = Match.fieldAsFieldAttribute(field()); - Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); + Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument"); String fieldName = getNameFromFieldAttribute(fieldAttribute); @SuppressWarnings("unchecked") List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); @@ -268,7 +270,12 @@ protected Query translate(TranslatorHandler handler) { Map opts = queryOptions(); opts.put(K_FIELD.getPreferredName(), kValue); - return new KnnQuery(source(), fieldName, queryAsFloats, opts); + List filterQueries = new ArrayList<>(); + for (Expression filterExpression : filterExpressions()) { + filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); + } + + return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries); } @Override @@ -276,6 +283,17 @@ public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); } + @Override + public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { + Translatable translatable = super.translatable(pushdownPredicates); + // We need to check whether filter expressions are translatable as well + for(Expression filterExpression : filterExpressions()) { + translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates)); + } + + return translatable; + } + public Expression withFilters(List filterExpressions) { return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java index 58c29c0c5a79c..d982f7d309d54 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -33,10 +33,6 @@ public class KnnQuery extends Query { public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; - public KnnQuery(Source source, String field, float[] query, Map options) { - this(source, field, query, options, List.of()); - } - public KnnQuery(Source source, String field, float[] query, Map options, List filterQueries) { super(source); assert options != null; @@ -67,7 +63,7 @@ protected QueryBuilder asBuilder() { } return queryBuilder; } - + public KnnQuery withFilterQueries(List newFilterQueries) { List combinedFilterQueries = new ArrayList<>(filterQueries); combinedFilterQueries.addAll(newFilterQueries); @@ -100,7 +96,7 @@ public int hashCode() { public boolean scorable() { return true; } - + public List filterQueries() { return filterQueries; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 66b797afa426c..900c84b689a7f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -60,6 +60,8 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -104,6 +106,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Locale; @@ -1626,6 +1629,10 @@ private void testFullTextFunctionWithPushableConjunction(FullTextFunctionTestCas assertEquals(expected.toString(), esQuery.query().toString()); } + public void testKnn() { + testFullTextFunctionWithNonPushableDisjunction(new KnnFunctionTestCase()); + } + private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTestCase testCase) { String query = String.format(Locale.ROOT, """ from test @@ -1646,6 +1653,32 @@ private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTest assertThat(fieldExtract.child(), instanceOf(EsQueryExec.class)); } + public void testKnnPrefilters() { + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null) + .addFilterQuery(expectedFilterQueryBuilder); + var expectedQuery = boolQuery() + .must(expectedKnnQueryBuilder) + .must(expectedFilterQueryBuilder); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCase testCase) { String query = String.format(Locale.ROOT, """ from test @@ -1665,11 +1698,12 @@ private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCas } private FullTextFunctionTestCase randomFullTextFunctionTestCase() { - return switch (randomIntBetween(0, 3)) { + return switch (randomIntBetween(0, 4)) { case 0 -> new MatchFunctionTestCase(); case 1 -> new MatchOperatorTestCase(); case 2 -> new KqlFunctionTestCase(); case 3 -> new QueryStringFunctionTestCase(); + case 4 -> new KnnFunctionTestCase(); default -> throw new IllegalStateException("Unexpected value"); }; } @@ -2190,4 +2224,33 @@ public String esqlQuery() { return "qstr(\"" + fieldName() + ": " + queryString() + "\")"; } } + + private class KnnFunctionTestCase extends FullTextFunctionTestCase { + + final int k; + + KnnFunctionTestCase() { + super(Knn.class, "dense_vector", randomVector()); + k = randomIntBetween(1, 10); + } + + private static Object randomVector() { + int numDims = randomIntBetween(10, 20); + float[] vector = new float[numDims]; + for (int i = 0; i < numDims; i++) { + vector[i] = randomFloat(); + } + return vector; + } + + @Override + public QueryBuilder queryBuilder() { + return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null); + } + + @Override + public String esqlQuery() { + return "knn(" + fieldName() + ", " + Arrays.toString(((float[]) queryString())) + ", " + k + ")"; + } + } } From 6a52b63d4029628512b4946948ec2fe288dd6458 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 12:52:38 +0200 Subject: [PATCH 04/22] Update KNN capability name --- .../src/main/resources/knn-function.csv-spec | 28 +++++++++---------- .../xpack/esql/plugin/KnnFunctionIT.java | 2 +- .../xpack/esql/action/EsqlCapabilities.java | 2 +- .../esql/expression/ExpressionWritables.java | 2 +- .../elasticsearch/xpack/esql/CsvTests.java | 2 +- .../xpack/esql/analysis/VerifierTests.java | 18 ++++++------ .../function/fulltext/KnnTests.java | 2 +- .../LocalPhysicalPlanOptimizerTests.java | 3 +- 8 files changed, 29 insertions(+), 30 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index c6105b82f2300..e1756d553ebb2 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,7 +3,7 @@ # top-n query at the shard level knnSearch -required_capability: knn_function_v2 +required_capability: knn_function_v3 // tag::knn-function[] from colors metadata _score @@ -31,7 +31,7 @@ chartreuse | [127.0, 255.0, 0.0] # https://github.com/elastic/elasticsearch/issues/129550 - Add as an example to knn function documentation knnSearchWithSimilarityOption-Ignore -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) @@ -47,7 +47,7 @@ wheat | [245.0, 222.0, 179.0] ; knnHybridSearch -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | where match(color, "blue") or knn(rgb_vector, [65,105,225], 140) @@ -70,7 +70,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithMultipleFunctions -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") @@ -83,7 +83,7 @@ olive | [128.0, 128.0, 0.0] ; knnAfterKeep -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | keep rgb_vector, color, _score @@ -102,7 +102,7 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | drop primary @@ -121,7 +121,7 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -140,7 +140,7 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score @@ -161,7 +161,7 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score @@ -185,7 +185,7 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNonPushableConjunction -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | eval composed_name = locate(color, " ") > 0 @@ -210,7 +210,7 @@ maroon | false # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctions-Ignore -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 @@ -227,7 +227,7 @@ papaya whip # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors metadata _score | where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) @@ -242,7 +242,7 @@ indigo | false ; testKnnInStatsNonPushable -required_capability: knn_function_v2 +required_capability: knn_function_v3 from colors | where length(color) < 10 @@ -254,7 +254,7 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function_v2 +required_capability: knn_function_v3 required_capability: full_text_functions_in_stats_where from colors diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 61795addb1e79..29169c0b82997 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -136,7 +136,7 @@ public void testKnnWithLookupJoin() { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); var indexName = "test"; var client = client().admin().indices(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 7f8c5efb45cc1..548eca59bae52 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1203,7 +1203,7 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION_V2(Build.current().isSnapshot()), + KNN_FUNCTION_V3(Build.current().isSnapshot()), LIKE_WITH_LIST_OF_PATTERNS, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index a3f6d3a089d49..646ef30cf111c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -259,7 +259,7 @@ private static List fullText() { } private static List vector() { - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { return List.of(Knn.ENTRY); } return List.of(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 9062bdef62d76..0fade7ce4edc5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -298,7 +298,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 6555a303592f3..e263118581718 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1235,7 +1235,7 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1368,7 +1368,7 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } } @@ -1407,7 +1407,7 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1472,7 +1472,7 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1543,7 +1543,7 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2071,7 +2071,7 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2159,7 +2159,7 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); @@ -2185,7 +2185,7 @@ public void testFullTextFunctionsConstantArg() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } @@ -2216,7 +2216,7 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 4a5708b398b18..595eb58118a09 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -51,7 +51,7 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); } private static List testCaseSuppliers() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 900c84b689a7f..23d417bb9113a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -61,7 +61,6 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; -import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -1366,7 +1365,7 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); String query = """ from test From a73973f60ea60708840275e678e84b8375f7e11e Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 13:19:57 +0200 Subject: [PATCH 05/22] Add tests --- .../optimizer/LogicalPlanOptimizerTests.java | 86 ++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index fe451c7d18721..54f35e45b9d6c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -7852,7 +7852,7 @@ public void testSampleNoPushDownChangePoint() { } public void testPushDownConjunctionsToKnnPrefilter() { - assumeTrue("sample must be enabled", EsqlCapabilities.Cap.SAMPLE_V3.isEnabled()); + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); var query = """ from test @@ -7870,4 +7870,88 @@ public void testPushDownConjunctionsToKnnPrefilter() { assertThat(and.right(), equalTo(prefilter)); var esRelation = as(filter.child(), EsRelation.class); } + + public void testNotPushDownDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var or = as(filter.condition(), Or.class); + var knn = as(or.left(), Knn.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(0)); + } + + public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + and + and + or + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + keyword == "test" + or + short < 5 + double > 5.0 + */ + // Both conjunctions are pushed down to knn prefilters, disjunctions are not + var query = """ + from test + | where + ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0)) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + var leftAnd = as(and.left(), And.class); + var rightOr = as(and.right(), Or.class); + var leftOr = as(leftAnd.left(), Or.class); + var knn = as(leftOr.left(), Knn.class); + var rightOrPrefilter = as(knn.filterExpressions().get(0), Or.class); + assertThat(rightOr, equalTo(rightOrPrefilter)); + var leftAndPrefilter = as(knn.filterExpressions().get(1), Equals.class); + assertThat(leftAnd.right(), equalTo(leftAndPrefilter)); + } + + public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + or + or + and + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + keyword == "test" + and + short < 5 + double > 5.0 + */ + // Just the conjunction is pushed down to knn prefilters, disjunctions are not + var query = """ + from test + | where + ((knn(dense_vector, [0, 1, 2], 10) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0)) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var or = as(filter.condition(), Or.class); + var leftOr = as(or.left(), Or.class); + var leftAnd = as(leftOr.left(), And.class); + var knn = as(leftAnd.left(), Knn.class); + var rightAndPrefilter = as(knn.filterExpressions().get(0), GreaterThan.class); + assertThat(leftAnd.right(), equalTo(rightAndPrefilter)); + } } From ef12ee5d7c50d5d98cd2232451f7262fa98551b6 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 18:26:26 +0200 Subject: [PATCH 06/22] Avoid infinite loop with multiple knn expressions, add tests --- .../esql/expression/function/vector/Knn.java | 2 +- .../PushDownConjunctionsToKnnPrefilters.java | 35 +++- .../LocalPhysicalPlanOptimizerTests.java | 177 ++++++++++++++---- .../optimizer/LogicalPlanOptimizerTests.java | 45 +++++ 4 files changed, 220 insertions(+), 39 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index b1c0cb95dfe18..2e2d793cb90f1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -287,7 +287,7 @@ public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { Translatable translatable = super.translatable(pushdownPredicates); // We need to check whether filter expressions are translatable as well - for(Expression filterExpression : filterExpressions()) { + for (Expression filterExpression : filterExpressions()) { translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates)); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java index de152eb3b3636..570cc587352e5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java @@ -17,6 +17,11 @@ import java.util.List; import java.util.Stack; +/** + * Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions. + * Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to + * {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))} +*/ public class PushDownConjunctionsToKnnPrefilters extends OptimizerRules.OptimizerRule { @Override @@ -28,12 +33,22 @@ protected LogicalPlan rule(Filter filter) { return condition.equals(newCondition) ? filter : filter.with(newCondition); } - private static Expression pushConjunctionsToKnn(Expression expression, List filters, Expression addedFilter) { + /** + * Updates knn function prefilters. This method processes conjunctions so knn functions on one side of the conjunction receive + * the other side of the conjunction as a prefilter + * + * @param expression expression to process recursively + * @param filters current filters to apply to the expression. They contain expressions on the other side of the traversed conjunctions + * @param addedFilter a new filter to add to the list of filters for the processing + * @return the updated expression, or the original expression if it doesn't need to be updated + */ + private static Expression pushConjunctionsToKnn(Expression expression, Stack filters, Expression addedFilter) { if (addedFilter != null) { - filters.add(addedFilter); + filters.push(addedFilter); } - Expression result = switch(expression) { + Expression result = switch (expression) { case And and: + // Traverse both sides of the And, using the other side as the added filter Expression newLeft = pushConjunctionsToKnn(and.left(), filters, and.right()); Expression newRight = pushConjunctionsToKnn(and.right(), filters, and.left()); if (newLeft.equals(and.left()) && newRight.equals(and.right())) { @@ -41,12 +56,20 @@ private static Expression pushConjunctionsToKnn(Expression expression, List newFilters = new ArrayList<>(filters); + if (newFilters.size() == knn.filterExpressions().size()) { + yield knn; + } + yield knn.withFilters(newFilters); default: List children = expression.children(); boolean childrenChanged = false; - // TODO This copies transformChildren + // This copies transformChildren algorithm to avoid unnecessary changes List transformedChildren = null; for (int i = 0, s = children.size(); i < s; i++) { @@ -66,7 +89,7 @@ private static Expression pushConjunctionsToKnn(Expression expression, List 10 - """; - var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); - - var limit = as(plan, LimitExec.class); - var exchange = as(limit.child(), ExchangeExec.class); - var project = as(exchange.child(), ProjectExec.class); - var field = as(project.child(), FieldExtractExec.class); - var queryExec = as(field.child(), EsQueryExec.class); - QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( - query, - unscore(rangeQuery("integer").gt(10)), - "integer", - new Source(2, 45, "integer > 10") - ); - KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null) - .addFilterQuery(expectedFilterQueryBuilder); - var expectedQuery = boolQuery() - .must(expectedKnnQueryBuilder) - .must(expectedFilterQueryBuilder); - assertEquals(expectedQuery.toString(), queryExec.query().toString()); - } - private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCase testCase) { String query = String.format(Locale.ROOT, """ from test @@ -1697,12 +1667,11 @@ private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCas } private FullTextFunctionTestCase randomFullTextFunctionTestCase() { - return switch (randomIntBetween(0, 4)) { + return switch (randomIntBetween(0, 3)) { case 0 -> new MatchFunctionTestCase(); case 1 -> new MatchOperatorTestCase(); case 2 -> new KqlFunctionTestCase(); case 3 -> new QueryStringFunctionTestCase(); - case 4 -> new KnnFunctionTestCase(); default -> throw new IllegalStateException("Unexpected value"); }; } @@ -1861,6 +1830,150 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) { aggExec.forEachDown(EsQueryExec.class, esQueryExec -> { assertNull(esQueryExec.query()); }); } + public void testKnnPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testPushDownConjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The filter condition should be pushed down to both the KNN query and the main query + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testNotPushDownDisjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) or integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The disjunction should not be pushed down to the KNN query + KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + QueryBuilder rangeQueryBuilder = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 44, "integer > 10") + ); + + var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + + public void testMultipleKnnQueriesInPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + // Integer range query (right side of first OR) + QueryBuilder integerRangeQuery = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 45, "integer > 10") + ); + + // Second KNN query (right side of second OR) + KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + + // Keyword term query (left side of second OR) + QueryBuilder keywordQuery = wrapWithSingleQuery( + query, + unscore(termQuery("keyword", "test")), + "keyword", + new Source(2, 87, "keyword == \"test\"") + ); + + // First OR (knn1 OR integer > 10) + var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery); + // Second OR (keyword == "test" OR knn2) + var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery.addFilterQuery(firstOr)); + firstKnnQuery.addFilterQuery(secondOr); + + // Top-level AND combining both ORs + var expectedQuery = boolQuery().must(firstOr).must(secondOr); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + public void testParallelizeTimeSeriesPlan() { assumeTrue("requires snapshot builds", Build.current().isSnapshot()); var query = "TS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)"; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 54f35e45b9d6c..d754c5232ef87 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -7954,4 +7954,49 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() { var rightAndPrefilter = as(knn.filterExpressions().get(0), GreaterThan.class); assertThat(leftAnd.right(), equalTo(rightAndPrefilter)); } + + public void testMultipleKnnQueriesInPrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + /* + and + or + knn(dense_vector, [0, 1, 2], 10) + integer > 10 + or + keyword == "test" + knn(dense_vector, [4, 5, 6], 10) + */ + var query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var and = as(filter.condition(), And.class); + + // First OR (knn1 OR integer > 10) + var firstOr = as(and.left(), Or.class); + var firstKnn = as(firstOr.left(), Knn.class); + var integerGt = as(firstOr.right(), GreaterThan.class); + + // Second OR (keyword == "test" OR knn2) + var secondOr = as(and.right(), Or.class); + var keywordEq = as(secondOr.left(), Equals.class); + var secondKnn = as(secondOr.right(), Knn.class); + + // First KNN should have the second OR as its filter + List firstKnnFilters = firstKnn.filterExpressions(); + assertThat(firstKnnFilters.size(), equalTo(1)); + var secondOrWithoutFilters = secondOr.replaceChildren(List.of(secondOr.left(), secondKnn.withFilters(List.of()))); + assertTrue(firstKnnFilters.contains(secondOrWithoutFilters)); + + // Second KNN should have the first OR as its filter + List secondKnnFilters = secondKnn.filterExpressions(); + assertThat(secondKnnFilters.size(), equalTo(1)); + var firstOrWithoutFilters = firstOr.replaceChildren(List.of(firstKnn.withFilters(List.of()), firstOr.right())); + assertTrue(secondKnnFilters.contains(firstOrWithoutFilters)); + } } From 616b725ffca920c3e34449c68a226185a0b12ef3 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 9 Jul 2025 19:08:19 +0200 Subject: [PATCH 07/22] Add test for not pushing down knn when there are non-pushable prefilters --- .../LocalPhysicalPlanOptimizerTests.java | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index e1a9d602bd009..b71a79b8d6459 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -1928,6 +1929,40 @@ public void testNotPushDownDisjunctionsToKnnPrefilter() { assertEquals(expectedQuery.toString(), queryExec.query().toString()); } + public void testNotPushDownKnnWithNonPushablePrefilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) AND integer > 10) and ((keyword == "test") or length(text) > 10)) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var secondLimit = as(field.child(), LimitExec.class); + var filter = as(secondLimit.child(), FilterExec.class); + var and = as(filter.condition(), And.class); + var knn = as(and.left(), Knn.class); + assertEquals("(keyword == \"test\") or length(text) > 10", knn.filterExpressions().get(0).toString()); + assertEquals("integer > 10", knn.filterExpressions().get(1).toString()); + + var fieldExtract = as(filter.child(), FieldExtractExec.class); + var queryExec = as(fieldExtract.child(), EsQueryExec.class); + + // The query should only contain the pushable condition + QueryBuilder integerGtQuery = wrapWithSingleQuery( + query, + unscore(rangeQuery("integer").gt(10)), + "integer", + new Source(2, 47, "integer > 10") + ); + + assertEquals(integerGtQuery.toString(), queryExec.query().toString()); + } + public void testMultipleKnnQueriesInPrefilters() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); From 3ae0e9bcf9d027e4f937629cd733c53341ac243f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 10 Jul 2025 10:00:40 +0200 Subject: [PATCH 08/22] Execute premapper after logical optimisation, to allow query builders to be generated with prefilters --- .../xpack/esql/plugin/KnnFunctionIT.java | 22 ++++++++++++++++++- .../xpack/esql/session/EsqlSession.java | 5 +++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 29169c0b82997..90e7612d2e49c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -113,7 +113,28 @@ public void testKnnNonPushedDown() { assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); } } + public void testKnnWithPrefilters() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, 5) AND id > 5 + | KEEP id, floats, _score, vector + | SORT _score DESC + | LIMIT 5 + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + List> valuesList = EsqlTestUtils.getValuesList(resp); + // K = 5, 1 more for every id > 10 + assertEquals(5, valuesList.size()); + } + } public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; Arrays.fill(queryVector, 1.0f); @@ -202,6 +223,5 @@ private void createAndPopulateLookupIndex(IndicesAdminClient client, String look var createRequest = client.prepareCreate(lookupIndexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index b6dd0c40f3481..044c423867b37 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -198,9 +198,10 @@ public void execute(EsqlQueryRequest request, EsqlExecutionInfo executionInfo, P analyzedPlan(parsed, executionInfo, request.filter(), new EsqlCCSUtils.CssPartialErrorsActionListener(executionInfo, listener) { @Override public void onResponse(LogicalPlan analyzedPlan) { + LogicalPlan optimizedPlan = optimizedPlan(analyzedPlan); preMapper.preMapper( - analyzedPlan, - listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, optimizedPlan(p), l)) + optimizedPlan, + listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, p, l)) ); } }); From 50d6e195a24cf97ae93aa42bf1daf4bd2e216fa7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 10 Jul 2025 12:26:00 +0200 Subject: [PATCH 09/22] Fix generating QueryBuilder prefilters for knn when they cannot be pushed down --- .../src/main/resources/knn-function.csv-spec | 17 ++++++++ .../esql/expression/function/vector/Knn.java | 41 +++++++++++-------- .../LocalPhysicalPlanOptimizerTests.java | 28 +++++++++++-- 3 files changed, 65 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index e1756d553ebb2..5812bd8ad2005 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -266,3 +266,20 @@ c: long | primary: boolean 41 | false 9 | true ; + +testKnnUsesPrefiltering +required_capability: knn_function_v3 + +from colors metadata _score +| where knn(rgb_vector, [255, 0, 0], 5) and primary == true +| sort _score desc, color asc +| keep color +; + +color:text +red +gray +black +magenta +yellow +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 2e2d793cb90f1..b0c4dc7dce582 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -253,6 +253,22 @@ private Map knnQueryOptions() throws InvalidArgumentException { return matchOptions; } + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); + } + + @Override + public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { + Translatable translatable = super.translatable(pushdownPredicates); + // We need to check whether filter expressions are translatable as well + for (Expression filterExpression : filterExpressions()) { + translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates)); + } + + return translatable; + } + @Override protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) { var fieldAttribute = Match.fieldAsFieldAttribute(field()); @@ -272,28 +288,19 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato List filterQueries = new ArrayList<>(); for (Expression filterExpression : filterExpressions()) { - filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); + if (filterExpression instanceof TranslationAware translationAware) { + // We can only translate filter expressions that are translatable. In case any is not translatable, + // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them + // when creating an evaluator for the non-pushed down query + if (translationAware.translatable(pushdownPredicates) == Translatable.YES) { + filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder()); + } + } } return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries); } - @Override - public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions()); - } - - @Override - public Translatable translatable(LucenePushdownPredicates pushdownPredicates) { - Translatable translatable = super.translatable(pushdownPredicates); - // We need to check whether filter expressions are translatable as well - for (Expression filterExpression : filterExpressions()) { - translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates)); - } - - return translatable; - } - public Expression withFilters(List filterExpressions) { return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index b71a79b8d6459..d0e97270f53e3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1979,30 +1979,50 @@ public void testMultipleKnnQueriesInPrefilters() { var queryExec = as(field.child(), EsQueryExec.class); KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder firstKnnQueryAsFilter = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ); // Integer range query (right side of first OR) QueryBuilder integerRangeQuery = wrapWithSingleQuery( query, unscore(rangeQuery("integer").gt(10)), "integer", - new Source(2, 45, "integer > 10") + new Source(2, 46, "integer > 10") ); // Second KNN query (right side of second OR) KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + KnnVectorQueryBuilder secondKnnQueryAsFilter = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 4, 5, 6 }, + 10, + null, + null, + null + ); // Keyword term query (left side of second OR) QueryBuilder keywordQuery = wrapWithSingleQuery( query, unscore(termQuery("keyword", "test")), "keyword", - new Source(2, 87, "keyword == \"test\"") + new Source(2, 66, "keyword == \"test\"") ); // First OR (knn1 OR integer > 10) var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery); + var firstOrAsFilter = boolQuery().should(firstKnnQueryAsFilter).should(integerRangeQuery); // Second OR (keyword == "test" OR knn2) - var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery.addFilterQuery(firstOr)); - firstKnnQuery.addFilterQuery(secondOr); + var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery); + var secondOrAsFilter = boolQuery().should(keywordQuery).should(secondKnnQueryAsFilter); + // Add prefilters to the knn queries. knn queries in prefilters don't have prefilters so we use copies of the queries + firstKnnQuery.addFilterQuery(secondOrAsFilter); + secondKnnQuery.addFilterQuery(firstOrAsFilter); // Top-level AND combining both ORs var expectedQuery = boolQuery().must(firstOr).must(secondOr); From 32715cef4ba8673b14c12e810bf04762bd2e2376 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 10 Jul 2025 13:31:27 +0200 Subject: [PATCH 10/22] Update CSV tests now that knn uses prefilters --- .../src/main/resources/knn-function.csv-spec | 29 ++++--------------- .../xpack/esql/plugin/KnnFunctionIT.java | 2 ++ 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index 5812bd8ad2005..6f16b9400fcab 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -11,7 +11,6 @@ from colors metadata _score | sort _score desc, color asc // end::knn-function[] | keep color, rgb_vector -| limit 10 ; // tag::knn-function-result[] @@ -50,11 +49,10 @@ knnHybridSearch required_capability: knn_function_v3 from colors metadata _score -| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140) +| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10) | where primary == true | sort _score desc, color asc | keep color, rgb_vector -| limit 10 ; color:text | rgb_vector:dense_vector @@ -69,17 +67,18 @@ red | [255.0, 0.0, 0.0] yellow | [255.0, 255.0, 0.0] ; -knnWithMultipleFunctions +knnWithPrefilter required_capability: knn_function_v3 from colors metadata _score -| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") +| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green")) | sort _score desc, color asc | keep color, rgb_vector ; color:text | rgb_vector:dense_vector olive | [128.0, 128.0, 0.0] +green | [0.0, 128.0, 0.0] ; knnAfterKeep @@ -163,9 +162,8 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] knnWithDisjunctionAndFiltersConjunction required_capability: knn_function_v3 -# TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true +| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true | keep color, rgb_vector, _score | sort _score desc, color asc | drop _score @@ -266,20 +264,3 @@ c: long | primary: boolean 41 | false 9 | true ; - -testKnnUsesPrefiltering -required_capability: knn_function_v3 - -from colors metadata _score -| where knn(rgb_vector, [255, 0, 0], 5) and primary == true -| sort _score desc, color asc -| keep color -; - -color:text -red -gray -black -magenta -yellow -; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 90e7612d2e49c..d7ac65d5a31ab 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -113,6 +113,7 @@ public void testKnnNonPushedDown() { assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); } } + public void testKnnWithPrefilters() { float[] queryVector = new float[numDims]; Arrays.fill(queryVector, 1.0f); @@ -135,6 +136,7 @@ public void testKnnWithPrefilters() { assertEquals(5, valuesList.size()); } } + public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; Arrays.fill(queryVector, 1.0f); From 62b0d1cb3bc747d2e446db98928ad0f567c57b2d Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 13:07:00 +0200 Subject: [PATCH 11/22] Remove TransportVersion as knn is not released yet --- .../main/java/org/elasticsearch/TransportVersions.java | 1 - .../xpack/esql/expression/function/vector/Knn.java | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 4a17aca737f75..35f5423df0ffb 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -337,7 +337,6 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_TYPE = def(9_118_0_00); public static final TransportVersion ESQL_FIXED_INDEX_LIKE = def(9_119_0_00); public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00); - public static final TransportVersion ESQL_KNN_FUNCTION_PREFILTER = def(9_121_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index b0c4dc7dce582..387eb8323293b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.expression.function.vector; -import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -349,10 +348,7 @@ private static Knn readFrom(StreamInput in) throws IOException { Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - List filterExpressions = List.of(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_KNN_FUNCTION_PREFILTER)) { - filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class); - } + List filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class); return new Knn(source, field, query, null, null, queryBuilder, filterExpressions); } @@ -362,9 +358,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(field()); out.writeNamedWriteable(query()); out.writeOptionalNamedWriteable(queryBuilder()); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_KNN_FUNCTION_PREFILTER)) { - out.writeNamedWriteableCollection(filterExpressions()); - } + out.writeNamedWriteableCollection(filterExpressions()); } @Override From b347bbacf702d94e2c53b4eabeadaadc08159a65 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 13:11:29 +0200 Subject: [PATCH 12/22] Add queryBuilder and filterExpresions to NodeInfo --- .../xpack/esql/expression/function/vector/Knn.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index 387eb8323293b..d314d1ae43323 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -335,7 +335,7 @@ public Expression replaceChildren(List newChildren) { @Override protected NodeInfo info() { - return NodeInfo.create(this, Knn::new, field(), query(), k(), options()); + return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions()); } @Override From 10837649f06679c70fb4a19b1249b6f3964ca355 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 13:12:23 +0200 Subject: [PATCH 13/22] Remove unnecessary method --- .../elasticsearch/xpack/esql/querydsl/query/KnnQuery.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java index d470674a8a6ec..b218b897121df 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -64,12 +64,6 @@ protected QueryBuilder asBuilder() { return queryBuilder; } - public KnnQuery withFilterQueries(List newFilterQueries) { - List combinedFilterQueries = new ArrayList<>(filterQueries); - combinedFilterQueries.addAll(newFilterQueries); - return new KnnQuery(source(), field, query, options, combinedFilterQueries); - } - @Override protected String innerToString() { return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))"; From 6e4c1a910d4b0dd3a937803da1e416ebc4766fff Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 16:20:49 +0200 Subject: [PATCH 14/22] Fix test for node info --- .../esql/expression/function/EsqlFunctionRegistry.java | 6 +++++- .../xpack/esql/expression/function/vector/Knn.java | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 630c9c2008a13..682015b511b8a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -478,7 +478,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(LastOverTime.class, uni(LastOverTime::new), "last_over_time"), def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"), def(Term.class, bi(Term::new), "term"), - def(Knn.class, Knn::new, "knn"), + def(Knn.class, quad(Knn::new), "knn"), def(StGeohash.class, StGeohash::new, "st_geohash"), def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"), def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"), @@ -1204,4 +1204,8 @@ private static TernaryBuilder tri(TernaryBuilder func return function; } + private static QuaternaryBuilder quad(QuaternaryBuilder function) { + return function; + } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index d314d1ae43323..61528521c3749 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -146,7 +146,7 @@ public Knn( this(source, field, query, k, options, null, List.of()); } - private Knn( + public Knn( Source source, Expression field, Expression query, From 91d4a0a7196aece309a6d32e5768478cdd345b24 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 16:26:27 +0200 Subject: [PATCH 15/22] Fix typo --- .../rules/logical/PushDownConjunctionsToKnnPrefilters.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java index 570cc587352e5..0f3ffcca77629 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java @@ -58,7 +58,7 @@ private static Expression pushConjunctionsToKnn(Expression expression, Stack newFilters = new ArrayList<>(filters); if (newFilters.size() == knn.filterExpressions().size()) { From a99e75f1dcce49c289b79b241c928b6f4750af7c Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 16:36:58 +0200 Subject: [PATCH 16/22] Add multiple filters test --- .../LocalPhysicalPlanOptimizerTests.java | 36 +++++++++++++++++++ .../optimizer/LogicalPlanOptimizerTests.java | 23 ++++++++++++ 2 files changed, 59 insertions(+) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 3c76f61ffa8f1..7a038f8bb80f0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1871,6 +1871,42 @@ public void testKnnPrefilters() { assertEquals(expectedQuery.toString(), queryExec.query().toString()); } + public void testKnnPrefiltersWithMultipleFilters() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) + | where integer > 10 + | where keyword == "test" + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + var integerFilter = wrapWithSingleQuery(query, unscore(rangeQuery("integer").gt(10)), "integer", new Source(3, 8, "integer > 10")); + var keywordFilter = wrapWithSingleQuery( + query, + unscore(termQuery("keyword", "test")), + "keyword", + new Source(4, 8, "keyword == \"test\"") + ); + QueryBuilder expectedFilterQueryBuilder = boolQuery().must(integerFilter).must(keywordFilter); + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(integerFilter).must(keywordFilter); + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + public void testPushDownConjunctionsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index f7e7a29db8806..320bd61a719d7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -7877,6 +7877,29 @@ public void testPushDownConjunctionsToKnnPrefilter() { var esRelation = as(filter.child(), EsRelation.class); } + public void testPushDownMultipleFiltersToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + var query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) + | where integer > 10 + | where keyword == "test" + """; + var optimized = planTypes(query); + + var limit = as(optimized, Limit.class); + var filter = as(limit.child(), Filter.class); + var firstAnd = as(filter.condition(), And.class); + var knn = as(firstAnd.left(), Knn.class); + var prefilterAnd = as(firstAnd.right(), And.class); + as(prefilterAnd.left(), GreaterThan.class); + as(prefilterAnd.right(), Equals.class); + List filterExpressions = knn.filterExpressions(); + assertThat(filterExpressions.size(), equalTo(1)); + assertThat(prefilterAnd, equalTo(filterExpressions.get(0))); + } + public void testNotPushDownDisjunctionsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); From 4c6bb4dc26d19c8664d3e851312254775345e5c5 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Fri, 11 Jul 2025 18:03:07 +0200 Subject: [PATCH 17/22] Add negations tests --- .../src/main/resources/knn-function.csv-spec | 50 ++++++++++- .../LocalPhysicalPlanOptimizerTests.java | 89 +++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index 6f16b9400fcab..670469187cea4 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -81,6 +81,29 @@ olive | [128.0, 128.0, 0.0] green | [0.0, 128.0, 0.0] ; +knnWithNegatedPrefilter +required_capability: knn_function_v3 + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) +| sort _score desc, color asc +| keep color, rgb_vector +| LIMIT 10 +; + +color:text | rgb_vector:dense_vector +sienna | [160.0, 82.0, 45.0] +peru | [205.0, 133.0, 63.0] +golden rod | [218.0, 165.0, 32.0] +brown | [165.0, 42.0, 42.0] +firebrick | [178.0, 34.0, 34.0] +chartreuse | [127.0, 255.0, 0.0] +gray | [128.0, 128.0, 128.0] +green | [0.0, 128.0, 0.0] +maroon | [128.0, 0.0, 0.0] +orange | [255.0, 165.0, 0.0] +; + knnAfterKeep required_capability: knn_function_v3 @@ -141,12 +164,10 @@ golden rod | true knnWithConjunction required_capability: knn_function_v3 -# TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*" +| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" | sort _score desc, color asc | keep color, hex_code, rgb_vector -| limit 10 ; color:text | hex_code:keyword | rgb_vector:dense_vector @@ -182,6 +203,29 @@ red | [255.0, 0.0, 0.0] yellow | [255.0, 255.0, 0.0] ; +knnWithNegationsAndFiltersConjunction +required_capability: knn_function_v3 + +from colors metadata _score +| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +turquoise | [64.0, 224.0, 208.0] +aqua marine | [127.0, 255.0, 212.0] +teal | [0.0, 128.0, 128.0] +silver | [192.0, 192.0, 192.0] +gray | [128.0, 128.0, 128.0] +gainsboro | [220.0, 220.0, 220.0] +thistle | [216.0, 191.0, 216.0] +lavender | [230.0, 230.0, 250.0] +azure | [240.0, 255.0, 255.0] +; + knnWithNonPushableConjunction required_capability: knn_function_v3 diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 7a038f8bb80f0..4bb05c7965d3c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1944,6 +1944,44 @@ public void testPushDownConjunctionsToKnnPrefilter() { assertEquals(expectedQuery.toString(), queryExec.query().toString()); } + + public void testPushDownNegatedConjunctionsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where knn(dense_vector, [0, 1, 2], 10) and NOT integer > 10 + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var queryExec = as(field.child(), EsQueryExec.class); + + // The filter condition should be pushed down to both the KNN query and the main query + QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), + "integer", + new Source(2, 45, "NOT integer > 10") + ); + + KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0, 1, 2 }, + 10, + null, + null, + null + ).addFilterQuery(expectedFilterQueryBuilder); + + var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + public void testNotPushDownDisjunctionsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); @@ -2007,6 +2045,57 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { assertEquals(integerGtQuery.toString(), queryExec.query().toString()); } + public void testPushDownComplexNegationsToKnnPrefilter() { + assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); + + String query = """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; + var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var queryExec = as(fieldExtract.child(), EsQueryExec.class); + + QueryBuilder notKeywordFilter = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), + "keyword", + new Source(2, 74, "keyword == \"test\"") + ); + + QueryBuilder notIntegerGt10 = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))), + "integer", + new Source(2, 46, "NOT integer > 10") + ); + + KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder firstKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); + KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + KnnVectorQueryBuilder secondKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); + + firstKnn.addFilterQuery(boolQuery() + .must(notKeywordFilter) + .must(unscore(boolQuery().mustNot(secondKnnFilter)))); + + secondKnn.addFilterQuery(boolQuery() + .should(firstKnnFilter) + .should(notIntegerGt10)); + + // Build the main boolean query structure + BoolQueryBuilder expectedQuery = boolQuery() + .must(notKeywordFilter) // NOT (keyword == "test") + .must(unscore(boolQuery().mustNot(secondKnn))) + .must(boolQuery().should(firstKnn).should(notIntegerGt10)); + + assertEquals(expectedQuery.toString(), queryExec.query().toString()); + } + public void testMultipleKnnQueriesInPrefilters() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); From 47a6d996335f97aca3311e5216b462f29b2ec2ba Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 11 Jul 2025 16:12:07 +0000 Subject: [PATCH 18/22] [CI] Auto commit changes from spotless --- .../LocalPhysicalPlanOptimizerTests.java | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 4bb05c7965d3c..ca6ab2ff90ce3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1944,7 +1944,6 @@ public void testPushDownConjunctionsToKnnPrefilter() { assertEquals(expectedQuery.toString(), queryExec.query().toString()); } - public void testPushDownNegatedConjunctionsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); @@ -2048,10 +2047,11 @@ public void testNotPushDownKnnWithNonPushablePrefilters() { public void testPushDownComplexNegationsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); - String query = """ - from test - | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) - """; + String query = + """ + from test + | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); var limit = as(plan, LimitExec.class); @@ -2079,17 +2079,12 @@ public void testPushDownComplexNegationsToKnnPrefilter() { KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); KnnVectorQueryBuilder secondKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); - firstKnn.addFilterQuery(boolQuery() - .must(notKeywordFilter) - .must(unscore(boolQuery().mustNot(secondKnnFilter)))); + firstKnn.addFilterQuery(boolQuery().must(notKeywordFilter).must(unscore(boolQuery().mustNot(secondKnnFilter)))); - secondKnn.addFilterQuery(boolQuery() - .should(firstKnnFilter) - .should(notIntegerGt10)); + secondKnn.addFilterQuery(boolQuery().should(firstKnnFilter).should(notIntegerGt10)); // Build the main boolean query structure - BoolQueryBuilder expectedQuery = boolQuery() - .must(notKeywordFilter) // NOT (keyword == "test") + BoolQueryBuilder expectedQuery = boolQuery().must(notKeywordFilter) // NOT (keyword == "test") .must(unscore(boolQuery().mustNot(secondKnn))) .must(boolQuery().should(firstKnn).should(notIntegerGt10)); From ac7c3bcce7c9532dc325c8799c4fda9a9159c02f Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 14 Jul 2025 13:21:36 +0200 Subject: [PATCH 19/22] knn functions won't appear as prefilters for other knn functions --- .../PushDownConjunctionsToKnnPrefilters.java | 45 ++++++++++++++--- .../LocalPhysicalPlanOptimizerTests.java | 49 ++++++------------- .../optimizer/LogicalPlanOptimizerTests.java | 6 +-- 3 files changed, 56 insertions(+), 44 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java index 0f3ffcca77629..aa4bb203b4346 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java @@ -10,15 +10,18 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Stack; /** * Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions. + * knn functions won't contain other knn functions as a prefilter, to avoid circular dependencies. * Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to * {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))} */ @@ -56,12 +59,12 @@ private static Expression pushConjunctionsToKnn(Expression expression, Stack newFilters = new ArrayList<>(filters); - if (newFilters.size() == knn.filterExpressions().size()) { + // We don't want knn expressions to have other knn expressions as a prefilter to avoid circular dependencies + List newFilters = filters.stream() + .map(PushDownConjunctionsToKnnPrefilters::removeKnn) + .filter(Objects::nonNull) + .toList(); + if (newFilters.equals(knn.filterExpressions())) { yield knn; } yield knn.withFilters(newFilters); @@ -94,4 +97,34 @@ private static Expression pushConjunctionsToKnn(Expression expression, Stack filteredChildren = expression.children() + .stream() + .map(PushDownConjunctionsToKnnPrefilters::removeKnn) + .filter(Objects::nonNull) + .toList(); + if (filteredChildren.equals(expression.children())) { + return expression; + } else if (filteredChildren.isEmpty()) { + return null; + } else if (expression instanceof BinaryLogic && filteredChildren.size() == 1) { + // Simplify an AND / OR expression to a single child + return filteredChildren.getFirst(); + } else { + return expression.replaceChildrenSameSize(filteredChildren); + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 4bb05c7965d3c..a604e1d26d313 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1944,7 +1944,6 @@ public void testPushDownConjunctionsToKnnPrefilter() { assertEquals(expectedQuery.toString(), queryExec.query().toString()); } - public void testPushDownNegatedConjunctionsToKnnPrefilter() { assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()); @@ -2050,7 +2049,8 @@ public void testPushDownComplexNegationsToKnnPrefilter() { String query = """ from test - | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) + | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) + and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10))) """; var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json")); @@ -2060,11 +2060,17 @@ public void testPushDownComplexNegationsToKnnPrefilter() { var fieldExtract = as(project.child(), FieldExtractExec.class); var queryExec = as(fieldExtract.child(), EsQueryExec.class); + QueryBuilder notKeywordQuery = wrapWithSingleQuery( + query, + unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), + "keyword", + new Source(3, 12, "keyword == \"test\"") + ); QueryBuilder notKeywordFilter = wrapWithSingleQuery( query, unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))), "keyword", - new Source(2, 74, "keyword == \"test\"") + new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))") ); QueryBuilder notIntegerGt10 = wrapWithSingleQuery( @@ -2075,21 +2081,13 @@ public void testPushDownComplexNegationsToKnnPrefilter() { ); KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); - KnnVectorQueryBuilder firstKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); - KnnVectorQueryBuilder secondKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); - firstKnn.addFilterQuery(boolQuery() - .must(notKeywordFilter) - .must(unscore(boolQuery().mustNot(secondKnnFilter)))); - - secondKnn.addFilterQuery(boolQuery() - .should(firstKnnFilter) - .should(notIntegerGt10)); + firstKnn.addFilterQuery(notKeywordFilter); + secondKnn.addFilterQuery(notIntegerGt10); // Build the main boolean query structure - BoolQueryBuilder expectedQuery = boolQuery() - .must(notKeywordFilter) // NOT (keyword == "test") + BoolQueryBuilder expectedQuery = boolQuery().must(notKeywordQuery) // NOT (keyword == "test") .must(unscore(boolQuery().mustNot(secondKnn))) .must(boolQuery().should(firstKnn).should(notIntegerGt10)); @@ -2112,14 +2110,6 @@ public void testMultipleKnnQueriesInPrefilters() { var queryExec = as(field.child(), EsQueryExec.class); KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null); - KnnVectorQueryBuilder firstKnnQueryAsFilter = new KnnVectorQueryBuilder( - "dense_vector", - new float[] { 0, 1, 2 }, - 10, - null, - null, - null - ); // Integer range query (right side of first OR) QueryBuilder integerRangeQuery = wrapWithSingleQuery( query, @@ -2130,14 +2120,6 @@ public void testMultipleKnnQueriesInPrefilters() { // Second KNN query (right side of second OR) KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null); - KnnVectorQueryBuilder secondKnnQueryAsFilter = new KnnVectorQueryBuilder( - "dense_vector", - new float[] { 4, 5, 6 }, - 10, - null, - null, - null - ); // Keyword term query (left side of second OR) QueryBuilder keywordQuery = wrapWithSingleQuery( @@ -2149,13 +2131,10 @@ public void testMultipleKnnQueriesInPrefilters() { // First OR (knn1 OR integer > 10) var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery); - var firstOrAsFilter = boolQuery().should(firstKnnQueryAsFilter).should(integerRangeQuery); // Second OR (keyword == "test" OR knn2) var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery); - var secondOrAsFilter = boolQuery().should(keywordQuery).should(secondKnnQueryAsFilter); - // Add prefilters to the knn queries. knn queries in prefilters don't have prefilters so we use copies of the queries - firstKnnQuery.addFilterQuery(secondOrAsFilter); - secondKnnQuery.addFilterQuery(firstOrAsFilter); + firstKnnQuery.addFilterQuery(keywordQuery); + secondKnnQuery.addFilterQuery(integerRangeQuery); // Top-level AND combining both ORs var expectedQuery = boolQuery().must(firstOr).must(secondOr); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 320bd61a719d7..f3c9569502d25 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -8019,13 +8019,13 @@ public void testMultipleKnnQueriesInPrefilters() { // First KNN should have the second OR as its filter List firstKnnFilters = firstKnn.filterExpressions(); assertThat(firstKnnFilters.size(), equalTo(1)); - var secondOrWithoutFilters = secondOr.replaceChildren(List.of(secondOr.left(), secondKnn.withFilters(List.of()))); - assertTrue(firstKnnFilters.contains(secondOrWithoutFilters)); + var secondOrWithoutKnn = secondOr.replaceChildren(List.of(secondOr.left(), Literal.TRUE)); + assertTrue(firstKnnFilters.contains(secondOrWithoutKnn)); // Second KNN should have the first OR as its filter List secondKnnFilters = secondKnn.filterExpressions(); assertThat(secondKnnFilters.size(), equalTo(1)); - var firstOrWithoutFilters = firstOr.replaceChildren(List.of(firstKnn.withFilters(List.of()), firstOr.right())); + var firstOrWithoutFilters = firstOr.replaceChildren(List.of(Literal.TRUE, firstOr.right())); assertTrue(secondKnnFilters.contains(firstOrWithoutFilters)); } } From dd29fecad3d0258083794d176eef614b7b8f3da9 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 14 Jul 2025 15:20:13 +0200 Subject: [PATCH 20/22] Fix test --- .../xpack/esql/optimizer/LogicalPlanOptimizerTests.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index f3c9569502d25..e301c1610bd7b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -8013,19 +8013,17 @@ public void testMultipleKnnQueriesInPrefilters() { // Second OR (keyword == "test" OR knn2) var secondOr = as(and.right(), Or.class); - var keywordEq = as(secondOr.left(), Equals.class); + as(secondOr.left(), Equals.class); var secondKnn = as(secondOr.right(), Knn.class); // First KNN should have the second OR as its filter List firstKnnFilters = firstKnn.filterExpressions(); assertThat(firstKnnFilters.size(), equalTo(1)); - var secondOrWithoutKnn = secondOr.replaceChildren(List.of(secondOr.left(), Literal.TRUE)); - assertTrue(firstKnnFilters.contains(secondOrWithoutKnn)); + assertTrue(firstKnnFilters.contains(secondOr.left())); // Second KNN should have the first OR as its filter List secondKnnFilters = secondKnn.filterExpressions(); assertThat(secondKnnFilters.size(), equalTo(1)); - var firstOrWithoutFilters = firstOr.replaceChildren(List.of(Literal.TRUE, firstOr.right())); - assertTrue(secondKnnFilters.contains(firstOrWithoutFilters)); + assertTrue(secondKnnFilters.contains(firstOr.right())); } } From b03f2a340709b9aa49256c4889a15d904e001c8c Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Mon, 14 Jul 2025 16:21:55 +0200 Subject: [PATCH 21/22] Fix test --- .../java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index d7ac65d5a31ab..9ae1c980337f1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -186,7 +186,7 @@ public void setup() throws IOException { var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(10, 20); + numDocs = randomIntBetween(15, 25); numDims = randomIntBetween(3, 10); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; float value = 0.0f; From 8a550a1c4af595f89c34bda5890e51c987dea95b Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Wed, 16 Jul 2025 08:36:34 +0200 Subject: [PATCH 22/22] Fix merge --- .../xpack/esql/expression/function/vector/VectorWritables.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java index f1bf291b7715e..a4274bf28de4b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java @@ -27,7 +27,7 @@ private VectorWritables() { public static List getNamedWritables() { List entries = new ArrayList<>(); - if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) { entries.add(Knn.ENTRY); } if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {