Skip to content

Commit a38f6b9

Browse files
fix predicate push down
1 parent fb221c9 commit a38f6b9

File tree

5 files changed

+182
-46
lines changed

5 files changed

+182
-46
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/resources/subquery.csv-spec

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,49 @@ null | null | null | 1 | 172.21.2.162
9292
null | null | null | 4 | 172.21.3.15
9393
;
9494

95+
subqueryInFromWithStatsInSubqueryConjunctiveFilterInMainQuery
96+
required_capability: fork_v9
97+
required_capability: subquery_in_from_command
98+
99+
FROM employees, (FROM sample_data
100+
| STATS cnt = count(*) by client_ip )
101+
, (FROM sample_data_str
102+
| STATS cnt = count(*) by client_ip )
103+
metadata _index
104+
| EVAL client_ip = client_ip::ip
105+
| WHERE client_ip == "172.21.3.15" AND cnt >0
106+
| SORT emp_no, client_ip
107+
| KEEP _index, emp_no, languages, cnt, client_ip
108+
;
109+
110+
_index:keyword | emp_no:integer | languages:integer | cnt:long | client_ip:ip
111+
null | null | null | 4 | 172.21.3.15
112+
null | null | null | 4 | 172.21.3.15
113+
;
114+
115+
subqueryInFromWithStatsInSubqueryDisjunctiveFilterInMainQuery
116+
required_capability: fork_v9
117+
required_capability: subquery_in_from_command
118+
119+
FROM employees, (FROM sample_data
120+
| STATS cnt = count(*) by client_ip )
121+
, (FROM sample_data_str
122+
| STATS cnt = count(*) by client_ip )
123+
metadata _index
124+
| EVAL client_ip = client_ip::ip
125+
| WHERE ( emp_no >= 10091 AND emp_no < 10094) OR client_ip == "172.21.3.15"
126+
| SORT emp_no, client_ip
127+
| KEEP _index, emp_no, languages, cnt, client_ip
128+
;
129+
130+
_index:keyword | emp_no:integer | languages:integer | cnt:long | client_ip:ip
131+
employees | 10091 | 3 | null | null
132+
employees | 10092 | 1 | null | null
133+
employees | 10093 | 3 | null | null
134+
null | null | null | 4 | 172.21.3.15
135+
null | null | null | 4 | 172.21.3.15
136+
;
137+
95138
subqueryInFromWithStatsInMainQuery
96139
required_capability: fork_v9
97140
required_capability: subquery_in_from_command

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFilters.java

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -329,24 +329,31 @@ private static LogicalPlan maybePushDownPastUnionAll(Filter filter, UnionAll uni
329329
if (pushable.isEmpty()) {
330330
return filter; // nothing to push down
331331
}
332-
// Preserve the filter on top of UnionAll if not all pushable predicates can be pushed down into UnionAll children.
333-
// This happens when the pushable predicate contains ReferenceAttribute that cannot be mapped to children's output correctly.
334-
boolean preserveOriginalFilterOnTopOfUnionAll = false;
335-
// Push the filter down to each child of the UnionAll, the child of a UnionAll is always a project
336-
// followed by an optional eval and then limit added by fork and then the real child
332+
// Push the filter down to each child of the UnionAll, the child of a UnionAll is always
333+
// a project followed by an optional eval and then limit or a limit added by fork and
334+
// then the real child, if there is unknown pattern, keep the filter and UnionAll plan unchanged
337335
List<LogicalPlan> newChildren = new ArrayList<>();
338336
boolean changed = false;
339337
for (LogicalPlan child : unionAll.children()) {
340-
if (child instanceof Project project) {
341-
LogicalPlan newChild = maybePushDownFilterPastEvalAndLimitForUnionAllChild(pushable, project);
342-
if (newChild != child) {
343-
changed = true;
344-
} else {
345-
preserveOriginalFilterOnTopOfUnionAll = true;
346-
}
338+
LogicalPlan newChild = switch (child) {
339+
case Project project -> maybePushDownFilterPastProjectForUnionAllChild(pushable, project);
340+
case Limit limit -> maybePushDownFilterPastLimitForUnionAllChild(pushable, limit);
341+
default -> null; // TODO add a general push down for unexpected pattern
342+
};
343+
344+
if (newChild == null) {
345+
// Unexpected pattern, keep plan unchanged without pushing down filters
346+
return filter;
347+
}
348+
349+
if (newChild != child) {
350+
changed = true;
347351
newChildren.add(newChild);
348-
} else { // unexpected pattern, just add the child as is
349-
newChildren.add(child);
352+
} else {
353+
// Theoretically, all the pushable predicates should be pushed down into each child,
354+
// in case one child is not changed, preserve the filter on top of UnionAll to make sure
355+
// correct results are returned and avoid infinite loop of the rule.
356+
return filter;
350357
}
351358
}
352359

@@ -355,34 +362,16 @@ private static LogicalPlan maybePushDownPastUnionAll(Filter filter, UnionAll uni
355362
}
356363

357364
LogicalPlan newUnionAll = unionAll.replaceChildren(newChildren);
358-
if (preserveOriginalFilterOnTopOfUnionAll) {
359-
// Preserve the filter on top of UnionAll as some pushable predicates cannot be pushed down
360-
// to make sure correct results are returned
361-
return filter.replaceChild(newUnionAll);
362-
}
363365
if (nonPushable.isEmpty()) {
364366
return newUnionAll;
365367
} else {
366368
return filter.with(newUnionAll, Predicates.combineAnd(nonPushable));
367369
}
368370
}
369371

370-
private static LogicalPlan maybePushDownFilterPastEvalAndLimitForUnionAllChild(List<Expression> pushable, Project project) {
371-
List<Expression> resolvedPushable = new ArrayList<>();
372-
// Make sure the pushable predicates can find their corresponding attributes in the child project
373-
for (Expression exp : pushable) {
374-
Expression replaced = resolveUnionAllOutputByName(exp, project.projections());
375-
if (replaced == null || replaced == exp) {
376-
// cannot find the attribute in the child project, cannot push down this filter
377-
return project;
378-
} else {
379-
resolvedPushable.add(replaced);
380-
}
381-
}
382-
if (resolvedPushable.size() != pushable.size()) {
383-
// Some pushable predicates cannot be resolved to the child project, cannot push down.
384-
// This should not happen, however we need to be cautious here, if the predicate is removed from the main query,
385-
// and it is not pushed down into the UnionAll child, the result will be incorrect.
372+
private static LogicalPlan maybePushDownFilterPastProjectForUnionAllChild(List<Expression> pushable, Project project) {
373+
List<Expression> resolvedPushable = resolvePushableAgainstOutput(pushable, project.projections());
374+
if (resolvedPushable == null) {
386375
return project;
387376
}
388377
LogicalPlan child = project.child();
@@ -395,6 +384,35 @@ private static LogicalPlan maybePushDownFilterPastEvalAndLimitForUnionAllChild(L
395384
return project;
396385
}
397386

387+
private static LogicalPlan maybePushDownFilterPastLimitForUnionAllChild(List<Expression> pushable, Limit limit) {
388+
List<Expression> resolvedPushable = resolvePushableAgainstOutput(pushable, limit.output());
389+
if (resolvedPushable == null) {
390+
return limit;
391+
}
392+
return pushDownFilterPastLimitForUnionAllChild(resolvedPushable, limit);
393+
}
394+
395+
/**
396+
* Attempts to resolve all pushable expressions against the given output attributes.
397+
* Returns a fully resolved list if successful, or null if any expression cannot be resolved.
398+
*/
399+
private static List<Expression> resolvePushableAgainstOutput(List<Expression> pushable, List<? extends NamedExpression> output) {
400+
List<Expression> resolved = new ArrayList<>();
401+
for (Expression exp : pushable) {
402+
Expression replaced = resolveUnionAllOutputByName(exp, output);
403+
// Make sure the pushable predicates can find their corresponding attributes in the output
404+
if (replaced == null || replaced == exp) {
405+
// cannot find the attribute in the child project, cannot push down this filter
406+
return null;
407+
}
408+
resolved.add(replaced);
409+
}
410+
// If some pushable predicates cannot be resolved against the output, cannot push filter down.
411+
// This should not happen, however we need to be cautious here, if the predicate is removed from
412+
// the main query, and it is not pushed down into the UnionAll child, the result will be incorrect.
413+
return resolved.size() == pushable.size() ? resolved : null;
414+
}
415+
398416
private static LogicalPlan pushDownFilterPastEvalForUnionAllChild(List<Expression> pushable, Project project, Eval eval) {
399417
// if the pushable references any attribute created by the eval, we cannot push down
400418
AttributeMap<Expression> evalAliases = buildEvaAliases(eval);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnionAll.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ public BiConsumer<LogicalPlan, Failures> postOptimizationPlanVerification() {
106106
return UnionAll::checkNestedUnionAlls;
107107
}
108108

109+
/**
110+
* Defer the check for nested UnionAlls until after logical planner as some of the nested subqueries can be flattened
111+
* by logical planner in the future.
112+
*/
109113
private static void checkNestedUnionAlls(LogicalPlan logicalPlan, Failures failures) {
110114
if (logicalPlan instanceof UnionAll unionAll) {
111115
unionAll.forEachDown(UnionAll.class, otherUnionAll -> {

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ public static void init() {
203203
EsqlTestUtils.TEST_CFG,
204204
new EsqlFunctionRegistry(),
205205
getIndexResult,
206-
emptyMap(),
206+
defaultLookupResolution(),
207207
enrichResolution,
208208
emptyInferenceResolution(),
209209
Map.of("test1", subqueryIndex1, "languages", subqueryIndex2)

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineFiltersTests.java

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1818
import org.elasticsearch.xpack.esql.core.expression.Literal;
1919
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
20+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
2021
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
2122
import org.elasticsearch.xpack.esql.core.tree.Source;
2223
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -1231,18 +1232,38 @@ public void testPushDownFilterPastUnionAllAndCombineWithFilterInSubquery() {
12311232
public void testPushDownFilterOnReferenceAttributesPastUnionAllDebug() {
12321233
assumeTrue("Requires subquery in FROM command support", EsqlCapabilities.Cap.SUBQUERY_IN_FROM_COMMAND.isEnabled());
12331234
var plan = planSubquery("""
1234-
FROM test, (FROM test1 | where salary < 100000 | EVAL x = 1, y = emp_no, z = emp_no + 1)
1235+
FROM test
1236+
, (FROM test1
1237+
| where salary < 100000
1238+
| EVAL x = 1, y = emp_no, z = emp_no + 1)
1239+
, (FROM languages
1240+
| STATS cnt = COUNT(*) by language_code
1241+
| RENAME language_code AS z, cnt AS y
1242+
| EVAL x = 1)
1243+
, (FROM test1
1244+
| RENAME languages AS language_code
1245+
| LOOKUP JOIN languages_lookup ON language_code
1246+
| RENAME emp_no AS x, salary AS y, language_code AS z)
12351247
| WHERE x is not null and y is not null and z > 0
12361248
""");
12371249

12381250
Limit limit = as(plan, Limit.class);
12391251
UnionAll unionAll = as(limit.child(), UnionAll.class);
1240-
assertEquals(2, unionAll.children().size());
1252+
assertEquals(4, unionAll.children().size());
12411253

12421254
LocalRelation child1 = as(unionAll.children().get(0), LocalRelation.class);
12431255

12441256
EsqlProject child2 = as(unionAll.children().get(1), EsqlProject.class);
1245-
Limit childLimit = as(child2.child(), Limit.class);
1257+
Filter filter = as(child2.child(), Filter.class);
1258+
IsNotNull isNotNull = as(filter.condition(), IsNotNull.class);
1259+
ReferenceAttribute y = as(isNotNull.field(), ReferenceAttribute.class);
1260+
assertEquals("y", y.name());
1261+
Eval eval = as(filter.child(), Eval.class);
1262+
List<Alias> aliases = eval.fields();
1263+
assertEquals(2, aliases.size());
1264+
assertEquals("language_name", aliases.get(0).name());
1265+
assertEquals("y", aliases.get(1).name());
1266+
Limit childLimit = as(eval.child(), Limit.class);
12461267
Subquery subquery = as(childLimit.child(), Subquery.class);
12471268
Project project = as(subquery.child(), Project.class);
12481269
Filter childFilter = as(project.child(), Filter.class);
@@ -1251,8 +1272,8 @@ public void testPushDownFilterOnReferenceAttributesPastUnionAllDebug() {
12511272
assertEquals("z", z.name());
12521273
Literal right = as(greaterThan.right(), Literal.class);
12531274
assertEquals(0, right.value());
1254-
Eval eval = as(childFilter.child(), Eval.class);
1255-
List<Alias> aliases = eval.fields();
1275+
eval = as(childFilter.child(), Eval.class);
1276+
aliases = eval.fields();
12561277
assertEquals(2, aliases.size());
12571278
Alias aliasX = aliases.get(0);
12581279
assertEquals("x", aliasX.name());
@@ -1261,17 +1282,67 @@ public void testPushDownFilterOnReferenceAttributesPastUnionAllDebug() {
12611282
Alias aliasZ = aliases.get(1);
12621283
assertEquals("z", aliasZ.name());
12631284
childFilter = as(eval.child(), Filter.class);
1264-
And and = as(childFilter.condition(), And.class);
1265-
IsNotNull isNotNull = as(and.right(), IsNotNull.class);
1266-
FieldAttribute emp_no = as(isNotNull.field(), FieldAttribute.class);
1267-
assertEquals("emp_no", emp_no.name());
1268-
LessThan lessThan = as(and.left(), LessThan.class);
1285+
LessThan lessThan = as(childFilter.condition(), LessThan.class);
12691286
FieldAttribute salaryField = as(lessThan.left(), FieldAttribute.class);
12701287
assertEquals("salary", salaryField.name());
12711288
Literal literal = as(lessThan.right(), Literal.class);
12721289
assertEquals(100000, literal.value());
12731290
EsRelation relation = as(childFilter.child(), EsRelation.class);
12741291
assertEquals("test1", relation.indexPattern());
1292+
1293+
EsqlProject child3 = as(unionAll.children().get(2), EsqlProject.class);
1294+
eval = as(child3.child(), Eval.class);
1295+
limit = as(eval.child(), Limit.class);
1296+
subquery = as(limit.child(), Subquery.class);
1297+
eval = as(subquery.child(), Eval.class);
1298+
filter = as(eval.child(), Filter.class);
1299+
And and = as(filter.condition(), And.class);
1300+
isNotNull = as(and.left(), IsNotNull.class);
1301+
y = as(isNotNull.field(), ReferenceAttribute.class);
1302+
assertEquals("y", y.name());
1303+
greaterThan = as(and.right(), GreaterThan.class);
1304+
z = as(greaterThan.left(), ReferenceAttribute.class);
1305+
assertEquals("z", z.name());
1306+
right = as(greaterThan.right(), Literal.class);
1307+
assertEquals(0, right.value());
1308+
Aggregate aggregate = as(filter.child(), Aggregate.class);
1309+
List<Expression> groupings = aggregate.groupings();
1310+
assertEquals(1, groupings.size());
1311+
FieldAttribute language_code = as(groupings.get(0), FieldAttribute.class);
1312+
assertEquals("language_code", language_code.name());
1313+
List<? extends NamedExpression> aggregates = aggregate.aggregates();
1314+
assertEquals(2, aggregates.size());
1315+
assertEquals("y", aggregates.get(0).name());
1316+
assertEquals("z", aggregates.get(1).name());
1317+
relation = as(aggregate.child(), EsRelation.class);
1318+
assertEquals("languages", relation.indexPattern());
1319+
1320+
EsqlProject child4 = as(unionAll.children().get(3), EsqlProject.class);
1321+
filter = as(child4.child(), Filter.class);
1322+
isNotNull = as(filter.condition(), IsNotNull.class);
1323+
ReferenceAttribute x = as(isNotNull.field(), ReferenceAttribute.class);
1324+
assertEquals("y", x.name());
1325+
eval = as(filter.child(), Eval.class);
1326+
aliases = eval.fields();
1327+
assertEquals(4, aliases.size());
1328+
limit = as(eval.child(), Limit.class);
1329+
subquery = as(limit.child(), Subquery.class);
1330+
project = as(subquery.child(), Project.class);
1331+
Join lookupJoin = as(project.child(), Join.class);
1332+
Filter leftFilter = as(lookupJoin.left(), Filter.class);
1333+
and = as(leftFilter.condition(), And.class);
1334+
isNotNull = as(and.left(), IsNotNull.class);
1335+
FieldAttribute emp_no = as(isNotNull.field(), FieldAttribute.class);
1336+
assertEquals("emp_no", emp_no.name());
1337+
greaterThan = as(and.right(), GreaterThan.class);
1338+
language_code = as(greaterThan.left(), FieldAttribute.class);
1339+
assertEquals("languages", language_code.name());
1340+
right = as(greaterThan.right(), Literal.class);
1341+
assertEquals(0, right.value());
1342+
relation = as(leftFilter.child(), EsRelation.class);
1343+
assertEquals("test1", relation.indexPattern());
1344+
relation = as(lookupJoin.right(), EsRelation.class);
1345+
assertEquals("languages_lookup", relation.indexPattern());
12751346
}
12761347

12771348
public void testPushDownFilterOnReferenceAttributesAndFieldAttributesPastUnionAllDebug() {

0 commit comments

Comments
 (0)