Skip to content

Commit ac7c3bc

Browse files
committed
knn functions won't appear as prefilters for other knn functions
1 parent 4c6bb4d commit ac7c3bc

File tree

3 files changed

+56
-44
lines changed

3 files changed

+56
-44
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
import org.elasticsearch.xpack.esql.core.expression.Expression;
1111
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
1212
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
13+
import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
1314
import org.elasticsearch.xpack.esql.plan.logical.Filter;
1415
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1516

1617
import java.util.ArrayList;
1718
import java.util.List;
19+
import java.util.Objects;
1820
import java.util.Stack;
1921

2022
/**
2123
* Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions.
24+
* knn functions won't contain other knn functions as a prefilter, to avoid circular dependencies.
2225
* Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to
2326
* {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))}
2427
*/
@@ -56,12 +59,12 @@ private static Expression pushConjunctionsToKnn(Expression expression, Stack<Exp
5659
}
5760
yield and.replaceChildrenSameSize(List.of(newLeft, newRight));
5861
case Knn knn:
59-
// Create a copy of the filters, and check the number of existing filters. We don't check for equality
60-
// as having two knn functions on opposite sides of the And is valid, but could lead to infinite number
61-
// of changes as the knn functions would receive the other knn functions as filters and would be constantly
62-
// updating
63-
List<Expression> newFilters = new ArrayList<>(filters);
64-
if (newFilters.size() == knn.filterExpressions().size()) {
62+
// We don't want knn expressions to have other knn expressions as a prefilter to avoid circular dependencies
63+
List<Expression> newFilters = filters.stream()
64+
.map(PushDownConjunctionsToKnnPrefilters::removeKnn)
65+
.filter(Objects::nonNull)
66+
.toList();
67+
if (newFilters.equals(knn.filterExpressions())) {
6568
yield knn;
6669
}
6770
yield knn.withFilters(newFilters);
@@ -94,4 +97,34 @@ private static Expression pushConjunctionsToKnn(Expression expression, Stack<Exp
9497

9598
return result;
9699
}
100+
101+
/**
102+
* Removes knn functions from the expression tree
103+
* @param expression expression to process
104+
* @return expression without knn functions, or null if the expression is a knn function
105+
*/
106+
private static Expression removeKnn(Expression expression) {
107+
if (expression.children().isEmpty()) {
108+
return expression;
109+
}
110+
if (expression instanceof Knn) {
111+
return null;
112+
}
113+
114+
List<Expression> filteredChildren = expression.children()
115+
.stream()
116+
.map(PushDownConjunctionsToKnnPrefilters::removeKnn)
117+
.filter(Objects::nonNull)
118+
.toList();
119+
if (filteredChildren.equals(expression.children())) {
120+
return expression;
121+
} else if (filteredChildren.isEmpty()) {
122+
return null;
123+
} else if (expression instanceof BinaryLogic && filteredChildren.size() == 1) {
124+
// Simplify an AND / OR expression to a single child
125+
return filteredChildren.getFirst();
126+
} else {
127+
return expression.replaceChildrenSameSize(filteredChildren);
128+
}
129+
}
97130
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,6 @@ public void testPushDownConjunctionsToKnnPrefilter() {
19441944
assertEquals(expectedQuery.toString(), queryExec.query().toString());
19451945
}
19461946

1947-
19481947
public void testPushDownNegatedConjunctionsToKnnPrefilter() {
19491948
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
19501949

@@ -2050,7 +2049,8 @@ public void testPushDownComplexNegationsToKnnPrefilter() {
20502049

20512050
String query = """
20522051
from test
2053-
| where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10) and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
2052+
| where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10)
2053+
and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
20542054
""";
20552055
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
20562056

@@ -2060,11 +2060,17 @@ public void testPushDownComplexNegationsToKnnPrefilter() {
20602060
var fieldExtract = as(project.child(), FieldExtractExec.class);
20612061
var queryExec = as(fieldExtract.child(), EsQueryExec.class);
20622062

2063+
QueryBuilder notKeywordQuery = wrapWithSingleQuery(
2064+
query,
2065+
unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
2066+
"keyword",
2067+
new Source(3, 12, "keyword == \"test\"")
2068+
);
20632069
QueryBuilder notKeywordFilter = wrapWithSingleQuery(
20642070
query,
20652071
unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
20662072
"keyword",
2067-
new Source(2, 74, "keyword == \"test\"")
2073+
new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))")
20682074
);
20692075

20702076
QueryBuilder notIntegerGt10 = wrapWithSingleQuery(
@@ -2075,21 +2081,13 @@ public void testPushDownComplexNegationsToKnnPrefilter() {
20752081
);
20762082

20772083
KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
2078-
KnnVectorQueryBuilder firstKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
20792084
KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
2080-
KnnVectorQueryBuilder secondKnnFilter = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
20812085

2082-
firstKnn.addFilterQuery(boolQuery()
2083-
.must(notKeywordFilter)
2084-
.must(unscore(boolQuery().mustNot(secondKnnFilter))));
2085-
2086-
secondKnn.addFilterQuery(boolQuery()
2087-
.should(firstKnnFilter)
2088-
.should(notIntegerGt10));
2086+
firstKnn.addFilterQuery(notKeywordFilter);
2087+
secondKnn.addFilterQuery(notIntegerGt10);
20892088

20902089
// Build the main boolean query structure
2091-
BoolQueryBuilder expectedQuery = boolQuery()
2092-
.must(notKeywordFilter) // NOT (keyword == "test")
2090+
BoolQueryBuilder expectedQuery = boolQuery().must(notKeywordQuery) // NOT (keyword == "test")
20932091
.must(unscore(boolQuery().mustNot(secondKnn)))
20942092
.must(boolQuery().should(firstKnn).should(notIntegerGt10));
20952093

@@ -2112,14 +2110,6 @@ public void testMultipleKnnQueriesInPrefilters() {
21122110
var queryExec = as(field.child(), EsQueryExec.class);
21132111

21142112
KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
2115-
KnnVectorQueryBuilder firstKnnQueryAsFilter = new KnnVectorQueryBuilder(
2116-
"dense_vector",
2117-
new float[] { 0, 1, 2 },
2118-
10,
2119-
null,
2120-
null,
2121-
null
2122-
);
21232113
// Integer range query (right side of first OR)
21242114
QueryBuilder integerRangeQuery = wrapWithSingleQuery(
21252115
query,
@@ -2130,14 +2120,6 @@ public void testMultipleKnnQueriesInPrefilters() {
21302120

21312121
// Second KNN query (right side of second OR)
21322122
KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
2133-
KnnVectorQueryBuilder secondKnnQueryAsFilter = new KnnVectorQueryBuilder(
2134-
"dense_vector",
2135-
new float[] { 4, 5, 6 },
2136-
10,
2137-
null,
2138-
null,
2139-
null
2140-
);
21412123

21422124
// Keyword term query (left side of second OR)
21432125
QueryBuilder keywordQuery = wrapWithSingleQuery(
@@ -2149,13 +2131,10 @@ public void testMultipleKnnQueriesInPrefilters() {
21492131

21502132
// First OR (knn1 OR integer > 10)
21512133
var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery);
2152-
var firstOrAsFilter = boolQuery().should(firstKnnQueryAsFilter).should(integerRangeQuery);
21532134
// Second OR (keyword == "test" OR knn2)
21542135
var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery);
2155-
var secondOrAsFilter = boolQuery().should(keywordQuery).should(secondKnnQueryAsFilter);
2156-
// Add prefilters to the knn queries. knn queries in prefilters don't have prefilters so we use copies of the queries
2157-
firstKnnQuery.addFilterQuery(secondOrAsFilter);
2158-
secondKnnQuery.addFilterQuery(firstOrAsFilter);
2136+
firstKnnQuery.addFilterQuery(keywordQuery);
2137+
secondKnnQuery.addFilterQuery(integerRangeQuery);
21592138

21602139
// Top-level AND combining both ORs
21612140
var expectedQuery = boolQuery().must(firstOr).must(secondOr);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8019,13 +8019,13 @@ public void testMultipleKnnQueriesInPrefilters() {
80198019
// First KNN should have the second OR as its filter
80208020
List<Expression> firstKnnFilters = firstKnn.filterExpressions();
80218021
assertThat(firstKnnFilters.size(), equalTo(1));
8022-
var secondOrWithoutFilters = secondOr.replaceChildren(List.of(secondOr.left(), secondKnn.withFilters(List.of())));
8023-
assertTrue(firstKnnFilters.contains(secondOrWithoutFilters));
8022+
var secondOrWithoutKnn = secondOr.replaceChildren(List.of(secondOr.left(), Literal.TRUE));
8023+
assertTrue(firstKnnFilters.contains(secondOrWithoutKnn));
80248024

80258025
// Second KNN should have the first OR as its filter
80268026
List<Expression> secondKnnFilters = secondKnn.filterExpressions();
80278027
assertThat(secondKnnFilters.size(), equalTo(1));
8028-
var firstOrWithoutFilters = firstOr.replaceChildren(List.of(firstKnn.withFilters(List.of()), firstOr.right()));
8028+
var firstOrWithoutFilters = firstOr.replaceChildren(List.of(Literal.TRUE, firstOr.right()));
80298029
assertTrue(secondKnnFilters.contains(firstOrWithoutFilters));
80308030
}
80318031
}

0 commit comments

Comments
 (0)