Skip to content

Commit 4438cb1

Browse files
committed
restore StatementContext.rewrittenCteProducer in CostBasedRewriteJob
1 parent aa2eb58 commit 4438cb1

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.doris.common.Pair;
2121
import org.apache.doris.nereids.CascadesContext;
22+
import org.apache.doris.nereids.StatementContext;
2223
import org.apache.doris.nereids.cost.Cost;
2324
import org.apache.doris.nereids.hint.Hint;
2425
import org.apache.doris.nereids.hint.UseCboRuleHint;
@@ -28,17 +29,21 @@
2829
import org.apache.doris.nereids.memo.GroupExpression;
2930
import org.apache.doris.nereids.rules.Rule;
3031
import org.apache.doris.nereids.rules.RuleType;
32+
import org.apache.doris.nereids.trees.expressions.CTEId;
33+
import org.apache.doris.nereids.trees.plans.Plan;
3134
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
3235
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
3336
import org.apache.doris.qe.ConnectContext;
3437

3538
import com.google.common.collect.ImmutableList;
39+
import com.google.common.collect.Maps;
3640
import org.apache.logging.log4j.LogManager;
3741
import org.apache.logging.log4j.Logger;
3842

3943
import java.util.ArrayList;
4044
import java.util.Arrays;
4145
import java.util.List;
46+
import java.util.Map;
4247
import java.util.Optional;
4348

4449
/**
@@ -56,6 +61,11 @@ public CostBasedRewriteJob(List<RewriteJob> rewriteJobs) {
5661
// need to generate real rewrite job list
5762
}
5863

64+
private void restoreCteProducerMap(StatementContext context, Map<CTEId, LogicalPlan> currentCteProducers) {
65+
context.getRewrittenCteProducer().clear();
66+
currentCteProducers.forEach(context.getRewrittenCteProducer()::put);
67+
}
68+
5969
@Override
6070
public void execute(JobContext jobContext) {
6171
// checkHint.first means whether it use hint and checkHint.second means what kind of hint it used
@@ -69,14 +79,21 @@ public void execute(JobContext jobContext) {
6979
CascadesContext applyCboRuleCtx = CascadesContext.newCurrentTreeContext(currentCtx);
7080
// execute cbo rule on one candidate
7181
Rewriter.getCteChildrenRewriter(applyCboRuleCtx, rewriteJobs).execute();
82+
Plan applyCboPlan = applyCboRuleCtx.getRewritePlan();
7283
if (skipCboRuleCtx.getRewritePlan().deepEquals(applyCboRuleCtx.getRewritePlan())) {
7384
// this means rewrite do not do anything
7485
return;
7586
}
7687

88+
Map<CTEId, LogicalPlan> currentCteProducers = Maps.newHashMap();
89+
// cost based rewrite job may contaminate StatementContext.rewrittenCteProducer
90+
// clone current rewrittenCteProducer, and restore it after getCost(.).
91+
currentCtx.getStatementContext().getRewrittenCteProducer().forEach(currentCteProducers::put);
7792
// compare two candidates
7893
Optional<Pair<Cost, GroupExpression>> skipCboRuleCost = getCost(currentCtx, skipCboRuleCtx, jobContext);
94+
restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers);
7995
Optional<Pair<Cost, GroupExpression>> appliedCboRuleCost = getCost(currentCtx, applyCboRuleCtx, jobContext);
96+
restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers);
8097
// If one of them optimize failed, just return
8198
if (!skipCboRuleCost.isPresent() || !appliedCboRuleCost.isPresent()) {
8299
LOG.warn("Cbo rewrite execute failed on sql: {}, jobs are {}, plan is {}.",
@@ -94,8 +111,7 @@ public void execute(JobContext jobContext) {
94111
}
95112
// If the candidate applied cbo rule is better, replace the original plan with it.
96113
if (appliedCboRuleCost.get().first.getValue() < skipCboRuleCost.get().first.getValue()) {
97-
currentCtx.addPlanProcesses(applyCboRuleCtx.getPlanProcesses());
98-
currentCtx.setRewritePlan(applyCboRuleCtx.getRewritePlan());
114+
currentCtx.setRewritePlan(applyCboPlan);
99115
}
100116
}
101117

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
suite("costbasedrewrite_producer") {
18+
sql """
19+
drop table if exists t1;
20+
21+
create table t1(a1 int,b1 int)
22+
properties("replication_num" = "1");
23+
24+
insert into t1 values(1,2);
25+
26+
drop table if exists t2;
27+
28+
create table t2(a2 int,b2 int)
29+
properties("replication_num" = "1");
30+
31+
insert into t2 values(1,3);
32+
"""
33+
34+
sql"""
35+
with cte1 as (
36+
select t1.a1, t1.b1
37+
from t1
38+
where t1.a1 > 0 and not exists (select distinct t2.b2 from t2 where t1.a1 = t2.a2 or t1.b1 = t2.a2)
39+
),
40+
cte2 as (
41+
select * from cte1 union select * from cte1)
42+
select * from cte2 join t1 on cte2.a1 = t1.a1;
43+
44+
"""
45+
}

0 commit comments

Comments
 (0)