diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java index fd5b7a3f04ee80..8a2ef3faa141e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CostBasedRewriteJob.java @@ -19,6 +19,7 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.cost.Cost; import org.apache.doris.nereids.hint.Hint; import org.apache.doris.nereids.hint.UseCboRuleHint; @@ -28,17 +29,21 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.CTEId; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -56,6 +61,11 @@ public CostBasedRewriteJob(List rewriteJobs) { // need to generate real rewrite job list } + private void restoreCteProducerMap(StatementContext context, Map currentCteProducers) { + context.getRewrittenCteProducer().clear(); + currentCteProducers.forEach(context.getRewrittenCteProducer()::put); + } + @Override public void execute(JobContext jobContext) { // checkHint.first means whether it use hint and checkHint.second means what kind of hint it used @@ -69,14 +79,21 @@ public void execute(JobContext jobContext) { CascadesContext applyCboRuleCtx = CascadesContext.newCurrentTreeContext(currentCtx); // execute cbo rule on one candidate Rewriter.getCteChildrenRewriter(applyCboRuleCtx, rewriteJobs).execute(); + Plan applyCboPlan = applyCboRuleCtx.getRewritePlan(); if (skipCboRuleCtx.getRewritePlan().deepEquals(applyCboRuleCtx.getRewritePlan())) { // this means rewrite do not do anything return; } + Map currentCteProducers = Maps.newHashMap(); + // cost based rewrite job may contaminate StatementContext.rewrittenCteProducer + // clone current rewrittenCteProducer, and restore it after getCost(.). + currentCtx.getStatementContext().getRewrittenCteProducer().forEach(currentCteProducers::put); // compare two candidates Optional> skipCboRuleCost = getCost(currentCtx, skipCboRuleCtx, jobContext); + restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers); Optional> appliedCboRuleCost = getCost(currentCtx, applyCboRuleCtx, jobContext); + restoreCteProducerMap(currentCtx.getStatementContext(), currentCteProducers); // If one of them optimize failed, just return if (!skipCboRuleCost.isPresent() || !appliedCboRuleCost.isPresent()) { LOG.warn("Cbo rewrite execute failed on sql: {}, jobs are {}, plan is {}.", @@ -94,8 +111,7 @@ public void execute(JobContext jobContext) { } // If the candidate applied cbo rule is better, replace the original plan with it. if (appliedCboRuleCost.get().first.getValue() < skipCboRuleCost.get().first.getValue()) { - currentCtx.addPlanProcesses(applyCboRuleCtx.getPlanProcesses()); - currentCtx.setRewritePlan(applyCboRuleCtx.getRewritePlan()); + currentCtx.setRewritePlan(applyCboPlan); } } diff --git a/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy b/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy new file mode 100644 index 00000000000000..128f35d7232756 --- /dev/null +++ b/regression-test/suites/nereids_p0/cte/costbasedrewrite_producer/costbasedrewrite_producer.groovy @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("costbasedrewrite_producer") { + sql """ + drop table if exists t1; + + create table t1(a1 int,b1 int) + properties("replication_num" = "1"); + + insert into t1 values(1,2); + + drop table if exists t2; + + create table t2(a2 int,b2 int) + properties("replication_num" = "1"); + + insert into t2 values(1,3); + """ + + sql""" + with cte1 as ( + select t1.a1, t1.b1 + from t1 + where t1.a1 > 0 and not exists (select distinct t2.b2 from t2 where t1.a1 = t2.a2 or t1.b1 = t2.a2) + ), + cte2 as ( + select * from cte1 union select * from cte1) + select * from cte2 join t1 on cte2.a1 = t1.a1; + + """ +}