Skip to content

Commit 073c151

Browse files
committed
restore StatementContext.rewrittenCteProducer in CostBasedRewriteJob
1 parent aa2eb58 commit 073c151

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.doris.common.Pair;
2121
import org.apache.doris.nereids.CascadesContext;
22+
import org.apache.doris.nereids.StatementContext;
2223
import org.apache.doris.nereids.cost.Cost;
2324
import org.apache.doris.nereids.hint.Hint;
2425
import org.apache.doris.nereids.hint.UseCboRuleHint;
@@ -28,17 +29,20 @@
2829
import org.apache.doris.nereids.memo.GroupExpression;
2930
import org.apache.doris.nereids.rules.Rule;
3031
import org.apache.doris.nereids.rules.RuleType;
32+
import org.apache.doris.nereids.trees.expressions.CTEId;
3133
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
3234
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
3335
import org.apache.doris.qe.ConnectContext;
3436

3537
import com.google.common.collect.ImmutableList;
38+
import com.google.common.collect.Maps;
3639
import org.apache.logging.log4j.LogManager;
3740
import org.apache.logging.log4j.Logger;
3841

3942
import java.util.ArrayList;
4043
import java.util.Arrays;
4144
import java.util.List;
45+
import java.util.Map;
4246
import java.util.Optional;
4347

4448
/**
@@ -56,6 +60,11 @@ public CostBasedRewriteJob(List<RewriteJob> rewriteJobs) {
5660
// need to generate real rewrite job list
5761
}
5862

63+
private void restoreCteProducerMap(StatementContext context, Map<CTEId, LogicalPlan> currentCteProducers) {
64+
context.getRewrittenCteProducer().clear();
65+
currentCteProducers.forEach(context.getRewrittenCteProducer()::put);
66+
}
67+
5968
@Override
6069
public void execute(JobContext jobContext) {
6170
// checkHint.first means whether it use hint and checkHint.second means what kind of hint it used
@@ -74,9 +83,15 @@ public void execute(JobContext jobContext) {
7483
return;
7584
}
7685

86+
Map<CTEId, LogicalPlan> currentCteProducers = Maps.newHashMap();
87+
// cost based rewrite job may contaminate StatementContext.rewrittenCteProducer
88+
// clone current rewrittenCteProducer, and restore it after getCost(.).
89+
currentCtx.getStatementContext().getRewrittenCteProducer().forEach(currentCteProducers::put);
7790
// compare two candidates
7891
Optional<Pair<Cost, GroupExpression>> skipCboRuleCost = getCost(currentCtx, skipCboRuleCtx, jobContext);
92+
restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers);
7993
Optional<Pair<Cost, GroupExpression>> appliedCboRuleCost = getCost(currentCtx, applyCboRuleCtx, jobContext);
94+
restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers);
8095
// If one of them optimize failed, just return
8196
if (!skipCboRuleCost.isPresent() || !appliedCboRuleCost.isPresent()) {
8297
LOG.warn("Cbo rewrite execute failed on sql: {}, jobs are {}, plan is {}.",
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
suite("costbasedrewrite_producer") {
18+
sql """
19+
drop table if exists t1;
20+
21+
create table t1(a1 int,b1 int)
22+
properties("replication_num" = "1");
23+
24+
insert into t1 values(1,2);
25+
26+
drop table if exists t2;
27+
28+
create table t2(a2 int,b2 int)
29+
properties("replication_num" = "1");
30+
31+
insert into t2 values(1,3);
32+
"""
33+
34+
sql"""
35+
with cte1 as (
36+
select t1.a1, t1.b1
37+
from t1
38+
where t1.a1 > 0 and not exists (select distinct t2.b2 from t2 where t1.a1 = t2.a2 or t1.b1 = t2.a2)
39+
),
40+
cte2 as (
41+
select * from cte1 union select * from cte1)
42+
select * from cte2 join t1 on cte2.a1 = t1.a1;
43+
44+
"""
45+
}

0 commit comments

Comments
 (0)