Skip to content

Commit 5c7bc7a

Browse files
authored
[8.19] Make OptimizerExpressionRule conditional (#127753)
* Make OptimizerExpressionRule conditional (#127500) (cherry picked from commit 7d466c9) * replace pattern with instanceof * fix flakiness
1 parent e4f7f08 commit 5c7bc7a

File tree

4 files changed

+83
-35
lines changed

4 files changed

+83
-35
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,19 @@ protected void doCollectFirst(Predicate<? super T> predicate, List<T> matches) {
184184
public T transformDown(Function<? super T, ? extends T> rule) {
185185
T root = rule.apply((T) this);
186186
Node<T> node = this.equals(root) ? this : root;
187-
188187
return node.transformChildren(child -> child.transformDown(rule));
189188
}
190189

191190
@SuppressWarnings("unchecked")
192191
public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
193-
// type filtering function
194192
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
195193
}
196194

195+
@SuppressWarnings("unchecked")
196+
public <E extends T> T transformDown(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
197+
return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
198+
}
199+
197200
@SuppressWarnings("unchecked")
198201
public T transformUp(Function<? super T, ? extends T> rule) {
199202
T transformed = transformChildren(child -> child.transformUp(rule));
@@ -203,10 +206,14 @@ public T transformUp(Function<? super T, ? extends T> rule) {
203206

204207
@SuppressWarnings("unchecked")
205208
public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) {
206-
// type filtering function
207209
return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
208210
}
209211

212+
@SuppressWarnings("unchecked")
213+
public <E extends T> T transformUp(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
214+
return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
215+
}
216+
210217
@SuppressWarnings("unchecked")
211218
protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
212219
boolean childrenChanged = false;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
88

99
import org.elasticsearch.xpack.esql.core.expression.Expression;
10+
import org.elasticsearch.xpack.esql.core.tree.Node;
1011
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
1112
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
13+
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
14+
import org.elasticsearch.xpack.esql.plan.logical.Limit;
1215
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
16+
import org.elasticsearch.xpack.esql.plan.logical.Project;
1317
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
1418
import org.elasticsearch.xpack.esql.rule.Rule;
1519

@@ -55,12 +59,21 @@ public OptimizerExpressionRule(TransformDirection direction) {
5559
@Override
5660
public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
5761
return direction == TransformDirection.DOWN
58-
? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
59-
: plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
62+
? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
63+
: plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
6064
}
6165

6266
protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
6367

68+
/**
69+
* Defines if a node should be visited or not.
70+
* Allows to skip nodes that are not applicable for the rule even if they contain expressions.
71+
* By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses.
72+
*/
73+
protected boolean shouldVisit(Node<?> node) {
74+
return (node instanceof EsRelation || node instanceof Project || node instanceof Limit) == false;
75+
}
76+
6477
public Class<E> expressionToken() {
6578
return expressionTypeToken;
6679
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.function.Consumer;
2020
import java.util.function.Function;
21+
import java.util.function.Predicate;
2122

2223
/**
2324
* There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@@ -113,22 +114,36 @@ public <E extends Expression> PlanType transformExpressionsOnlyUp(Class<E> typeT
113114
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
114115
}
115116

116-
public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
117-
return transformExpressionsDown(Expression.class, rule);
118-
}
119-
120117
public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
121118
return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
122119
}
123120

124-
public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) {
125-
return transformExpressionsUp(Expression.class, rule);
121+
public <E extends Expression> PlanType transformExpressionsDown(
122+
Predicate<Node<?>> shouldVisit,
123+
Class<E> typeToken,
124+
Function<E, ? extends Expression> rule
125+
) {
126+
return transformDown(
127+
shouldVisit,
128+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
129+
);
126130
}
127131

128132
public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
129133
return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
130134
}
131135

136+
public <E extends Expression> PlanType transformExpressionsUp(
137+
Predicate<Node<?>> shouldVisit,
138+
Class<E> typeToken,
139+
Function<E, ? extends Expression> rule
140+
) {
141+
return transformUp(
142+
shouldVisit,
143+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
144+
);
145+
}
146+
132147
@SuppressWarnings("unchecked")
133148
private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
134149
if (arg instanceof Expression exp) {
@@ -188,18 +203,10 @@ public <E extends Expression> void forEachExpression(Class<E> typeToken, Consume
188203
forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
189204
}
190205

191-
public void forEachExpressionDown(Consumer<? super Expression> rule) {
192-
forEachExpressionDown(Expression.class, rule);
193-
}
194-
195206
public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
196207
forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
197208
}
198209

199-
public void forEachExpressionUp(Consumer<? super Expression> rule) {
200-
forEachExpressionUp(Expression.class, rule);
201-
}
202-
203210
public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
204211
forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
205212
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,38 @@
88

99
import org.elasticsearch.common.io.stream.StreamOutput;
1010
import org.elasticsearch.test.ESTestCase;
11+
import org.elasticsearch.xpack.esql.core.expression.Alias;
1112
import org.elasticsearch.xpack.esql.core.expression.Expression;
1213
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
1314
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1415
import org.elasticsearch.xpack.esql.core.expression.Literal;
1516
import org.elasticsearch.xpack.esql.core.expression.Nullability;
17+
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
1618
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1719
import org.elasticsearch.xpack.esql.core.tree.Source;
1820
import org.elasticsearch.xpack.esql.core.type.DataType;
19-
import org.elasticsearch.xpack.esql.core.util.TestUtils;
2021
import org.elasticsearch.xpack.esql.expression.predicate.Range;
22+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
23+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
24+
import org.elasticsearch.xpack.esql.parser.EsqlParser;
2125

2226
import java.io.IOException;
27+
import java.util.ArrayList;
2328
import java.util.Collections;
2429
import java.util.List;
2530

2631
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
2732
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
33+
import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute;
2834
import static org.elasticsearch.xpack.esql.core.util.TestUtils.of;
35+
import static org.hamcrest.Matchers.containsInAnyOrder;
2936

3037
public class OptimizerRulesTests extends ESTestCase {
3138

32-
private static final Literal FIVE = L(5);
33-
private static final Literal SIX = L(6);
39+
private static final Literal FIVE = of(5);
40+
private static final Literal SIX = of(6);
3441

35-
public static class DummyBooleanExpression extends Expression {
42+
public static final class DummyBooleanExpression extends Expression {
3643

3744
private final int id;
3845

@@ -87,21 +94,13 @@ public boolean equals(Object obj) {
8794
}
8895
}
8996

90-
private static Literal L(Object value) {
91-
return of(value);
92-
}
93-
94-
private static FieldAttribute getFieldAttribute() {
95-
return TestUtils.getFieldAttribute("a");
96-
}
97-
9897
//
9998
// Range optimization
10099
//
101100

102101
// 6 < a <= 5 -> FALSE
103102
public void testFoldExcludingRangeToFalse() {
104-
FieldAttribute fa = getFieldAttribute();
103+
FieldAttribute fa = getFieldAttribute("a");
105104

106105
Range r = rangeOf(fa, SIX, false, FIVE, true);
107106
assertTrue(r.foldable());
@@ -110,13 +109,35 @@ public void testFoldExcludingRangeToFalse() {
110109

111110
// 6 < a <= 5.5 -> FALSE
112111
public void testFoldExcludingRangeWithDifferentTypesToFalse() {
113-
FieldAttribute fa = getFieldAttribute();
112+
FieldAttribute fa = getFieldAttribute("a");
114113

115-
Range r = rangeOf(fa, SIX, false, L(5.5d), true);
114+
Range r = rangeOf(fa, SIX, false, of(5.5d), true);
116115
assertTrue(r.foldable());
117116
assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
118117
}
119118

120-
// Conjunction
119+
public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() {
120+
var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) {
121+
private final List<Expression> appliedTo = new ArrayList<>();
121122

123+
@Override
124+
protected Expression rule(Expression e, LogicalOptimizerContext ctx) {
125+
appliedTo.add(e);
126+
return e;
127+
}
128+
};
129+
130+
rule.apply(
131+
new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"),
132+
new LogicalOptimizerContext(null, FoldContext.small())
133+
);
134+
135+
var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER);
136+
var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1");
137+
var add = new Add(new Source(1, 20, "f1+1"), attribute, literal);
138+
var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add);
139+
140+
// contains expressions only from EVAL
141+
assertThat(rule.appliedTo, containsInAnyOrder(alias, add, attribute, literal));
142+
}
122143
}

0 commit comments

Comments
 (0)