Skip to content

Commit 6e09cd3

Browse files
idegtiarenkojfreden
authored andcommitted
Make OptimizerExpressionRule conditional (elastic#127500)
1 parent 714d30b commit 6e09cd3

File tree

5 files changed

+89
-36
lines changed

5 files changed

+89
-36
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/esql/QueryPlanningBenchmark.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ private LogicalPlan plan(String query) {
123123
}
124124

125125
@Benchmark
126-
public void run(Blackhole blackhole) {
126+
public void manyFields(Blackhole blackhole) {
127127
blackhole.consume(plan("FROM test | LIMIT 10"));
128128
}
129129
}

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,19 @@ protected void doCollectFirst(Predicate<? super T> predicate, List<T> matches) {
184184
public T transformDown(Function<? super T, ? extends T> rule) {
185185
T root = rule.apply((T) this);
186186
Node<T> node = this.equals(root) ? this : root;
187-
188187
return node.transformChildren(child -> child.transformDown(rule));
189188
}
190189

191190
@SuppressWarnings("unchecked")
192191
public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
193-
// type filtering function
194192
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
195193
}
196194

195+
@SuppressWarnings("unchecked")
196+
public <E extends T> T transformDown(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
197+
return transformDown((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
198+
}
199+
197200
@SuppressWarnings("unchecked")
198201
public T transformUp(Function<? super T, ? extends T> rule) {
199202
T transformed = transformChildren(child -> child.transformUp(rule));
@@ -203,10 +206,14 @@ public T transformUp(Function<? super T, ? extends T> rule) {
203206

204207
@SuppressWarnings("unchecked")
205208
public <E extends T> T transformUp(Class<E> typeToken, Function<E, ? extends T> rule) {
206-
// type filtering function
207209
return transformUp((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));
208210
}
209211

212+
@SuppressWarnings("unchecked")
213+
public <E extends T> T transformUp(Predicate<Node<?>> nodePredicate, Function<E, ? extends T> rule) {
214+
return transformUp((t) -> (nodePredicate.test(t) ? rule.apply((E) t) : t));
215+
}
216+
210217
@SuppressWarnings("unchecked")
211218
protected <R extends Function<? super T, ? extends T>> T transformChildren(Function<T, ? extends T> traversalOperation) {
212219
boolean childrenChanged = false;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
88

99
import org.elasticsearch.xpack.esql.core.expression.Expression;
10+
import org.elasticsearch.xpack.esql.core.tree.Node;
1011
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
1112
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
13+
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
14+
import org.elasticsearch.xpack.esql.plan.logical.Limit;
1215
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
16+
import org.elasticsearch.xpack.esql.plan.logical.Project;
1317
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
1418
import org.elasticsearch.xpack.esql.rule.Rule;
1519

@@ -55,12 +59,26 @@ public OptimizerExpressionRule(TransformDirection direction) {
5559
@Override
5660
public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
5761
return direction == TransformDirection.DOWN
58-
? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
59-
: plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
62+
? plan.transformExpressionsDown(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx))
63+
: plan.transformExpressionsUp(this::shouldVisit, expressionTypeToken, e -> rule(e, ctx));
6064
}
6165

6266
protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
6367

68+
/**
69+
* Defines if a node should be visited or not.
70+
* Allows to skip nodes that are not applicable for the rule even if they contain expressions.
71+
* By default that skips FROM, LIMIT, PROJECT, KEEP and DROP but this list could be extended or replaced in subclasses.
72+
*/
73+
protected boolean shouldVisit(Node<?> node) {
74+
return switch (node) {
75+
case EsRelation relation -> false;
76+
case Project project -> false;// this covers project, keep and drop
77+
case Limit limit -> false;
78+
default -> true;
79+
};
80+
}
81+
6482
public Class<E> expressionToken() {
6583
return expressionTypeToken;
6684
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.function.Consumer;
2020
import java.util.function.Function;
21+
import java.util.function.Predicate;
2122

2223
/**
2324
* There are two main types of plans, {@code LogicalPlan} and {@code PhysicalPlan}
@@ -109,22 +110,36 @@ public <E extends Expression> PlanType transformExpressionsOnlyUp(Class<E> typeT
109110
return transformPropertiesOnly(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
110111
}
111112

112-
public PlanType transformExpressionsDown(Function<Expression, ? extends Expression> rule) {
113-
return transformExpressionsDown(Expression.class, rule);
114-
}
115-
116113
public <E extends Expression> PlanType transformExpressionsDown(Class<E> typeToken, Function<E, ? extends Expression> rule) {
117114
return transformPropertiesDown(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)));
118115
}
119116

120-
public PlanType transformExpressionsUp(Function<Expression, ? extends Expression> rule) {
121-
return transformExpressionsUp(Expression.class, rule);
117+
public <E extends Expression> PlanType transformExpressionsDown(
118+
Predicate<Node<?>> shouldVisit,
119+
Class<E> typeToken,
120+
Function<E, ? extends Expression> rule
121+
) {
122+
return transformDown(
123+
shouldVisit,
124+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformDown(typeToken, rule)))
125+
);
122126
}
123127

124128
public <E extends Expression> PlanType transformExpressionsUp(Class<E> typeToken, Function<E, ? extends Expression> rule) {
125129
return transformPropertiesUp(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)));
126130
}
127131

132+
public <E extends Expression> PlanType transformExpressionsUp(
133+
Predicate<Node<?>> shouldVisit,
134+
Class<E> typeToken,
135+
Function<E, ? extends Expression> rule
136+
) {
137+
return transformUp(
138+
shouldVisit,
139+
t -> t.transformNodeProps(Object.class, e -> doTransformExpression(e, exp -> exp.transformUp(typeToken, rule)))
140+
);
141+
}
142+
128143
@SuppressWarnings("unchecked")
129144
private static Object doTransformExpression(Object arg, Function<Expression, ? extends Expression> traversal) {
130145
if (arg instanceof Expression exp) {
@@ -184,18 +199,10 @@ public <E extends Expression> void forEachExpression(Class<E> typeToken, Consume
184199
forEachPropertyOnly(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
185200
}
186201

187-
public void forEachExpressionDown(Consumer<? super Expression> rule) {
188-
forEachExpressionDown(Expression.class, rule);
189-
}
190-
191202
public <E extends Expression> void forEachExpressionDown(Class<? extends E> typeToken, Consumer<? super E> rule) {
192203
forEachPropertyDown(Object.class, e -> doForEachExpression(e, exp -> exp.forEachDown(typeToken, rule)));
193204
}
194205

195-
public void forEachExpressionUp(Consumer<? super Expression> rule) {
196-
forEachExpressionUp(Expression.class, rule);
197-
}
198-
199206
public <E extends Expression> void forEachExpressionUp(Class<E> typeToken, Consumer<? super E> rule) {
200207
forEachPropertyUp(Object.class, e -> doForEachExpression(e, exp -> exp.forEachUp(typeToken, rule)));
201208
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,38 @@
88

99
import org.elasticsearch.common.io.stream.StreamOutput;
1010
import org.elasticsearch.test.ESTestCase;
11+
import org.elasticsearch.xpack.esql.core.expression.Alias;
1112
import org.elasticsearch.xpack.esql.core.expression.Expression;
1213
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
1314
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
1415
import org.elasticsearch.xpack.esql.core.expression.Literal;
1516
import org.elasticsearch.xpack.esql.core.expression.Nullability;
17+
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
1618
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1719
import org.elasticsearch.xpack.esql.core.tree.Source;
1820
import org.elasticsearch.xpack.esql.core.type.DataType;
19-
import org.elasticsearch.xpack.esql.core.util.TestUtils;
2021
import org.elasticsearch.xpack.esql.expression.predicate.Range;
22+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
23+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
24+
import org.elasticsearch.xpack.esql.parser.EsqlParser;
2125

2226
import java.io.IOException;
27+
import java.util.ArrayList;
2328
import java.util.Collections;
2429
import java.util.List;
2530

2631
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
2732
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
33+
import static org.elasticsearch.xpack.esql.core.util.TestUtils.getFieldAttribute;
2834
import static org.elasticsearch.xpack.esql.core.util.TestUtils.of;
35+
import static org.hamcrest.Matchers.contains;
2936

3037
public class OptimizerRulesTests extends ESTestCase {
3138

32-
private static final Literal FIVE = L(5);
33-
private static final Literal SIX = L(6);
39+
private static final Literal FIVE = of(5);
40+
private static final Literal SIX = of(6);
3441

35-
public static class DummyBooleanExpression extends Expression {
42+
public static final class DummyBooleanExpression extends Expression {
3643

3744
private final int id;
3845

@@ -87,21 +94,13 @@ public boolean equals(Object obj) {
8794
}
8895
}
8996

90-
private static Literal L(Object value) {
91-
return of(value);
92-
}
93-
94-
private static FieldAttribute getFieldAttribute() {
95-
return TestUtils.getFieldAttribute("a");
96-
}
97-
9897
//
9998
// Range optimization
10099
//
101100

102101
// 6 < a <= 5 -> FALSE
103102
public void testFoldExcludingRangeToFalse() {
104-
FieldAttribute fa = getFieldAttribute();
103+
FieldAttribute fa = getFieldAttribute("a");
105104

106105
Range r = rangeOf(fa, SIX, false, FIVE, true);
107106
assertTrue(r.foldable());
@@ -110,13 +109,35 @@ public void testFoldExcludingRangeToFalse() {
110109

111110
// 6 < a <= 5.5 -> FALSE
112111
public void testFoldExcludingRangeWithDifferentTypesToFalse() {
113-
FieldAttribute fa = getFieldAttribute();
112+
FieldAttribute fa = getFieldAttribute("a");
114113

115-
Range r = rangeOf(fa, SIX, false, L(5.5d), true);
114+
Range r = rangeOf(fa, SIX, false, of(5.5d), true);
116115
assertTrue(r.foldable());
117116
assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
118117
}
119118

120-
// Conjunction
119+
public void testOptimizerExpressionRuleShouldNotVisitExcludedNodes() {
120+
var rule = new OptimizerRules.OptimizerExpressionRule<>(randomFrom(OptimizerRules.TransformDirection.values())) {
121+
private final List<Expression> appliedTo = new ArrayList<>();
121122

123+
@Override
124+
protected Expression rule(Expression e, LogicalOptimizerContext ctx) {
125+
appliedTo.add(e);
126+
return e;
127+
}
128+
};
129+
130+
rule.apply(
131+
new EsqlParser().createStatement("FROM index | EVAL x=f1+1 | KEEP x, f2 | LIMIT 1"),
132+
new LogicalOptimizerContext(null, FoldContext.small())
133+
);
134+
135+
var literal = new Literal(new Source(1, 25, "1"), 1, DataType.INTEGER);
136+
var attribute = new UnresolvedAttribute(new Source(1, 20, "f1"), "f1");
137+
var add = new Add(new Source(1, 20, "f1+1"), attribute, literal);
138+
var alias = new Alias(new Source(1, 18, "x=f1+1"), "x", add);
139+
140+
// contains expressions only from EVAL
141+
assertThat(rule.appliedTo, contains(alias, add, attribute, literal));
142+
}
122143
}

0 commit comments

Comments
 (0)