Skip to content

Commit 62ad977

Browse files
committed
Update folding past LOOKUP JOIN
1 parent d3d1d24 commit 62ad977

File tree

3 files changed

+80
-49
lines changed

3 files changed

+80
-49
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneColumns.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.xpack.esql.plan.logical.Project;
2727
import org.elasticsearch.xpack.esql.plan.logical.Sample;
2828
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
29+
import org.elasticsearch.xpack.esql.plan.logical.join.StubRelation;
2930
import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier;
3031
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
3132
import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier;
@@ -119,9 +120,26 @@ private static LogicalPlan pruneColumns(Aggregate aggregate, AttributeSet.Builde
119120
p = aggregate.with(aggregate.groupings(), remaining);
120121
}
121122
} else {
122-
p = inlineJoin && aggregate.groupings().containsAll(remaining) // not expecting high groups cardinality
123-
? new Project(aggregate.source(), aggregate.child(), remaining)
124-
: aggregate.with(aggregate.groupings(), remaining);
123+
if (inlineJoin && aggregate.groupings().containsAll(remaining)) { // not expecting high groups cardinality
124+
// It's an INLINEJOIN and all remaining attributes are groupings, which are already part of the IJ output (from the
125+
// left-hand side).
126+
if (aggregate.child() instanceof StubRelation stub) {
127+
var message = "Aggregate groups references ["
128+
+ remaining
129+
+ "] not in child's (StubRelation) output: ["
130+
+ stub.outputSet()
131+
+ "]";
132+
assert stub.outputSet().containsAll(Expressions.asAttributes(remaining)) : message;
133+
134+
p = emptyLocalRelation(aggregate);
135+
} else {
136+
// There are no aggregates to compute, just output the groupings; these are already in the IJ output, so only
137+
// restrict the output to what remained.
138+
p = new Project(aggregate.source(), aggregate.child(), remaining);
139+
}
140+
} else { // not an INLINEJOIN or there are actually aggregates to compute
141+
p = aggregate.with(aggregate.groupings(), remaining);
142+
}
125143
}
126144
}
127145

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ public static void init() {
100100
EsqlTestUtils.TEST_CFG,
101101
new EsqlFunctionRegistry(),
102102
getIndexResultAirports,
103+
defaultLookupResolution(),
103104
enrichResolution,
104105
emptyInferenceResolution()
105106
),

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import java.util.List;
140140
import java.util.Locale;
141141
import java.util.Map;
142+
import java.util.Set;
142143
import java.util.function.BiFunction;
143144
import java.util.function.Function;
144145

@@ -5751,13 +5752,12 @@ public void testInlinestatsNestedExpressionsInGroups() {
57515752
var limit = as(plan, Limit.class); // TODO: this needs to go
57525753
var inline = as(limit.child(), InlineJoin.class);
57535754
var eval = as(inline.left(), Eval.class);
5754-
assertThat(eval.fields(), hasSize(1));
5755-
assertThat(Expressions.attribute(eval.fields().get(0)).name(), is("emp_no % 2"));
5755+
assertThat(Expressions.names(eval.fields()), is(List.of("emp_no % 2")));
57565756
limit = asLimit(eval.child(), 1000, false);
57575757
var agg = as(inline.right(), Aggregate.class);
57585758
var groupings = agg.groupings();
5759-
var aggs = agg.aggregates();
57605759
var ref = as(groupings.get(0), ReferenceAttribute.class);
5760+
var aggs = agg.aggregates();
57615761
assertThat(aggs.get(1), is(ref));
57625762
assertThat(eval.fields().get(0).toAttribute(), is(ref));
57635763
assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
@@ -5793,10 +5793,7 @@ public void testInlinestatsGetsPrunedEntirely() {
57935793
var plan = optimizedPlan(query);
57945794

57955795
var project = as(plan, Project.class);
5796-
var projections = project.projections();
5797-
assertThat(projections.size(), equalTo(2));
5798-
assertThat(projections.get(0).name(), equalTo("x"));
5799-
assertThat(projections.get(1).name(), equalTo("emp_no"));
5796+
assertThat(Expressions.names(project.projections()), is(List.of("x", "emp_no")));
58005797
var topN = as(project.child(), TopN.class);
58015798
assertThat(topN.order().size(), is(1));
58025799
var relation = as(topN.child(), EsRelation.class);
@@ -5820,10 +5817,7 @@ public void testDoubleInlinestatsGetsPrunedEntirely() {
58205817
var plan = optimizedPlan(query);
58215818

58225819
var project = as(plan, Project.class);
5823-
var projections = project.projections();
5824-
assertThat(projections.size(), equalTo(2));
5825-
assertThat(projections.get(0).name(), equalTo("x"));
5826-
assertThat(projections.get(1).name(), equalTo("emp_no"));
5820+
assertThat(Expressions.names(project.projections()), is(List.of("x", "emp_no")));
58275821
var topN = as(project.child(), TopN.class);
58285822
assertThat(topN.order().size(), is(1));
58295823
var relation = as(topN.child(), EsRelation.class);
@@ -5853,24 +5847,17 @@ public void testInlinestatsGetsPrunedPartially() {
58535847
var plan = optimizedPlan(query);
58545848

58555849
var project = as(plan, Project.class);
5856-
var projections = project.projections();
5857-
assertThat(projections.size(), equalTo(3));
5858-
assertThat(projections.get(0).name(), equalTo("x"));
5859-
assertThat(projections.get(1).name(), equalTo("a"));
5860-
assertThat(projections.get(2).name(), equalTo("emp_no"));
5850+
assertThat(Expressions.names(project.projections()), is(List.of("x", "a", "emp_no")));
58615851
var upperLimit = asLimit(project.child(), 1, true);
58625852
var inlineJoin = as(upperLimit.child(), InlineJoin.class);
5863-
assertThat(inlineJoin.config().matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("emp_no{f}")));
5853+
assertThat(Expressions.names(inlineJoin.config().matchFields()), is(List.of("emp_no")));
58645854
// Left
58655855
var limit = as(inlineJoin.left(), Limit.class); // TODO: this needs to go
58665856
assertThat(limit.limit().fold(FoldContext.small()), equalTo(1));
58675857
var relation = as(limit.child(), EsRelation.class);
58685858
// Right
58695859
var agg = as(inlineJoin.right(), Aggregate.class);
5870-
assertMap(
5871-
agg.output().stream().map(Object::toString).toList(),
5872-
matchesList().item(startsWith("a{r}")).item(startsWith("emp_no{f}"))
5873-
);
5860+
assertMap(Expressions.names(agg.output()), is(List.of("a", "emp_no")));
58745861
var stub = as(agg.child(), StubRelation.class);
58755862
}
58765863

@@ -5891,27 +5878,62 @@ public void testTrippleInlinestatsGetsPrunedPartially() {
58915878
var plan = optimizedPlan(query);
58925879

58935880
var project = as(plan, Project.class);
5894-
var projections = project.projections();
5895-
assertThat(projections.size(), equalTo(3));
5896-
assertThat(projections.get(0).name(), equalTo("x"));
5897-
assertThat(projections.get(1).name(), equalTo("a"));
5898-
assertThat(projections.get(2).name(), equalTo("emp_no"));
5881+
assertThat(Expressions.names(project.projections()), is(List.of("x", "a", "emp_no")));
58995882
var upperLimit = asLimit(project.child(), 1, true);
59005883
var inlineJoin = as(upperLimit.child(), InlineJoin.class);
5901-
assertThat(inlineJoin.config().matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("emp_no{f}")));
5884+
assertThat(Expressions.names(inlineJoin.config().matchFields()), is(List.of("emp_no")));
59025885
// Left
59035886
var limit = as(inlineJoin.left(), Limit.class);
59045887
assertThat(limit.limit().fold(FoldContext.small()), equalTo(1));
59055888
var relation = as(limit.child(), EsRelation.class);
59065889
// Right
59075890
var agg = as(inlineJoin.right(), Aggregate.class);
5908-
assertMap(
5909-
agg.output().stream().map(Object::toString).toList(),
5910-
matchesList().item(startsWith("a{r}")).item(startsWith("emp_no{f}"))
5911-
);
5891+
assertMap(Expressions.names(agg.output()), is(List.of("a", "emp_no")));
59125892
var stub = as(agg.child(), StubRelation.class);
59135893
}
59145894

5895+
/*
5896+
* Project[[abbrev{f}#19, scalerank{f}#21 AS backup_scalerank#4, language_name{f}#28 AS scalerank#11]]
5897+
* \_TopN[[Order[abbrev{f}#19,DESC,FIRST]],5[INTEGER]]
5898+
* \_Join[LEFT,[scalerank{f}#21],[scalerank{f}#21],[language_code{f}#27]]
5899+
* |_EsRelation[airports][abbrev{f}#19, city{f}#25, city_location{f}#26, coun..]
5900+
* \_EsRelation[languages_lookup][LOOKUP][language_code{f}#27, language_name{f}#28]
5901+
*/
5902+
public void testInlinestatsWithLookupJoin() {
5903+
var query = """
5904+
FROM airports
5905+
| EVAL backup_scalerank = scalerank
5906+
| RENAME scalerank AS language_code
5907+
| LOOKUP JOIN languages_lookup ON language_code
5908+
| RENAME language_name as scalerank
5909+
| DROP language_code
5910+
| INLINESTATS count=COUNT(*) BY scalerank
5911+
| SORT abbrev DESC
5912+
| KEEP abbrev, *scalerank
5913+
| LIMIT 5
5914+
""";
5915+
assumeTrue("Requires LOOKUP JOIN", EsqlCapabilities.Cap.JOIN_LOOKUP_V12.isEnabled());
5916+
if (releaseBuildForInlinestats(query)) {
5917+
return;
5918+
}
5919+
5920+
var plan = planAirports(query);
5921+
var project = as(plan, Project.class);
5922+
assertThat(Expressions.names(project.projections()), is(List.of("abbrev", "backup_scalerank", "scalerank")));
5923+
var topN = as(project.child(), TopN.class);
5924+
assertThat(topN.order().size(), is(1));
5925+
var order = as(topN.order().get(0), Order.class);
5926+
assertThat(order.direction(), equalTo(Order.OrderDirection.DESC));
5927+
assertThat(order.nullsPosition(), equalTo(Order.NullsPosition.FIRST));
5928+
assertThat(Expressions.name(order.child()), equalTo("abbrev"));
5929+
var join = as(topN.child(), Join.class);
5930+
assertThat(Expressions.names(join.config().matchFields()), is(List.of("scalerank")));
5931+
var left = as(join.left(), EsRelation.class);
5932+
assertThat(left.concreteIndices(), is(Set.of("airports")));
5933+
var right = as(join.right(), EsRelation.class);
5934+
assertThat(right.concreteIndices(), is(Set.of("languages_lookup")));
5935+
}
5936+
59155937
/*
59165938
* EsqlProject[[avg{r}#4, emp_no{f}#9, first_name{f}#10]]
59175939
* \_Limit[10[INTEGER],true]
@@ -5938,30 +5960,20 @@ public void testInlinestatsWithAvg() {
59385960
var plan = optimizedPlan(query);
59395961

59405962
var esqlProject = as(plan, EsqlProject.class);
5941-
var projections = esqlProject.projections();
5942-
assertThat(projections.size(), equalTo(3));
5943-
assertThat(projections.get(0).name(), equalTo("avg"));
5944-
assertThat(projections.get(1).name(), equalTo("emp_no"));
5945-
assertThat(projections.get(2).name(), equalTo("first_name"));
5963+
assertThat(Expressions.names(esqlProject.projections()), is(List.of("avg", "emp_no", "first_name")));
59465964
var upperLimit = asLimit(esqlProject.child(), 10, true);
59475965
var inlineJoin = as(upperLimit.child(), InlineJoin.class);
5948-
assertThat(inlineJoin.config().matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("emp_no{f}")));
5966+
assertThat(Expressions.names(inlineJoin.config().matchFields()), is(List.of("emp_no")));
59495967
// Left
5950-
var limit = as(inlineJoin.left(), Limit.class); // TODO: this needs to go
5951-
assertThat(limit.limit().fold(FoldContext.small()), equalTo(10));
5968+
var limit = asLimit(inlineJoin.left(), 10, false); // TODO: this needs to go
59525969
var relation = as(limit.child(), EsRelation.class);
59535970
// Right
59545971
var project = as(inlineJoin.right(), Project.class);
59555972
assertThat(Expressions.names(project.projections()), contains("avg", "emp_no"));
59565973
var eval = as(project.child(), Eval.class);
5957-
assertThat(eval.fields().size(), equalTo(1));
5958-
var field = eval.fields().getFirst();
5959-
assertThat(field.toAttribute().name(), equalTo("avg"));
5974+
assertThat(Expressions.names(eval.fields()), is(List.of("avg")));
59605975
var agg = as(eval.child(), Aggregate.class);
5961-
assertMap(
5962-
agg.output().stream().map(Object::toString).toList(),
5963-
matchesList().item(startsWith("$$SUM$avg$0")).item(startsWith("$$COUNT$avg$1")).item(startsWith("emp_no{f}"))
5964-
);
5976+
assertMap(Expressions.names(agg.output()), is(List.of("$$SUM$avg$0", "$$COUNT$avg$1", "emp_no")));
59655977
var stub = as(agg.child(), StubRelation.class);
59665978
}
59675979

0 commit comments

Comments
 (0)