diff --git a/src/query/service/src/pipelines/processors/transforms/transform_recursive_cte_source.rs b/src/query/service/src/pipelines/processors/transforms/transform_recursive_cte_source.rs index 5b197d07e21e6..0cbb2752a118d 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_recursive_cte_source.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_recursive_cte_source.rs @@ -15,6 +15,7 @@ use std::any::Any; use std::collections::BTreeMap; use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; @@ -84,9 +85,47 @@ impl TransformRecursiveCteSource { union_plan: UnionAll, ) -> Result { let mut union_plan = union_plan; + + // Recursive CTE uses internal MEMORY tables addressed by name in the current database. + // If we keep using the stable scan name (cte name/alias), concurrent queries can interfere + // by creating/dropping/recreating the same table name, leading to wrong or flaky results. + // + // Make the internal table names query-unique by prefixing them with the query id. + // This is purely internal and does not change user-visible semantics. + let rcte_prefix = make_rcte_prefix(&ctx.get_id()); + let local_cte_scan_names = { + let names = collect_local_recursive_scan_names(&union_plan.right); + if names.is_empty() { + union_plan.cte_scan_names.clone() + } else { + names + } + }; + if union_plan.cte_scan_names != local_cte_scan_names { + union_plan.cte_scan_names = local_cte_scan_names; + } + let local_cte_scan_name_set: HashSet<&str> = union_plan + .cte_scan_names + .iter() + .map(String::as_str) + .collect(); + let mut exec_ids: HashMap> = HashMap::new(); - assign_exec_ids(&mut union_plan.left, &mut exec_ids); - assign_exec_ids(&mut union_plan.right, &mut exec_ids); + rewrite_assign_and_strip_recursive_cte( + &mut union_plan.left, + &local_cte_scan_name_set, + &rcte_prefix, + &mut exec_ids, + ); + rewrite_assign_and_strip_recursive_cte( + &mut union_plan.right, + &local_cte_scan_name_set, + &rcte_prefix, + &mut exec_ids, + ); + for name in union_plan.cte_scan_names.iter_mut() { + *name = format!("{rcte_prefix}{name}"); + } let left_outputs = union_plan .left_outputs @@ -134,6 +173,8 @@ impl TransformRecursiveCteSource { if ctx.get_settings().get_max_cte_recursive_depth()? < recursive_step { return Err(ErrorCode::Internal("Recursive depth is reached")); } + #[cfg(debug_assertions)] + crate::test_kits::rcte_hooks::maybe_pause_before_step(&ctx.get_id(), recursive_step).await; let mut cte_scan_tables = vec![]; let plan = if recursive_step == 0 { // Find all cte scan in the union right child plan, then create memory table for them. @@ -172,6 +213,88 @@ impl TransformRecursiveCteSource { } } +fn make_rcte_prefix(query_id: &str) -> String { + // Keep it readable and safe as an identifier. + // Use enough entropy to be effectively unique for concurrent queries. + let mut short = String::with_capacity(32); + for ch in query_id.chars() { + if ch.is_ascii_alphanumeric() { + short.push(ch); + } + if short.len() >= 32 { + break; + } + } + if short.is_empty() { + short.push_str("unknown"); + } + format!("__rcte_{short}_") +} + +fn rewrite_assign_and_strip_recursive_cte( + plan: &mut PhysicalPlan, + local_cte_scan_name_set: &HashSet<&str>, + prefix: &str, + exec_ids: &mut HashMap>, +) { + // Only nested recursive UNION nodes that reference the current recursive CTE should be + // downgraded to normal unions to avoid nested recursive sources for the same table. + if let Some(union_all) = UnionAll::from_mut_physical_plan(plan) { + if !union_all.cte_scan_names.is_empty() + && union_all + .cte_scan_names + .iter() + .all(|name| local_cte_scan_name_set.contains(name.as_str())) + { + union_all.cte_scan_names.clear(); + } + } + + if let Some(recursive_cte_scan) = RecursiveCteScan::from_mut_physical_plan(plan) { + if local_cte_scan_name_set.contains(recursive_cte_scan.table_name.as_str()) { + recursive_cte_scan.table_name = format!("{prefix}{}", recursive_cte_scan.table_name); + let id = NEXT_R_CTE_ID.fetch_add(1, Ordering::Relaxed); + recursive_cte_scan.exec_id = Some(id); + exec_ids + .entry(recursive_cte_scan.table_name.clone()) + .or_default() + .push(id); + } + } + + for child in plan.children_mut() { + rewrite_assign_and_strip_recursive_cte(child, local_cte_scan_name_set, prefix, exec_ids); + } +} + +fn collect_local_recursive_scan_names(plan: &PhysicalPlan) -> Vec { + fn walk(plan: &PhysicalPlan, names: &mut Vec, seen: &mut HashSet) { + // Nested recursive unions belong to other recursive CTEs. Leave them to their own + // TransformRecursiveCteSource instance. + if let Some(union_all) = UnionAll::from_physical_plan(plan) { + if !union_all.cte_scan_names.is_empty() { + return; + } + } + + if let Some(recursive_cte_scan) = RecursiveCteScan::from_physical_plan(plan) { + if seen.insert(recursive_cte_scan.table_name.clone()) { + names.push(recursive_cte_scan.table_name.clone()); + } + return; + } + + for child in plan.children() { + walk(child, names, seen); + } + } + + let mut names = Vec::new(); + let mut seen = HashSet::new(); + walk(plan, &mut names, &mut seen); + names +} + #[async_trait::async_trait] impl AsyncSource for TransformRecursiveCteSource { const NAME: &'static str = "TransformRecursiveCteSource"; @@ -236,21 +359,6 @@ impl AsyncSource for TransformRecursiveCteSource { } } -fn assign_exec_ids(plan: &mut PhysicalPlan, mapping: &mut HashMap>) { - if let Some(recursive_cte_scan) = RecursiveCteScan::from_mut_physical_plan(plan) { - let id = NEXT_R_CTE_ID.fetch_add(1, Ordering::Relaxed); - recursive_cte_scan.exec_id = Some(id); - mapping - .entry(recursive_cte_scan.table_name.clone()) - .or_default() - .push(id); - } - - for child in plan.children_mut() { - assign_exec_ids(child, mapping); - } -} - async fn drop_tables(ctx: Arc, table_names: Vec) -> Result<()> { for table_name in table_names { let drop_table_plan = DropTablePlan { @@ -311,7 +419,6 @@ async fn create_memory_table_for_cte_scan( let mut options = BTreeMap::new(); options.insert(OPT_KEY_RECURSIVE_CTE.to_string(), "1".to_string()); - self.plans.push(CreateTablePlan { schema, create_option: CreateOption::CreateIfNotExists, diff --git a/src/query/service/src/test_kits/mod.rs b/src/query/service/src/test_kits/mod.rs index d4679b553968c..acdca6d8a9afd 100644 --- a/src/query/service/src/test_kits/mod.rs +++ b/src/query/service/src/test_kits/mod.rs @@ -31,3 +31,4 @@ pub use config::config_with_spill; pub use context::*; pub use fixture::*; pub use fuse::*; +pub mod rcte_hooks; diff --git a/src/query/service/src/test_kits/rcte_hooks.rs b/src/query/service/src/test_kits/rcte_hooks.rs new file mode 100644 index 0000000000000..6847014912e8b --- /dev/null +++ b/src/query/service/src/test_kits/rcte_hooks.rs @@ -0,0 +1,180 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Test-only hooks for recursive CTE execution. +//! +//! This module is intended to make race conditions reproducible by providing +//! deterministic pause/resume points in the recursive CTE executor. +//! +//! By default no hooks are installed and the hook checks are no-ops. + +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::OnceLock; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +use tokio::sync::Notify; + +static HOOKS: OnceLock> = OnceLock::new(); + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct GateKey { + query_id: String, + step: usize, +} + +impl GateKey { + fn new(query_id: &str, step: usize) -> Self { + Self { + query_id: query_id.to_string(), + step, + } + } +} + +#[derive(Default)] +pub struct RcteHookRegistry { + gates: Mutex>>, +} + +impl RcteHookRegistry { + pub fn global() -> Arc { + HOOKS + .get_or_init(|| Arc::new(RcteHookRegistry::default())) + .clone() + } + + pub fn install_pause_before_step(&self, query_id: &str, step: usize) -> Arc { + let mut gates = self.gates.lock().unwrap(); + let key = GateKey::new(query_id, step); + gates + .entry(key) + .or_insert_with(|| Arc::new(PauseGate::new(step))) + .clone() + } + + fn get_gate(&self, query_id: &str, step: usize) -> Option> { + let key = GateKey::new(query_id, step); + self.gates.lock().unwrap().get(&key).cloned() + } +} + +/// A reusable pause gate for a single step number. +/// +/// When the code hits the hook point, it increments `arrived` and blocks until +/// the test releases the same hit index via `release(hit_no)`. +pub struct PauseGate { + step: usize, + arrived: AtomicUsize, + released: AtomicUsize, + arrived_notify: Notify, + released_notify: Notify, +} + +impl PauseGate { + fn new(step: usize) -> Self { + Self { + step, + arrived: AtomicUsize::new(0), + released: AtomicUsize::new(0), + arrived_notify: Notify::new(), + released_notify: Notify::new(), + } + } + + pub fn step(&self) -> usize { + self.step + } + + pub fn arrived(&self) -> usize { + self.arrived.load(Ordering::Acquire) + } + + pub async fn wait_arrived_at_least(&self, n: usize) { + loop { + let notified = self.arrived_notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + + if self.arrived() >= n { + return; + } + + // Re-check after registration to avoid missing a notify between + // condition check and awaiting. + if self.arrived() >= n { + return; + } + + notified.await; + } + } + + /// Release the `hit_no`-th arrival (1-based). + pub fn release(&self, hit_no: usize) { + // Monotonic release. + let mut cur = self.released.load(Ordering::Acquire); + while cur < hit_no { + match self + .released + .compare_exchange(cur, hit_no, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => break, + Err(v) => cur = v, + } + } + self.released_notify.notify_waiters(); + } + + async fn hit(&self) { + let hit_no = self.arrived.fetch_add(1, Ordering::AcqRel) + 1; + self.arrived_notify.notify_waiters(); + + loop { + let notified = self.released_notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + + let released = self.released.load(Ordering::Acquire); + if released >= hit_no { + return; + } + + // Re-check after registration to avoid missing a notify between + // condition check and awaiting. + let released = self.released.load(Ordering::Acquire); + if released >= hit_no { + return; + } + + notified.await; + } + } +} + +/// Called from the recursive CTE executor. +/// +/// If a pause gate is installed for `step`, this call will block until released. +#[async_backtrace::framed] +pub async fn maybe_pause_before_step(query_id: &str, step: usize) { + let Some(registry) = HOOKS.get() else { + return; + }; + let Some(gate) = registry.get_gate(query_id, step) else { + return; + }; + gate.hit().await; +} diff --git a/src/query/service/tests/it/sql/mod.rs b/src/query/service/tests/it/sql/mod.rs index 01ebe8a9789fb..227c41634ac2e 100644 --- a/src/query/service/tests/it/sql/mod.rs +++ b/src/query/service/tests/it/sql/mod.rs @@ -15,3 +15,4 @@ mod exec; mod expr; mod planner; +mod recursive_cte; diff --git a/src/query/service/tests/it/sql/recursive_cte.rs b/src/query/service/tests/it/sql/recursive_cte.rs new file mode 100644 index 0000000000000..ef46db55bdf3c --- /dev/null +++ b/src/query/service/tests/it/sql/recursive_cte.rs @@ -0,0 +1,195 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; +use std::sync::Arc; + +use databend_common_ast::ast::Engine; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::DataBlock; +use databend_common_expression::ScalarRef; +use databend_common_expression::TableField; +use databend_common_expression::TableSchemaRefExt; +use databend_common_expression::infer_schema_type; +use databend_common_expression::types::DataType; +use databend_common_expression::types::number::NumberDataType; +use databend_common_expression::types::number::NumberScalar; +use databend_common_meta_app::schema::CreateOption; +use databend_common_meta_app::tenant::Tenant; +use databend_common_sql::plans::CreateTablePlan; +use databend_common_sql::plans::DropTablePlan; +use databend_query::interpreters::CreateTableInterpreter; +use databend_query::interpreters::DropTableInterpreter; +use databend_query::interpreters::Interpreter; +use databend_query::interpreters::InterpreterFactory; +use databend_query::sessions::QueryContext; +use databend_query::sessions::TableContext; +use databend_query::sql::Planner; +use databend_query::test_kits::TestFixture; +use databend_query::test_kits::rcte_hooks::RcteHookRegistry; +use databend_storages_common_table_meta::table::OPT_KEY_RECURSIVE_CTE; +use futures_util::TryStreamExt; + +fn extract_u64(blocks: Vec, col: usize) -> u64 { + let block = DataBlock::concat(&blocks).expect("concat blocks"); + assert_eq!(block.num_rows(), 1, "unexpected rows: {}", block.num_rows()); + + let scalar = block.get_by_offset(col).index(0).expect("scalar at row 0"); + + match scalar { + ScalarRef::Number(NumberScalar::UInt64(v)) => v, + ScalarRef::Number(NumberScalar::UInt32(v)) => v as u64, + ScalarRef::Number(NumberScalar::Int64(v)) => v as u64, + other => panic!("unexpected scalar type for col {col}: {other:?}"), + } +} + +async fn run_query_single_u64(ctx: Arc, sql: &str) -> Result { + let mut planner = Planner::new(ctx.clone()); + let (plan, _) = planner.plan_sql(sql).await?; + let executor = InterpreterFactory::get(ctx.clone(), &plan).await?; + let stream = executor.execute(ctx).await?; + let blocks: Vec = stream.try_collect().await?; + Ok(extract_u64(blocks, 0)) +} + +/// Deterministically reproduce *wrong results* caused by recursive CTE internal table name reuse. +/// +/// This is a stable (non-flaky) repro: it forces the internal MEMORY table (`lines`) to be +/// corrupted between recursive step=0 and step=1, so step=1 reads no prepared blocks and the +/// recursion stops early. +#[test] +fn recursive_cte_deterministic_wrong_count_repro() -> anyhow::Result<()> { + let outer_rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + outer_rt.block_on(async { + use databend_common_base::runtime::Runtime; + + let fixture = Arc::new(TestFixture::setup().await?); + + let db = fixture.default_db_name(); + fixture + .execute_command(&format!("create database if not exists {db}")) + .await?; + + let runtime = Runtime::with_worker_threads(2, None)?; + + // Use the same QueryContext for running the query and applying interference, to avoid + // uncertainty around per-context table caching. + let ctx = fixture.new_query_ctx().await?; + ctx.set_current_database(db.clone()).await?; + ctx.get_settings().set_max_threads(8)?; + + // Pause the recursive CTE executor before it starts step=1 (only for this query id). + let gate = RcteHookRegistry::global().install_pause_before_step(&ctx.get_id(), 1); + + let ctx_q = ctx.clone(); + let jh = runtime.spawn(async move { + // This query should normally return 1000. + // If the internal MEMORY table is recreated between step=0 and step=1, + // step=1 reads no prepared blocks and recursion stops early => count becomes 1. + let sql = "WITH RECURSIVE\n\ + lines(x) AS (\n\ + SELECT 1::UInt64\n\ + UNION ALL\n\ + SELECT x + 1\n\ + FROM lines\n\ + WHERE x < 1000\n\ + )\n\ + SELECT count(*) FROM lines"; + + run_query_single_u64(ctx_q, sql).await + }); + + // Wait until the query reaches step=1 and is blocked. + gate.wait_arrived_at_least(1).await; + + // Deterministic interference: drop and recreate the internal recursive MEMORY table between + // step=0 (write prepared blocks) and step=1 (read prepared blocks). + // + // Important: QueryContext caches tables for consistency within a query. To make the query observe + // the recreated table, we also evict it from *the query's* table cache before resuming. + let ctx_ddl = fixture.new_query_ctx().await?; + ctx_ddl.set_current_database(db.clone()).await?; + + let drop_table_plan = DropTablePlan { + if_exists: true, + tenant: Tenant { + tenant: ctx_ddl.get_tenant().tenant, + }, + catalog: ctx_ddl.get_current_catalog(), + database: db.clone(), + table: "lines".to_string(), + all: true, + }; + let drop_table_interpreter = + DropTableInterpreter::try_create(ctx_ddl.clone(), drop_table_plan)?; + drop_table_interpreter.execute2().await?; + + let schema = TableSchemaRefExt::create(vec![TableField::new( + "x", + infer_schema_type(&DataType::Number(NumberDataType::UInt64))?, + )]); + + let mut options = BTreeMap::new(); + options.insert(OPT_KEY_RECURSIVE_CTE.to_string(), "1".to_string()); + + let create_table_plan = CreateTablePlan { + schema, + create_option: CreateOption::Create, + tenant: Tenant { + tenant: ctx_ddl.get_tenant().tenant, + }, + catalog: ctx_ddl.get_current_catalog(), + database: db.clone(), + table: "lines".to_string(), + engine: Engine::Memory, + engine_options: Default::default(), + table_properties: Default::default(), + table_partition: None, + storage_params: None, + options, + field_comments: vec![], + cluster_key: None, + as_select: None, + table_indexes: None, + table_constraints: None, + attached_columns: None, + }; + let create_table_interpreter = + CreateTableInterpreter::try_create(ctx_ddl.clone(), create_table_plan)?; + let _ = create_table_interpreter.execute(ctx_ddl.clone()).await?; + + // Evict in the *query* context so the resumed recursive step re-fetches the table. + ctx.evict_table_from_cache(&ctx.get_current_catalog(), &db, "lines")?; + // Allow the query to continue step=1. + gate.release(1); + + let got = jh.await.map_err(|e| ErrorCode::Internal(e.to_string()))??; + + // Without interference, this is 1000. If the bug is present, this becomes 1 (seed-only). + if got != 1000 { + return Err(ErrorCode::Internal(format!( + "deterministic wrong-result repro: expected 1000, got {got}" + ))); + } + + Ok::<(), ErrorCode>(()) + })?; + + Ok(()) +}