Skip to content

Commit 22de40b

Browse files
committed
Make OptimizerExpressionRule conditional
1 parent 85d375c commit 22de40b

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ protected void doCollectFirst(Predicate<? super T> predicate, List<T> matches) {
184184
public T transformDown(Function<? super T, ? extends T> rule) {
185185
T root = rule.apply((T) this);
186186
Node<T> node = this.equals(root) ? this : root;
187-
188187
return node.transformChildren(child -> child.transformDown(rule));
189188
}
190189

@@ -194,6 +193,12 @@ public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T
194193
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
195194
}
196195

196+
@SuppressWarnings("unchecked")
197+
public <E extends T> T transformDown(Predicate<Node<?>> tokenPredicate, Function<E, ? extends T> rule) {
198+
// type filtering function
199+
return transformDown((t) -> (tokenPredicate.test(t) ? rule.apply((E) t) : t));
200+
}
201+
197202
@SuppressWarnings("unchecked")
198203
public T transformUp(Function<? super T, ? extends T> rule) {
199204
T transformed = transformChildren(child -> child.transformUp(rule));
@@ -207,6 +212,12 @@ public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T>
207212
return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
208213
}
209214

215+
@SuppressWarnings("unchecked")
216+
public <E extends T> T transformUp(Predicate<Node<?>> tokenPredicate, Function<E, ? extends T> rule) {
217+
// type filtering function
218+
return transformUp((t) -> (tokenPredicate.test(t) ? rule.apply((E) t) : t));
219+
}
220+
210221
@SuppressWarnings("unchecked")
211222
protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
212223
boolean childrenChanged = false;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
88

99
import org.elasticsearch.xpack.esql.core.expression.Expression;
10+
import org.elasticsearch.xpack.esql.core.tree.Node;
1011
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
1112
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
13+
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
1214
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
15+
import org.elasticsearch.xpack.esql.plan.logical.Project;
1316
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
1417
import org.elasticsearch.xpack.esql.rule.Rule;
1518

@@ -55,12 +58,20 @@ public OptimizerExpressionRule(TransformDirection direction) {
5558
@Override
5659
public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
5760
return direction == TransformDirection.DOWN
58-
? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
59-
: plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
61+
? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
62+
: plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
6063
}
6164

6265
protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
6366

67+
protected boolean shouldVisit(Node<?> node) {
68+
return switch (node) {
69+
case EsRelation esr -> false;
70+
case Project p -> false;// this covers both keep and project
71+
default -> true;
72+
};
73+
}
74+
6475
public Class<E> expressionToken() {
6576
return expressionTypeToken;
6677
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.function.Consumer;
2020
import java.util.function.Function;
21+
import java.util.function.Predicate;
2122

2223
/**
2324
* There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@@ -109,22 +110,36 @@ public <E extends Expression> PlanType transformExpressionsOnlyUp(Class<E> typeT
109110
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
110111
}
111112

112-
public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
113-
return transformExpressionsDown(Expression.class, rule);
114-
}
115-
116113
public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
117114
return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
118115
}
119116

120-
public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) {
121-
return transformExpressionsUp(Expression.class, rule);
117+
public <E extends Expression> PlanType transformExpressionsDown(
118+
Predicate<Node<?>> shouldVisit,
119+
Class<E> typeToken,
120+
Function<E, ? extends Expression> rule
121+
) {
122+
return transformDown(
123+
shouldVisit,
124+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
125+
);
122126
}
123127

124128
public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
125129
return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
126130
}
127131

132+
public <E extends Expression> PlanType transformExpressionsUp(
133+
Predicate<Node<?>> shouldVisit,
134+
Class<E> typeToken,
135+
Function<E, ? extends Expression> rule
136+
) {
137+
return transformUp(
138+
shouldVisit,
139+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
140+
);
141+
}
142+
128143
@SuppressWarnings("unchecked")
129144
private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
130145
if (arg instanceof Expression exp) {
@@ -184,18 +199,10 @@ public <E extends Expression> void forEachExpression(Class<E> typeToken, Consume
184199
forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
185200
}
186201

187-
public void forEachExpressionDown(Consumer<? super Expression> rule) {
188-
forEachExpressionDown(Expression.class, rule);
189-
}
190-
191202
public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
192203
forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
193204
}
194205

195-
public void forEachExpressionUp(Consumer<? super Expression> rule) {
196-
forEachExpressionUp(Expression.class, rule);
197-
}
198-
199206
public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
200207
forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
201208
}

0 commit comments

Comments
 (0)