Skip to content

Commit ef12ee5

Browse files
committed
Avoid infinite loop with multiple knn expressions, add tests
1 parent a73973f commit ef12ee5

File tree

4 files changed

+220
-39
lines changed

4 files changed

+220
-39
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
287287
public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
288288
Translatable translatable = super.translatable(pushdownPredicates);
289289
// We need to check whether filter expressions are translatable as well
290-
for(Expression filterExpression : filterExpressions()) {
290+
for (Expression filterExpression : filterExpressions()) {
291291
translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates));
292292
}
293293

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
import java.util.List;
1818
import java.util.Stack;
1919

20+
/**
21+
* Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions.
22+
* Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to
23+
* {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))}
24+
*/
2025
public class PushDownConjunctionsToKnnPrefilters extends OptimizerRules.OptimizerRule<Filter> {
2126

2227
@Override
@@ -28,25 +33,43 @@ protected LogicalPlan rule(Filter filter) {
2833
return condition.equals(newCondition) ? filter : filter.with(newCondition);
2934
}
3035

31-
private static Expression pushConjunctionsToKnn(Expression expression, List<Expression> filters, Expression addedFilter) {
36+
/**
37+
* Updates knn function prefilters. This method processes conjunctions so knn functions on one side of the conjunction receive
38+
* the other side of the conjunction as a prefilter
39+
*
40+
* @param expression expression to process recursively
41+
* @param filters current filters to apply to the expression. They contain expressions on the other side of the traversed conjunctions
42+
* @param addedFilter a new filter to add to the list of filters for the processing
43+
* @return the updated expression, or the original expression if it doesn't need to be updated
44+
*/
45+
private static Expression pushConjunctionsToKnn(Expression expression, Stack<Expression> filters, Expression addedFilter) {
3246
if (addedFilter != null) {
33-
filters.add(addedFilter);
47+
filters.push(addedFilter);
3448
}
35-
Expression result = switch(expression) {
49+
Expression result = switch (expression) {
3650
case And and:
51+
// Traverse both sides of the And, using the other side as the added filter
3752
Expression newLeft = pushConjunctionsToKnn(and.left(), filters, and.right());
3853
Expression newRight = pushConjunctionsToKnn(and.right(), filters, and.left());
3954
if (newLeft.equals(and.left()) && newRight.equals(and.right())) {
4055
yield and;
4156
}
4257
yield and.replaceChildrenSameSize(List.of(newLeft, newRight));
4358
case Knn knn:
44-
yield knn.withFilters(List.copyOf(filters));
59+
// Create a copy of the filters, and check the number of existing filters. We don't check for equality
60+
// as having two knn functions on opposite sides of the And is valid, but could lead to infinite number
61+
// of changes as the knn functions would recveive the other knn functions as filters and would be constantly
62+
// updating
63+
List<Expression> newFilters = new ArrayList<>(filters);
64+
if (newFilters.size() == knn.filterExpressions().size()) {
65+
yield knn;
66+
}
67+
yield knn.withFilters(newFilters);
4568
default:
4669
List<Expression> children = expression.children();
4770
boolean childrenChanged = false;
4871

49-
// TODO This copies transformChildren
72+
// This copies transformChildren algorithm to avoid unnecessary changes
5073
List<Expression> transformedChildren = null;
5174

5275
for (int i = 0, s = children.size(); i < s; i++) {
@@ -66,7 +89,7 @@ private static Expression pushConjunctionsToKnn(Expression expression, List<Expr
6689
};
6790

6891
if (addedFilter != null) {
69-
filters.removeLast();
92+
filters.pop();
7093
}
7194

7295
return result;

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 145 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,10 +1628,6 @@ private void testFullTextFunctionWithPushableConjunction(FullTextFunctionTestCas
16281628
assertEquals(expected.toString(), esQuery.query().toString());
16291629
}
16301630

1631-
public void testKnn() {
1632-
testFullTextFunctionWithNonPushableDisjunction(new KnnFunctionTestCase());
1633-
}
1634-
16351631
private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTestCase testCase) {
16361632
String query = String.format(Locale.ROOT, """
16371633
from test
@@ -1652,32 +1648,6 @@ private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTest
16521648
assertThat(fieldExtract.child(), instanceOf(EsQueryExec.class));
16531649
}
16541650

1655-
public void testKnnPrefilters() {
1656-
String query = """
1657-
from test
1658-
| where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1659-
""";
1660-
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1661-
1662-
var limit = as(plan, LimitExec.class);
1663-
var exchange = as(limit.child(), ExchangeExec.class);
1664-
var project = as(exchange.child(), ProjectExec.class);
1665-
var field = as(project.child(), FieldExtractExec.class);
1666-
var queryExec = as(field.child(), EsQueryExec.class);
1667-
QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
1668-
query,
1669-
unscore(rangeQuery("integer").gt(10)),
1670-
"integer",
1671-
new Source(2, 45, "integer > 10")
1672-
);
1673-
KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null)
1674-
.addFilterQuery(expectedFilterQueryBuilder);
1675-
var expectedQuery = boolQuery()
1676-
.must(expectedKnnQueryBuilder)
1677-
.must(expectedFilterQueryBuilder);
1678-
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1679-
}
1680-
16811651
private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCase testCase) {
16821652
String query = String.format(Locale.ROOT, """
16831653
from test
@@ -1697,12 +1667,11 @@ private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCas
16971667
}
16981668

16991669
private FullTextFunctionTestCase randomFullTextFunctionTestCase() {
1700-
return switch (randomIntBetween(0, 4)) {
1670+
return switch (randomIntBetween(0, 3)) {
17011671
case 0 -> new MatchFunctionTestCase();
17021672
case 1 -> new MatchOperatorTestCase();
17031673
case 2 -> new KqlFunctionTestCase();
17041674
case 3 -> new QueryStringFunctionTestCase();
1705-
case 4 -> new KnnFunctionTestCase();
17061675
default -> throw new IllegalStateException("Unexpected value");
17071676
};
17081677
}
@@ -1861,6 +1830,150 @@ public void testFullTextFunctionWithStatsBy(FullTextFunctionTestCase testCase) {
18611830
aggExec.forEachDown(EsQueryExec.class, esQueryExec -> { assertNull(esQueryExec.query()); });
18621831
}
18631832

1833+
public void testKnnPrefilters() {
1834+
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
1835+
1836+
String query = """
1837+
from test
1838+
| where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1839+
""";
1840+
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1841+
1842+
var limit = as(plan, LimitExec.class);
1843+
var exchange = as(limit.child(), ExchangeExec.class);
1844+
var project = as(exchange.child(), ProjectExec.class);
1845+
var field = as(project.child(), FieldExtractExec.class);
1846+
var queryExec = as(field.child(), EsQueryExec.class);
1847+
QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
1848+
query,
1849+
unscore(rangeQuery("integer").gt(10)),
1850+
"integer",
1851+
new Source(2, 45, "integer > 10")
1852+
);
1853+
KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
1854+
"dense_vector",
1855+
new float[] { 0, 1, 2 },
1856+
10,
1857+
null,
1858+
null,
1859+
null
1860+
).addFilterQuery(expectedFilterQueryBuilder);
1861+
var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder);
1862+
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1863+
}
1864+
1865+
public void testPushDownConjunctionsToKnnPrefilter() {
1866+
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
1867+
1868+
String query = """
1869+
from test
1870+
| where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1871+
""";
1872+
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1873+
1874+
var limit = as(plan, LimitExec.class);
1875+
var exchange = as(limit.child(), ExchangeExec.class);
1876+
var project = as(exchange.child(), ProjectExec.class);
1877+
var field = as(project.child(), FieldExtractExec.class);
1878+
var queryExec = as(field.child(), EsQueryExec.class);
1879+
1880+
// The filter condition should be pushed down to both the KNN query and the main query
1881+
QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
1882+
query,
1883+
unscore(rangeQuery("integer").gt(10)),
1884+
"integer",
1885+
new Source(2, 45, "integer > 10")
1886+
);
1887+
1888+
KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
1889+
"dense_vector",
1890+
new float[] { 0, 1, 2 },
1891+
10,
1892+
null,
1893+
null,
1894+
null
1895+
).addFilterQuery(expectedFilterQueryBuilder);
1896+
1897+
var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder);
1898+
1899+
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1900+
}
1901+
1902+
public void testNotPushDownDisjunctionsToKnnPrefilter() {
1903+
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
1904+
1905+
String query = """
1906+
from test
1907+
| where knn(dense_vector, [0, 1, 2], 10) or integer > 10
1908+
""";
1909+
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1910+
1911+
var limit = as(plan, LimitExec.class);
1912+
var exchange = as(limit.child(), ExchangeExec.class);
1913+
var project = as(exchange.child(), ProjectExec.class);
1914+
var field = as(project.child(), FieldExtractExec.class);
1915+
var queryExec = as(field.child(), EsQueryExec.class);
1916+
1917+
// The disjunction should not be pushed down to the KNN query
1918+
KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
1919+
QueryBuilder rangeQueryBuilder = wrapWithSingleQuery(
1920+
query,
1921+
unscore(rangeQuery("integer").gt(10)),
1922+
"integer",
1923+
new Source(2, 44, "integer > 10")
1924+
);
1925+
1926+
var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder);
1927+
1928+
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1929+
}
1930+
1931+
public void testMultipleKnnQueriesInPrefilters() {
1932+
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
1933+
1934+
String query = """
1935+
from test
1936+
| where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
1937+
""";
1938+
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1939+
1940+
var limit = as(plan, LimitExec.class);
1941+
var exchange = as(limit.child(), ExchangeExec.class);
1942+
var project = as(exchange.child(), ProjectExec.class);
1943+
var field = as(project.child(), FieldExtractExec.class);
1944+
var queryExec = as(field.child(), EsQueryExec.class);
1945+
1946+
KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
1947+
// Integer range query (right side of first OR)
1948+
QueryBuilder integerRangeQuery = wrapWithSingleQuery(
1949+
query,
1950+
unscore(rangeQuery("integer").gt(10)),
1951+
"integer",
1952+
new Source(2, 45, "integer > 10")
1953+
);
1954+
1955+
// Second KNN query (right side of second OR)
1956+
KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
1957+
1958+
// Keyword term query (left side of second OR)
1959+
QueryBuilder keywordQuery = wrapWithSingleQuery(
1960+
query,
1961+
unscore(termQuery("keyword", "test")),
1962+
"keyword",
1963+
new Source(2, 87, "keyword == \"test\"")
1964+
);
1965+
1966+
// First OR (knn1 OR integer > 10)
1967+
var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery);
1968+
// Second OR (keyword == "test" OR knn2)
1969+
var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery.addFilterQuery(firstOr));
1970+
firstKnnQuery.addFilterQuery(secondOr);
1971+
1972+
// Top-level AND combining both ORs
1973+
var expectedQuery = boolQuery().must(firstOr).must(secondOr);
1974+
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1975+
}
1976+
18641977
public void testParallelizeTimeSeriesPlan() {
18651978
assumeTrue("requires snapshot builds", Build.current().isSnapshot());
18661979
var query = "TS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)";

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7954,4 +7954,49 @@ public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() {
79547954
var rightAndPrefilter = as(knn.filterExpressions().get(0), GreaterThan.class);
79557955
assertThat(leftAnd.right(), equalTo(rightAndPrefilter));
79567956
}
7957+
7958+
public void testMultipleKnnQueriesInPrefilters() {
7959+
assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
7960+
7961+
/*
7962+
and
7963+
or
7964+
knn(dense_vector, [0, 1, 2], 10)
7965+
integer > 10
7966+
or
7967+
keyword == "test"
7968+
knn(dense_vector, [4, 5, 6], 10)
7969+
*/
7970+
var query = """
7971+
from test
7972+
| where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
7973+
""";
7974+
var optimized = planTypes(query);
7975+
7976+
var limit = as(optimized, Limit.class);
7977+
var filter = as(limit.child(), Filter.class);
7978+
var and = as(filter.condition(), And.class);
7979+
7980+
// First OR (knn1 OR integer > 10)
7981+
var firstOr = as(and.left(), Or.class);
7982+
var firstKnn = as(firstOr.left(), Knn.class);
7983+
var integerGt = as(firstOr.right(), GreaterThan.class);
7984+
7985+
// Second OR (keyword == "test" OR knn2)
7986+
var secondOr = as(and.right(), Or.class);
7987+
var keywordEq = as(secondOr.left(), Equals.class);
7988+
var secondKnn = as(secondOr.right(), Knn.class);
7989+
7990+
// First KNN should have the second OR as its filter
7991+
List<Expression> firstKnnFilters = firstKnn.filterExpressions();
7992+
assertThat(firstKnnFilters.size(), equalTo(1));
7993+
var secondOrWithoutFilters = secondOr.replaceChildren(List.of(secondOr.left(), secondKnn.withFilters(List.of())));
7994+
assertTrue(firstKnnFilters.contains(secondOrWithoutFilters));
7995+
7996+
// Second KNN should have the first OR as its filter
7997+
List<Expression> secondKnnFilters = secondKnn.filterExpressions();
7998+
assertThat(secondKnnFilters.size(), equalTo(1));
7999+
var firstOrWithoutFilters = firstOr.replaceChildren(List.of(firstKnn.withFilters(List.of()), firstOr.right()));
8000+
assertTrue(secondKnnFilters.contains(firstOrWithoutFilters));
8001+
}
79578002
}

0 commit comments

Comments
 (0)