Skip to content

Commit 52893e6

Browse files
committed
Adding support for pruning columns in fork branches
1 parent 594a373 commit 52893e6

File tree

7 files changed

+641
-19
lines changed

7 files changed

+641
-19
lines changed

docs/changelog/137907.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 137907
2+
summary: Prune columns when using fork
3+
area: ES|QL
4+
type: bug
5+
issues:
6+
- 136365

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/AttributeSet.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,15 @@ public static Builder builder(int expectedSize) {
212212
return new Builder(AttributeMap.builder(expectedSize));
213213
}
214214

215+
public static Builder forkBuilder() {
216+
return new ForkBuilder(AttributeMap.builder());
217+
}
218+
215219
public static class Builder {
216220

217-
private final AttributeMap.Builder<Object> mapBuilder;
221+
protected final AttributeMap.Builder<Object> mapBuilder;
218222

219-
private Builder(AttributeMap.Builder<Object> mapBuilder) {
223+
protected Builder(AttributeMap.Builder<Object> mapBuilder) {
220224
this.mapBuilder = mapBuilder;
221225
}
222226

@@ -266,4 +270,25 @@ public void clear() {
266270
mapBuilder.keySet().clear();
267271
}
268272
}
273+
274+
/**
275+
* This class extends {@code Builder}, but its {@code contains} method also matches {@code NamedExpression} instances by their name.
276+
* This is useful for Fork plans, where branches may have different Attribute IDs but share a common output schema,
277+
* allowing equality checks of used attributes based on their names.
278+
*/
279+
public static class ForkBuilder extends Builder {
280+
281+
private ForkBuilder(AttributeMap.Builder<Object> mapBuilder) {
282+
super(mapBuilder);
283+
}
284+
285+
@Override
286+
public boolean contains(Object o) {
287+
if (super.contains(o)) {
288+
return true;
289+
}
290+
return o instanceof NamedExpression
291+
&& mapBuilder.keySet().stream().anyMatch(x -> x.name().equals(((NamedExpression) o).name()));
292+
}
293+
}
269294
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropagateNullable;
3535
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PropgateUnmappedFields;
3636
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns;
37+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumnsInForkBranches;
3738
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneEmptyAggregates;
3839
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneFilters;
3940
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneLiteralsInOrderBy;
@@ -215,6 +216,7 @@ protected static Batch<LogicalPlan> operators() {
215216
new PruneFilters(),
216217
new PruneColumns(),
217218
new PruneLiteralsInOrderBy(),
219+
new PruneColumnsInForkBranches(),
218220
new PushDownAndCombineLimits(),
219221
new PushLimitToKnn(),
220222
new PushDownAndCombineFilters(),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneColumns.java

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@
2525
import org.elasticsearch.xpack.esql.plan.logical.Project;
2626
import org.elasticsearch.xpack.esql.plan.logical.Sample;
2727
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
28+
import org.elasticsearch.xpack.esql.plan.logical.UnionAll;
2829
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
2930
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
3031
import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier;
3132
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
3233
import org.elasticsearch.xpack.esql.rule.Rule;
3334

3435
import java.util.ArrayList;
36+
import java.util.HashSet;
3537
import java.util.List;
38+
import java.util.Set;
3639

3740
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneEmptyPlans.skipPlan;
3841

@@ -46,9 +49,8 @@ public LogicalPlan apply(LogicalPlan plan) {
4649
return pruneColumns(plan, plan.outputSet().asBuilder(), false);
4750
}
4851

49-
private static LogicalPlan pruneColumns(LogicalPlan plan, AttributeSet.Builder used, boolean inlineJoin) {
52+
static LogicalPlan pruneColumns(LogicalPlan plan, AttributeSet.Builder used, boolean inlineJoin) {
5053
Holder<Boolean> forkPresent = new Holder<>(false);
51-
5254
// while going top-to-bottom (upstream)
5355
return plan.transformDown(p -> {
5456
// Note: It is NOT required to do anything special for binary plans like JOINs, except INLINE STATS. It is perfectly fine that
@@ -58,17 +60,13 @@ private static LogicalPlan pruneColumns(LogicalPlan plan, AttributeSet.Builder u
5860
// same index fields will have different name ids in the left and right hand sides - as in the extreme example
5961
// `FROM lookup_idx | LOOKUP JOIN lookup_idx ON key_field`.
6062

61-
// TODO: revisit with every new command
62-
// skip nodes that simply pass the input through and use no references
63-
if (p instanceof Limit || p instanceof Sample) {
63+
if (forkPresent.get()) {
6464
return p;
6565
}
6666

67-
if (p instanceof Fork) {
68-
forkPresent.set(true);
69-
}
70-
// pruning columns for Fork branches can have the side effect of having misaligned outputs
71-
if (forkPresent.get()) {
67+
// TODO: revisit with every new command
68+
// skip nodes that simply pass the input through and use no references
69+
if (p instanceof Limit || p instanceof Sample) {
7270
return p;
7371
}
7472

@@ -83,6 +81,10 @@ private static LogicalPlan pruneColumns(LogicalPlan plan, AttributeSet.Builder u
8381
case Eval eval -> pruneColumnsInEval(eval, used, recheck);
8482
case Project project -> inlineJoin ? pruneColumnsInProject(project, used) : p;
8583
case EsRelation esr -> pruneColumnsInEsRelation(esr, used);
84+
case Fork fork -> {
85+
forkPresent.set(true);
86+
yield pruneColumnsInFork(fork, used);
87+
}
8688
default -> p;
8789
};
8890
} while (recheck.get());
@@ -94,7 +96,7 @@ private static LogicalPlan pruneColumns(LogicalPlan plan, AttributeSet.Builder u
9496
});
9597
}
9698

97-
private static LogicalPlan pruneColumnsInAggregate(Aggregate aggregate, AttributeSet.Builder used, boolean inlineJoin) {
99+
static LogicalPlan pruneColumnsInAggregate(Aggregate aggregate, AttributeSet.Builder used, boolean inlineJoin) {
98100
LogicalPlan p = aggregate;
99101

100102
var remaining = pruneUnusedAndAddReferences(aggregate.aggregates(), used);
@@ -134,7 +136,7 @@ private static LogicalPlan pruneColumnsInAggregate(Aggregate aggregate, Attribut
134136
return p;
135137
}
136138

137-
private static LogicalPlan pruneColumnsInInlineJoinRight(InlineJoin ij, AttributeSet.Builder used, Holder<Boolean> recheck) {
139+
static LogicalPlan pruneColumnsInInlineJoinRight(InlineJoin ij, AttributeSet.Builder used, Holder<Boolean> recheck) {
138140
LogicalPlan p = ij;
139141

140142
used.addAll(ij.references());
@@ -155,7 +157,7 @@ private static LogicalPlan pruneColumnsInInlineJoinRight(InlineJoin ij, Attribut
155157
return p;
156158
}
157159

158-
private static LogicalPlan pruneColumnsInEval(Eval eval, AttributeSet.Builder used, Holder<Boolean> recheck) {
160+
static LogicalPlan pruneColumnsInEval(Eval eval, AttributeSet.Builder used, Holder<Boolean> recheck) {
159161
LogicalPlan p = eval;
160162

161163
var remaining = pruneUnusedAndAddReferences(eval.fields(), used);
@@ -173,7 +175,7 @@ private static LogicalPlan pruneColumnsInEval(Eval eval, AttributeSet.Builder us
173175
}
174176

175177
// Note: only run when the Project is a descendent of an InlineJoin.
176-
private static LogicalPlan pruneColumnsInProject(Project project, AttributeSet.Builder used) {
178+
static LogicalPlan pruneColumnsInProject(Project project, AttributeSet.Builder used) {
177179
LogicalPlan p = project;
178180

179181
var remaining = pruneUnusedAndAddReferences(project.projections(), used);
@@ -184,7 +186,7 @@ private static LogicalPlan pruneColumnsInProject(Project project, AttributeSet.B
184186
return p;
185187
}
186188

187-
private static LogicalPlan pruneColumnsInEsRelation(EsRelation esr, AttributeSet.Builder used) {
189+
static LogicalPlan pruneColumnsInEsRelation(EsRelation esr, AttributeSet.Builder used) {
188190
LogicalPlan p = esr;
189191

190192
if (esr.indexMode() == IndexMode.LOOKUP) {
@@ -200,6 +202,36 @@ private static LogicalPlan pruneColumnsInEsRelation(EsRelation esr, AttributeSet
200202
return p;
201203
}
202204

205+
private static LogicalPlan pruneColumnsInFork(Fork fork, AttributeSet.Builder used) {
206+
// prune the output attributes of fork based on usage from the rest of the plan
207+
// this does not consider the inner usage within each branch of the fork
208+
// as those will be handled when traversing down each branch in PruneColumnsInForkBranches
209+
LogicalPlan p = fork;
210+
211+
// should exit early for UnionAll
212+
if (fork instanceof UnionAll) {
213+
return p;
214+
}
215+
boolean changed = false;
216+
AttributeSet.Builder builder = AttributeSet.builder();
217+
// if any of the fork outputs are used, keep them
218+
// otherwise, prune them based on the rest of the plan's usage
219+
Set<String> names = new HashSet<>(used.build().names());
220+
for (var attr : fork.output()) {
221+
// we should also ensure to keep any synthetic attributes around as those could still be used for internal processing
222+
if (attr.synthetic() || names.contains(attr.name())) {
223+
builder.add(attr);
224+
} else {
225+
changed = true;
226+
}
227+
}
228+
if (changed) {
229+
List<Attribute> attrs = builder.build().stream().toList();
230+
p = new Fork(fork.source(), fork.children(), attrs);
231+
}
232+
return p;
233+
}
234+
203235
private static LogicalPlan emptyLocalRelation(UnaryPlan plan) {
204236
// create an empty local relation with no attributes
205237
return skipPlan(plan);
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
11+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
12+
import org.elasticsearch.xpack.esql.core.util.Holder;
13+
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
14+
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
15+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
16+
import org.elasticsearch.xpack.esql.plan.logical.Fork;
17+
import org.elasticsearch.xpack.esql.plan.logical.Limit;
18+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
19+
import org.elasticsearch.xpack.esql.plan.logical.Project;
20+
import org.elasticsearch.xpack.esql.plan.logical.Sample;
21+
import org.elasticsearch.xpack.esql.plan.logical.UnionAll;
22+
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
23+
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
24+
import org.elasticsearch.xpack.esql.rule.Rule;
25+
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
import java.util.Set;
29+
import java.util.stream.Collectors;
30+
31+
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns.pruneColumnsInAggregate;
32+
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns.pruneColumnsInEsRelation;
33+
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns.pruneColumnsInEval;
34+
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneColumns.pruneColumnsInInlineJoinRight;
35+
36+
/**
37+
* This is used to prune unused columns and expressions in each branch of a Fork.
38+
* The output for each fork branch has already been pruned in {@code PruneColumns#pruneColumnsInFork}, so here we only need to
39+
* remove unused columns and expressions in the sub-plans of each branch, similarly to independently running {@code PruneColumns}.
40+
*/
41+
public final class PruneColumnsInForkBranches extends Rule<LogicalPlan, LogicalPlan> {
42+
43+
@Override
44+
public LogicalPlan apply(LogicalPlan plan) {
45+
46+
// collect used attributes from the plan above fork
47+
var used = AttributeSet.forkBuilder();
48+
var forkFound = new Holder<>(false);
49+
50+
// traverse down to the fork node
51+
return plan.transformDown(p -> {
52+
// if fork is not found yet, keep collecting used attributes from everything above.
53+
// Once fork is found, just return the rest of the plan as is, as any pruning/transformation will have
54+
// taken place in pruneSubPlan for each of the fork branches.
55+
if (false == (p instanceof Fork) || forkFound.get()) {
56+
if (false == forkFound.get()) {
57+
used.addAll(p.references());
58+
}
59+
return p;
60+
}
61+
62+
// only do this for fork
63+
if (p instanceof UnionAll) {
64+
return p;
65+
}
66+
67+
used.addAll(p.output());
68+
forkFound.set(true);
69+
var forkOutputNames = p.output().stream().map(NamedExpression::name).collect(Collectors.toSet());
70+
boolean changed = false;
71+
List<LogicalPlan> newChildren = new ArrayList<>();
72+
for (var child : p.children()) {
73+
var clonedUsed = AttributeSet.forkBuilder().addAll(used);
74+
var newChild = pruneSubPlan(child, clonedUsed, forkOutputNames);
75+
newChildren.add(newChild);
76+
if (false == newChild.equals(child)) {
77+
changed = true;
78+
}
79+
}
80+
if (changed) {
81+
return new Fork(p.source(), newChildren, p.output());
82+
} else {
83+
return p;
84+
}
85+
});
86+
}
87+
88+
private static LogicalPlan pruneSubPlan(LogicalPlan plan, AttributeSet.Builder usedAttrs, Set<String> forkOutput) {
89+
if (plan instanceof LocalRelation localRelation) {
90+
var outputAttrs = localRelation.output().stream().filter(usedAttrs::contains).collect(Collectors.toList());
91+
return new LocalRelation(localRelation.source(), outputAttrs, localRelation.supplier());
92+
}
93+
94+
var projectHolder = new Holder<>(false);
95+
return plan.transformDown(p -> {
96+
if (p instanceof Limit || p instanceof Sample) {
97+
return p;
98+
}
99+
100+
var recheck = new Holder<Boolean>();
101+
do {
102+
// we operate using the names of the fields, rather than comparing the attributes directly,
103+
// as attributes may have been recreated during the transformations of fork branches.
104+
recheck.set(false);
105+
p = switch (p) {
106+
case Aggregate agg -> pruneColumnsInAggregate(agg, usedAttrs, false);
107+
case InlineJoin inj -> pruneColumnsInInlineJoinRight(inj, usedAttrs, recheck);
108+
case Eval eval -> pruneColumnsInEval(eval, usedAttrs, recheck);
109+
case Project project -> {
110+
// process only the direct Project after Fork, but skip any subsequent instances
111+
if (projectHolder.get()) {
112+
yield p;
113+
} else {
114+
projectHolder.set(true);
115+
var prunedAttrs = project.projections().stream().filter(x -> forkOutput.contains(x.name())).toList();
116+
yield new Project(project.source(), project.child(), prunedAttrs);
117+
}
118+
}
119+
case EsRelation esr -> pruneColumnsInEsRelation(esr, usedAttrs);
120+
default -> p;
121+
};
122+
} while (recheck.get());
123+
usedAttrs.addAll(p.references());
124+
return p;
125+
});
126+
}
127+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Fork.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ public static Set<String> outputUnsupportedAttributeNames(List<LogicalPlan> subp
162162

163163
@Override
164164
public int hashCode() {
165-
return Objects.hash(Fork.class, children());
165+
return Objects.hash(Fork.class, output, children());
166166
}
167167

168168
@Override
@@ -175,7 +175,7 @@ public boolean equals(Object o) {
175175
}
176176
Fork other = (Fork) o;
177177

178-
return Objects.equals(children(), other.children());
178+
return Objects.equals(output, other.output) && Objects.equals(children(), other.children());
179179
}
180180

181181
@Override

0 commit comments

Comments
 (0)