diff --git a/Cargo.lock b/Cargo.lock index c035bca6..67dddff2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1320,7 +1320,7 @@ dependencies = [ [[package]] name = "kite_sql" -version = "0.1.4" +version = "0.1.5" dependencies = [ "ahash 0.8.12", "async-trait", @@ -1349,7 +1349,6 @@ dependencies = [ "ordered-float", "parking_lot", "paste", - "petgraph", "pgwire", "pprof", "recursive", @@ -1765,16 +1764,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "petgraph" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" -dependencies = [ - "fixedbitset", - "indexmap", -] - [[package]] name = "pgwire" version = "0.28.0" diff --git a/Cargo.toml b/Cargo.toml index b195c947..ad53d1a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "kite_sql" -version = "0.1.4" +version = "0.1.5" edition = "2021" authors = ["Kould ", "Xwg "] description = "SQL as a Function for Rust" @@ -48,7 +48,6 @@ itertools = { version = "0.12" } ordered-float = { version = "4", features = ["serde"] } paste = { version = "1" } parking_lot = { version = "0.12", features = ["arc_lock"] } -petgraph = { version = "0.6" } recursive = { version = "0.1" } regex = { version = "1" } rust_decimal = { version = "1" } diff --git a/Makefile b/Makefile index 2a5fe433..e7b10c88 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ test-wasm: ## Run the sqllogictest harness against the configured .slt suite. test-slt: - $(CARGO) run -p sqllogictest-test -- --path $(SQLLOGIC_PATH) + $(CARGO) run -p sqllogictest-test -- --path "$(SQLLOGIC_PATH)" ## Convenience target to run every suite in sequence. test-all: test test-wasm test-slt diff --git a/examples/hello_world.rs b/examples/hello_world.rs index f0de3e76..eef41674 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -12,81 +12,85 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod app { + use kite_sql::db::{DataBaseBuilder, ResultIter}; + use kite_sql::errors::DatabaseError; + use kite_sql::implement_from_tuple; + use kite_sql::types::value::DataValue; -use kite_sql::db::{DataBaseBuilder, ResultIter}; -use kite_sql::errors::DatabaseError; -use kite_sql::implement_from_tuple; -use kite_sql::types::value::DataValue; - -#[derive(Default, Debug, PartialEq)] -struct MyStruct { - pub c1: i32, - pub c2: String, -} + #[derive(Default, Debug, PartialEq)] + pub struct MyStruct { + pub c1: i32, + pub c2: String, + } -implement_from_tuple!( - MyStruct, ( - c1: i32 => |inner: &mut MyStruct, value| { - if let DataValue::Int32(val) = value { - inner.c1 = val; + implement_from_tuple!( + MyStruct, ( + c1: i32 => |inner: &mut MyStruct, value| { + if let DataValue::Int32(val) = value { + inner.c1 = val; + } + }, + c2: String => |inner: &mut MyStruct, value| { + if let DataValue::Utf8 { value, .. } = value { + inner.c2 = value; + } } - }, - c2: String => |inner: &mut MyStruct, value| { - if let DataValue::Utf8 { value, .. } = value { - inner.c2 = value; - } - } - ) -); + ) + ); -#[cfg(feature = "macros")] -fn main() -> Result<(), DatabaseError> { - let database = DataBaseBuilder::path("./example_data/hello_world").build()?; + pub fn run() -> Result<(), DatabaseError> { + let database = DataBaseBuilder::path("./example_data/hello_world").build()?; - // 1) Create table and insert multiple rows with mixed types. - database - .run( - "create table if not exists my_struct ( - c1 int primary key, - c2 varchar, - c3 int - )", - )? - .done()?; - database - .run( - r#" - insert into my_struct values - (0, 'zero', 0), - (1, 'one', 1), - (2, 'two', 2) - "#, - )? - .done()?; + database + .run( + "create table if not exists my_struct ( + c1 int primary key, + c2 varchar, + c3 int + )", + )? + .done()?; + database + .run( + r#" + insert into my_struct values + (0, 'zero', 0), + (1, 'one', 1), + (2, 'two', 2) + "#, + )? + .done()?; - // 2) Update and delete demo. - database - .run("update my_struct set c3 = c3 + 10 where c1 = 1")? - .done()?; - database.run("delete from my_struct where c1 = 2")?.done()?; + database + .run("update my_struct set c3 = c3 + 10 where c1 = 1")? + .done()?; + database.run("delete from my_struct where c1 = 2")?.done()?; - // 3) Query and deserialize into Rust struct. - let iter = database.run("select * from my_struct")?; - let schema = iter.schema().clone(); + let iter = database.run("select * from my_struct")?; + let schema = iter.schema().clone(); - for tuple in iter { - println!("{:?}", MyStruct::from((&schema, tuple?))); - } + for tuple in iter { + println!("{:?}", MyStruct::from((&schema, tuple?))); + } + + let mut agg = database.run("select count(*) from my_struct")?; + if let Some(count_row) = agg.next() { + println!("row count = {:?}", count_row?); + } + agg.done()?; - // 4) Aggregate example. - let mut agg = database.run("select count(*) from my_struct")?; - if let Some(count_row) = agg.next() { - println!("row count = {:?}", count_row?); + database.run("drop table my_struct")?.done()?; + + Ok(()) } - agg.done()?; +} - database.run("drop table my_struct")?.done()?; +#[cfg(target_arch = "wasm32")] +fn main() {} - Ok(()) +#[cfg(all(not(target_arch = "wasm32"), feature = "macros"))] +fn main() -> Result<(), kite_sql::errors::DatabaseError> { + app::run() } diff --git a/examples/transaction.rs b/examples/transaction.rs index 50d57f28..0fda4c81 100644 --- a/examples/transaction.rs +++ b/examples/transaction.rs @@ -12,55 +12,62 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod app { + use kite_sql::db::{DataBaseBuilder, ResultIter}; + use kite_sql::errors::DatabaseError; + use kite_sql::types::tuple::Tuple; + use kite_sql::types::value::DataValue; -use kite_sql::db::{DataBaseBuilder, ResultIter}; -use kite_sql::errors::DatabaseError; -use kite_sql::types::tuple::Tuple; -use kite_sql::types::value::DataValue; + pub fn run() -> Result<(), DatabaseError> { + let database = DataBaseBuilder::path("./example_data/transaction").build_optimistic()?; + database + .run("create table if not exists t1 (c1 int primary key, c2 int)")? + .done()?; + let mut transaction = database.new_transaction()?; -fn main() -> Result<(), DatabaseError> { - let database = DataBaseBuilder::path("./example_data/transaction").build_optimistic()?; - database - .run("create table if not exists t1 (c1 int primary key, c2 int)")? - .done()?; - let mut transaction = database.new_transaction()?; + transaction + .run("insert into t1 values(0, 0), (1, 1)")? + .done()?; - transaction - .run("insert into t1 values(0, 0), (1, 1)")? - .done()?; + assert!(database.run("select * from t1")?.next().is_none()); - assert!(database.run("select * from t1")?.next().is_none()); + transaction.commit()?; - transaction.commit()?; + let mut iter = database.run("select * from t1")?; + assert_eq!( + iter.next().unwrap()?, + Tuple::new(None, vec![DataValue::Int32(0), DataValue::Int32(0)]) + ); + assert_eq!( + iter.next().unwrap()?, + Tuple::new(None, vec![DataValue::Int32(1), DataValue::Int32(1)]) + ); + assert!(iter.next().is_none()); - let mut iter = database.run("select * from t1")?; - assert_eq!( - iter.next().unwrap()?, - Tuple::new(None, vec![DataValue::Int32(0), DataValue::Int32(0)]) - ); - assert_eq!( - iter.next().unwrap()?, - Tuple::new(None, vec![DataValue::Int32(1), DataValue::Int32(1)]) - ); - assert!(iter.next().is_none()); + let mut tx2 = database.new_transaction()?; + tx2.run("update t1 set c2 = 99 where c1 = 0")?.done()?; + assert_eq!( + database + .run("select c2 from t1 where c1 = 0")? + .next() + .unwrap()? + .values[0] + .i32(), + Some(0) + ); + drop(tx2); - // Scenario: another transaction updates but does not commit; changes stay invisible. - let mut tx2 = database.new_transaction()?; - tx2.run("update t1 set c2 = 99 where c1 = 0")?.done()?; - assert_eq!( - database - .run("select c2 from t1 where c1 = 0")? - .next() - .unwrap()? - .values[0] - .i32(), - Some(0) - ); - // rollback - drop(tx2); + database.run("drop table t1")?.done()?; + + Ok(()) + } +} - database.run("drop table t1")?.done()?; +#[cfg(target_arch = "wasm32")] +fn main() {} - Ok(()) +#[cfg(not(target_arch = "wasm32"))] +fn main() -> Result<(), kite_sql::errors::DatabaseError> { + app::run() } diff --git a/src/db.rs b/src/db.rs index f0d6e771..657a55e2 100644 --- a/src/db.rs +++ b/src/db.rs @@ -211,11 +211,8 @@ impl State { /// Limit(1) /// Project(a,b) let source_plan = binder.bind(stmt)?; - // println!("source_plan plan: {:#?}", source_plan); - let best_plan = Self::default_optimizer(source_plan) .find_best(Some(&transaction.meta_loader(meta_cache)))?; - // println!("best_plan plan: {:#?}", best_plan); Ok(best_plan) } @@ -356,7 +353,7 @@ impl Database { self.state.prepare(sql) } - fn execute>( + pub fn execute>( &self, statement: &Statement, params: A, diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 6991da32..792cbc86 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -383,6 +383,7 @@ mod test { use super::*; use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::{DataBaseBuilder, ResultIter}; use crate::execution::dql::test::build_integers; use crate::execution::{try_collect, ReadExecutor}; use crate::expression::ScalarExpression; @@ -404,6 +405,18 @@ mod test { use std::sync::Arc; use tempfile::TempDir; + fn tuple_to_strings(tuple: &Tuple) -> Vec> { + tuple + .values + .iter() + .map(|value| match value { + DataValue::Null => None, + DataValue::Utf8 { value, .. } => Some(value.clone()), + other => Some(other.to_string()), + }) + .collect() + } + fn build_join_values( eq: bool, ) -> ( @@ -1019,4 +1032,52 @@ mod test { Ok(()) } + + #[test] + fn test_right_join_using_preserves_right_side_values() -> Result<(), DatabaseError> { + let temp_dir = TempDir::new().expect("unable to create temporary working directory"); + let db = DataBaseBuilder::path(temp_dir.path()).build_in_memory()?; + + let setup_sql = [ + "DROP TABLE IF EXISTS str1", + "DROP TABLE IF EXISTS str2", + "CREATE TABLE str1 (aid INT PRIMARY KEY, a INT, s VARCHAR)", + "CREATE TABLE str2 (bid INT PRIMARY KEY, a INT, s VARCHAR)", + "INSERT INTO str1 VALUES (0, 1, 'a'), (1, 2, 'A'), (2, 3, 'c'), (3, 4, 'D')", + "INSERT INTO str2 VALUES (0, 1, 'A'), (1, 2, 'B'), (2, 3, 'C'), (3, 4, 'E')", + ]; + + for sql in setup_sql { + db.run(sql)?.done()?; + } + + let mut iter = db.run( + "SELECT s, str1.s, str2.s \ + FROM str1 RIGHT OUTER JOIN str2 USING(s) \ + ORDER BY str2.s", + )?; + let mut actual = Vec::new(); + + while let Some(row) = iter.next() { + let tuple = row?; + actual.push(tuple_to_strings(&tuple)); + } + iter.done()?; + + assert_eq!( + actual, + vec![ + vec![ + Some("A".to_string()), + Some("A".to_string()), + Some("A".to_string()) + ], + vec![None, None, Some("B".to_string())], + vec![None, None, Some("C".to_string())], + vec![None, None, Some("E".to_string())], + ] + ); + + Ok(()) + } } diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index d95d45e3..f8a19c2b 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -16,10 +16,10 @@ use crate::errors::DatabaseError; use crate::optimizer::core::pattern::PatternMatcher; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::heuristic::matcher::HepMatcher; +use crate::optimizer::heuristic::matcher::PlanMatcher; use crate::optimizer::rule::implementation::ImplementationRuleImpl; use crate::planner::operator::PhysicalOption; +use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use std::cmp::Ordering; use std::collections::HashMap; @@ -42,42 +42,88 @@ impl GroupExpression { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct NodePath(Vec); + +impl NodePath { + fn root() -> Self { + Self(Vec::new()) + } + + fn child(&self, idx: usize) -> Self { + let mut path = self.0.clone(); + path.push(idx); + Self(path) + } +} + #[derive(Debug)] pub struct Memo { - groups: HashMap, + groups: HashMap, } impl Memo { pub(crate) fn new( - graph: &HepGraph, + plan: &LogicalPlan, loader: &StatisticMetaLoader<'_, T>, implementations: &[ImplementationRuleImpl], ) -> Result { - let node_count = graph.node_count(); let mut groups = HashMap::new(); + Self::collect(plan, NodePath::root(), loader, implementations, &mut groups)?; + Ok(Memo { groups }) + } - if node_count == 0 { - return Err(DatabaseError::EmptyPlan); + fn collect( + plan: &LogicalPlan, + path: NodePath, + loader: &StatisticMetaLoader<'_, T>, + implementations: &[ImplementationRuleImpl], + groups: &mut HashMap, + ) -> Result<(), DatabaseError> { + for rule in implementations { + if PlanMatcher::new(rule.pattern(), plan).match_opt_expr() { + let group_expr = groups + .entry(path.clone()) + .or_insert_with(|| GroupExpression { exprs: vec![] }); + rule.to_expression(&plan.operator, loader, group_expr)?; + } } - for node_id in graph.nodes_iter(None) { - for rule in implementations { - if HepMatcher::new(rule.pattern(), node_id, graph).match_opt_expr() { - let op = graph.operator(node_id); - let group_expr = groups - .entry(node_id) - .or_insert_with(|| GroupExpression { exprs: vec![] }); - - rule.to_expression(op, loader, group_expr)?; - } + match plan.childrens.as_ref() { + Childrens::Only(child) => { + Self::collect(child, path.child(0), loader, implementations, groups)?; } + Childrens::Twins { left, right } => { + Self::collect(left, path.child(0), loader, implementations, groups)?; + Self::collect(right, path.child(1), loader, implementations, groups)?; + } + Childrens::None => {} } - Ok(Memo { groups }) + Ok(()) + } + + pub(crate) fn annotate_plan(&self, plan: &mut LogicalPlan) { + Self::annotate(plan, &NodePath::root(), self); + } + + fn annotate(plan: &mut LogicalPlan, path: &NodePath, memo: &Memo) { + if let Some(option) = memo.cheapest_physical_option(path) { + plan.physical_option = Some(option); + } + + match plan.childrens.as_mut() { + Childrens::Only(child) => Self::annotate(child, &path.child(0), memo), + Childrens::Twins { left, right } => { + Self::annotate(left, &path.child(0), memo); + Self::annotate(right, &path.child(1), memo); + } + Childrens::None => {} + } } - pub(crate) fn cheapest_physical_option(&self, node_id: &HepNodeId) -> Option { - self.groups.get(node_id).and_then(|exprs| { + pub(crate) fn cheapest_physical_option(&self, path: &NodePath) -> Option { + self.groups.get(path).and_then(|exprs| { exprs .exprs .iter() @@ -94,13 +140,13 @@ impl Memo { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { + use super::NodePath; use crate::binder::{Binder, BinderContext}; use crate::db::{DataBaseBuilder, ResultIter}; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::optimizer::core::memo::Memo; use crate::optimizer::heuristic::batch::HepBatchStrategy; - use crate::optimizer::heuristic::graph::HepGraph; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::implementation::ImplementationRuleImpl; use crate::optimizer::rule::normalization::NormalizationRuleImpl; @@ -110,12 +156,12 @@ mod tests { use crate::types::index::{IndexInfo, IndexMeta, IndexType}; use crate::types::value::DataValue; use crate::types::LogicalType; - use petgraph::stable_graph::NodeIndex; use std::ops::Bound; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; + // Tips: This test may occasionally encounter errors; you can repeat the test multiple times. #[test] fn test_build_memo() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); @@ -161,7 +207,7 @@ mod tests { "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", )?; let plan = binder.bind(&stmt[0])?; - let best_plan = HepOptimizer::new(plan) + let mut best_plan = HepOptimizer::new(plan) .batch( "Simplify Filter".to_string(), HepBatchStrategy::once_topdown(), @@ -176,7 +222,6 @@ mod tests { ], ) .find_best::(None)?; - let graph = HepGraph::new(best_plan); let rules = vec![ ImplementationRuleImpl::Projection, ImplementationRuleImpl::Filter, @@ -186,12 +231,15 @@ mod tests { ]; let memo = Memo::new( - &graph, + &best_plan, &transaction.meta_loader(database.state.meta_cache()), &rules, )?; - let best_plan = graph.into_plan(Some(&memo)); - let exprs = &memo.groups.get(&NodeIndex::new(3)).unwrap(); + Memo::annotate_plan(&memo, &mut best_plan); + let exprs = memo + .groups + .get(&NodePath(vec![0, 0, 0])) + .expect("missing group"); assert_eq!(exprs.exprs.len(), 2); assert_eq!(exprs.exprs[0].cost, Some(1000)); @@ -200,7 +248,6 @@ mod tests { assert!(matches!(exprs.exprs[1].op, PhysicalOption::IndexScan(_))); assert_eq!( best_plan - .unwrap() .childrens .pop_only() .childrens diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index a4e768f0..5e2d0d60 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -16,8 +16,8 @@ use crate::errors::DatabaseError; use crate::optimizer::core::memo::GroupExpression; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::storage::Transaction; // TODO: Use indexing and other methods for matching optimization to avoid traversal @@ -26,7 +26,8 @@ pub trait MatchPattern { } pub trait NormalizationRule: MatchPattern { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError>; + /// Returns true when the plan tree is modified. + fn apply(&self, plan: &mut LogicalPlan) -> Result; } pub trait ImplementationRule: MatchPattern { diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs deleted file mode 100644 index 8a5a5f53..00000000 --- a/src/optimizer/heuristic/graph.rs +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::optimizer::core::memo::Memo; -use crate::planner::operator::Operator; -use crate::planner::{Childrens, LogicalPlan}; -use fixedbitset::FixedBitSet; -use itertools::Itertools; -use petgraph::stable_graph::{NodeIndex, StableDiGraph}; -use petgraph::visit::{Bfs, EdgeRef}; -use std::mem; - -/// HepNodeId is used in optimizer to identify a node. -pub type HepNodeId = NodeIndex; - -#[derive(Debug, Clone)] -pub struct HepGraph { - graph: StableDiGraph, - root_index: HepNodeId, - pub version: usize, -} - -impl HepGraph { - pub fn new(root: LogicalPlan) -> Self { - fn graph_filling( - graph: &mut StableDiGraph, - LogicalPlan { - operator, - childrens, - .. - }: LogicalPlan, - ) -> HepNodeId { - let index = graph.add_node(operator); - - match *childrens { - Childrens::None => (), - Childrens::Only(child) => { - let child_index = graph_filling(graph, *child); - let _ = graph.add_edge(index, child_index, 0); - } - Childrens::Twins { left, right } => { - let child_index = graph_filling(graph, *left); - let _ = graph.add_edge(index, child_index, 0); - let child_index = graph_filling(graph, *right); - let _ = graph.add_edge(index, child_index, 1); - } - } - index - } - - let mut graph = StableDiGraph::::default(); - - let root_index = graph_filling(&mut graph, root); - - HepGraph { - graph, - root_index, - version: 0, - } - } - - pub fn node_count(&self) -> usize { - self.graph.node_count() - } - - pub fn parent_id(&self, node_id: HepNodeId) -> Option { - self.graph - .neighbors_directed(node_id, petgraph::Direction::Incoming) - .next() - } - - #[allow(dead_code)] - pub fn add_root(&mut self, new_node: Operator) { - let old_root_id = mem::replace(&mut self.root_index, self.graph.add_node(new_node)); - - self.graph.add_edge(self.root_index, old_root_id, 0); - self.version += 1; - } - - pub fn add_node( - &mut self, - source_id: HepNodeId, - children_option: Option, - new_node: Operator, - ) { - let new_index = self.graph.add_node(new_node); - let mut order = self.graph.edges(source_id).count(); - - if let Some((children_id, old_edge_id)) = children_option.and_then(|children_id| { - self.graph - .find_edge(source_id, children_id) - .map(|old_edge_id| (children_id, old_edge_id)) - }) { - order = self.graph.remove_edge(old_edge_id).unwrap_or(0); - - self.graph.add_edge(new_index, children_id, 0); - } - - self.graph.add_edge(source_id, new_index, order); - self.version += 1; - } - - pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) { - self.graph[source_id] = new_node; - self.version += 1; - } - - pub fn swap_node(&mut self, a: HepNodeId, b: HepNodeId) { - let tmp = self.graph[a].clone(); - - self.graph[a] = mem::replace(&mut self.graph[b], tmp); - self.version += 1; - } - - pub fn remove_node(&mut self, source_id: HepNodeId, with_childrens: bool) -> Option { - if !with_childrens { - let children_ids = self - .graph - .edges(source_id) - .sorted_by_key(|edge_ref| edge_ref.weight()) - .map(|edge_ref| edge_ref.target()) - .collect_vec(); - - if let Some(parent_id) = self.parent_id(source_id) { - if let Some(edge) = self.graph.find_edge(parent_id, source_id) { - let weight = *self.graph.edge_weight(edge).unwrap_or(&0); - - for (order, children_id) in children_ids.into_iter().enumerate() { - let _ = self.graph.add_edge(parent_id, children_id, weight + order); - } - } - } else { - debug_assert!(children_ids.len() < 2); - self.root_index = children_ids[0]; - } - } - - self.version += 1; - self.graph.remove_node(source_id) - } - - #[allow(dead_code)] - pub fn node(&self, node_id: HepNodeId) -> Option<&Operator> { - self.graph.node_weight(node_id) - } - - pub fn operator(&self, node_id: HepNodeId) -> &Operator { - &self.graph[node_id] - } - - pub fn operator_mut(&mut self, node_id: HepNodeId) -> &mut Operator { - &mut self.graph[node_id] - } - - /// If input node is join, we use the edge weight to control the join children order. - pub fn children_at(&self, id: HepNodeId) -> Box + '_> { - Box::new( - self.graph - .edges(id) - .sorted_by_key(|edge| edge.weight()) - .map(|edge| edge.target()), - ) - } - - pub fn eldest_child_at(&self, id: HepNodeId) -> Option { - self.graph - .edges(id) - .min_by_key(|edge| edge.weight()) - .map(|edge| edge.target()) - } - - pub fn youngest_child_at(&self, id: HepNodeId) -> Option { - self.graph - .edges(id) - .max_by_key(|edge| edge.weight()) - .map(|edge| edge.target()) - } - - pub fn into_plan(mut self, memo: Option<&Memo>) -> Option { - self.build_childrens(self.root_index, memo) - } - - /// Use bfs to traverse the graph and return node ids - pub fn nodes_iter(&self, start_option: Option) -> HepGraphIter<'_> { - let inner = Bfs::new(&self.graph, start_option.unwrap_or(self.root_index)); - HepGraphIter { - inner, - graph: &self.graph, - } - } - - fn build_childrens(&mut self, start: HepNodeId, memo: Option<&Memo>) -> Option { - let physical_option = memo.and_then(|memo| memo.cheapest_physical_option(&start)); - - let mut iter = self.children_at(start); - - let child_0 = iter.next(); - let child_1 = iter.next(); - drop(iter); - - let child_0 = child_0.and_then(|id| self.build_childrens(id, memo)); - let child_1 = child_1.and_then(|id| self.build_childrens(id, memo)); - - let childrens = match (child_0, child_1) { - (Some(child_0), Some(child_1)) => Childrens::Twins { - left: Box::new(child_0), - right: Box::new(child_1), - }, - (Some(child), None) | (None, Some(child)) => Childrens::Only(Box::new(child)), - (None, None) => Childrens::None, - }; - - self.graph.remove_node(start).map(|operator| LogicalPlan { - operator, - childrens: Box::new(childrens), - physical_option, - _output_schema_ref: None, - }) - } -} - -pub struct HepGraphIter<'a> { - inner: Bfs, - graph: &'a StableDiGraph, -} - -impl Iterator for HepGraphIter<'_> { - type Item = HepNodeId; - - fn next(&mut self) -> Option { - self.inner.next(self.graph) - } -} - -#[cfg(all(test, not(target_arch = "wasm32")))] -mod tests { - use crate::binder::test::build_t1_table; - use crate::errors::DatabaseError; - use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; - use crate::planner::operator::Operator; - use crate::planner::{Childrens, LogicalPlan}; - use petgraph::stable_graph::{EdgeIndex, NodeIndex}; - - #[test] - fn test_graph_for_plan() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let graph = HepGraph::new(plan); - - assert!(graph - .graph - .contains_edge(NodeIndex::new(1), NodeIndex::new(2))); - assert!(graph - .graph - .contains_edge(NodeIndex::new(1), NodeIndex::new(3))); - assert!(graph - .graph - .contains_edge(NodeIndex::new(0), NodeIndex::new(1))); - - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(0)), Some(&0)); - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(1)), Some(&1)); - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(2)), Some(&0)); - - Ok(()) - } - - #[test] - fn test_graph_add_node() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - graph.add_node(HepNodeId::new(1), None, Operator::Dummy); - - graph.add_node(HepNodeId::new(1), Some(HepNodeId::new(4)), Operator::Dummy); - - graph.add_node(HepNodeId::new(5), None, Operator::Dummy); - - assert!(graph - .graph - .contains_edge(NodeIndex::new(5), NodeIndex::new(4))); - assert!(graph - .graph - .contains_edge(NodeIndex::new(1), NodeIndex::new(5))); - assert!(graph - .graph - .contains_edge(NodeIndex::new(5), NodeIndex::new(6))); - - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(3)), Some(&0)); - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(4)), Some(&2)); - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(5)), Some(&1)); - - Ok(()) - } - - #[test] - fn test_graph_replace_node() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - graph.replace_node(HepNodeId::new(1), Operator::Dummy); - - assert!(matches!(graph.operator(HepNodeId::new(1)), Operator::Dummy)); - - Ok(()) - } - - #[test] - fn test_graph_remove_middle_node_by_single() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - graph.remove_node(HepNodeId::new(1), false); - - assert_eq!(graph.graph.edge_count(), 2); - - assert!(graph - .graph - .contains_edge(NodeIndex::new(0), NodeIndex::new(2))); - assert!(graph - .graph - .contains_edge(NodeIndex::new(0), NodeIndex::new(3))); - - Ok(()) - } - - #[test] - fn test_graph_remove_middle_node_with_childrens() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - graph.remove_node(HepNodeId::new(1), true); - - assert_eq!(graph.graph.edge_count(), 0); - - Ok(()) - } - - #[test] - fn test_graph_swap_node() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - let before_op_0 = graph.operator(HepNodeId::new(0)).clone(); - let before_op_1 = graph.operator(HepNodeId::new(1)).clone(); - - graph.swap_node(HepNodeId::new(0), HepNodeId::new(1)); - - let op_0 = graph.operator(HepNodeId::new(0)); - let op_1 = graph.operator(HepNodeId::new(1)); - - assert_eq!(op_0, &before_op_1); - assert_eq!(op_1, &before_op_0); - - Ok(()) - } - - #[test] - fn test_graph_add_root() -> Result<(), DatabaseError> { - let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let mut graph = HepGraph::new(plan); - - graph.add_root(Operator::Dummy); - - assert_eq!(graph.graph.edge_count(), 4); - assert!(graph - .graph - .contains_edge(NodeIndex::new(4), NodeIndex::new(0))); - assert_eq!(graph.graph.edge_weight(EdgeIndex::new(3)), Some(&0)); - - Ok(()) - } - - #[test] - fn test_graph_to_plan() -> Result<(), DatabaseError> { - fn clear_output_schema_buf(plan: &mut LogicalPlan) { - plan._output_schema_ref = None; - - match plan.childrens.as_mut() { - Childrens::Only(child) => { - clear_output_schema_buf(child); - } - Childrens::Twins { left, right } => { - clear_output_schema_buf(left); - clear_output_schema_buf(right); - } - Childrens::None => (), - } - } - - let table_state = build_t1_table()?; - let mut plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - clear_output_schema_buf(&mut plan); - - let graph = HepGraph::new(plan.clone()); - - let plan_for_graph = graph.into_plan(None).unwrap(); - - assert_eq!(plan, plan_for_graph); - - Ok(()) - } -} diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 4d5abd54..6ef36ab7 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -13,52 +13,47 @@ // limitations under the License. use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate, PatternMatcher}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::planner::LogicalPlan; -/// Use pattern to determines which rule can be applied -pub struct HepMatcher<'a, 'b> { +/// Use pattern to determine which rule can be applied on a [`LogicalPlan`] subtree. +pub struct PlanMatcher<'a> { pattern: &'a Pattern, - start_id: HepNodeId, - graph: &'b HepGraph, + plan: &'a LogicalPlan, } -impl<'a, 'b> HepMatcher<'a, 'b> { - pub(crate) fn new(pattern: &'a Pattern, start_id: HepNodeId, graph: &'b HepGraph) -> Self { - Self { - pattern, - start_id, - graph, - } +impl<'a> PlanMatcher<'a> { + pub(crate) fn new(pattern: &'a Pattern, plan: &'a LogicalPlan) -> Self { + Self { pattern, plan } } } -impl PatternMatcher for HepMatcher<'_, '_> { +impl PatternMatcher for PlanMatcher<'_> { fn match_opt_expr(&self) -> bool { - let op = self.graph.operator(self.start_id); - // check the root node predicate - if !(self.pattern.predicate)(op) { + if !(self.pattern.predicate)(&self.plan.operator) { return false; } match &self.pattern.children { PatternChildrenPredicate::Recursive => { - // check - for node_id in self.graph.nodes_iter(Some(self.start_id)) { - if !(self.pattern.predicate)(self.graph.operator(node_id)) { + for child in self.plan.childrens.iter() { + if !(self.pattern.predicate)(&child.operator) { + return false; + } + if !PlanMatcher::new(self.pattern, child).match_opt_expr() { return false; } } } PatternChildrenPredicate::Predicate(patterns) => { - for node_id in self.graph.children_at(self.start_id) { + for child in self.plan.childrens.iter() { for pattern in patterns { - if !HepMatcher::new(pattern, node_id, self.graph).match_opt_expr() { + if !PlanMatcher::new(pattern, child).match_opt_expr() { return false; } } } } - PatternChildrenPredicate::None => (), + PatternChildrenPredicate::None => {} } true @@ -67,11 +62,10 @@ impl PatternMatcher for HepMatcher<'_, '_> { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { + use super::*; use crate::binder::test::build_t1_table; use crate::errors::DatabaseError; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate, PatternMatcher}; - use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; - use crate::optimizer::heuristic::matcher::HepMatcher; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; @@ -79,26 +73,16 @@ mod tests { fn test_predicate() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; let plan = table_state.plan("select * from t1")?; - let graph = HepGraph::new(plan.clone()); let project_into_table_scan_pattern = Pattern { - predicate: |p| match p { - Operator::Project(_) => true, - _ => false, - }, + predicate: |p| matches!(p, Operator::Project(_)), children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |p| match p { - Operator::TableScan(_) => true, - _ => false, - }, + predicate: |p| matches!(p, Operator::TableScan(_)), children: PatternChildrenPredicate::None, }]), }; - assert!( - HepMatcher::new(&project_into_table_scan_pattern, HepNodeId::new(0), &graph) - .match_opt_expr() - ); + assert!(PlanMatcher::new(&project_into_table_scan_pattern, &plan).match_opt_expr()); Ok(()) } @@ -129,16 +113,12 @@ mod tests { physical_option: None, _output_schema_ref: None, }; - let graph = HepGraph::new(all_dummy_plan.clone()); let only_dummy_pattern = Pattern { - predicate: |p| match p { - Operator::Dummy => true, - _ => false, - }, + predicate: |p| matches!(p, Operator::Dummy), children: PatternChildrenPredicate::Recursive, }; - assert!(HepMatcher::new(&only_dummy_pattern, HepNodeId::new(0), &graph).match_opt_expr()); + assert!(PlanMatcher::new(&only_dummy_pattern, &all_dummy_plan).match_opt_expr()); } } diff --git a/src/optimizer/heuristic/mod.rs b/src/optimizer/heuristic/mod.rs index 311098b6..491d1353 100644 --- a/src/optimizer/heuristic/mod.rs +++ b/src/optimizer/heuristic/mod.rs @@ -13,6 +13,5 @@ // limitations under the License. pub(crate) mod batch; -pub(crate) mod graph; pub(crate) mod matcher; pub mod optimizer; diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 4b978ca1..037624aa 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -18,17 +18,16 @@ use crate::optimizer::core::pattern::PatternMatcher; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::optimizer::heuristic::batch::{HepBatch, HepBatchStrategy}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::heuristic::matcher::HepMatcher; +use crate::optimizer::heuristic::matcher::PlanMatcher; use crate::optimizer::rule::implementation::ImplementationRuleImpl; use crate::optimizer::rule::normalization::NormalizationRuleImpl; -use crate::planner::LogicalPlan; +use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use std::ops::Not; pub struct HepOptimizer { batches: Vec, - pub graph: HepGraph, + plan: LogicalPlan, implementations: Vec, } @@ -36,7 +35,7 @@ impl HepOptimizer { pub fn new(root: LogicalPlan) -> Self { Self { batches: vec![], - graph: HepGraph::new(root), + plan: root, implementations: vec![], } } @@ -60,63 +59,70 @@ impl HepOptimizer { mut self, loader: Option<&StatisticMetaLoader<'_, T>>, ) -> Result { - for ref batch in self.batches { + for batch in &self.batches { match batch.strategy { HepBatchStrategy::MaxTimes(max_iteration) => { for _ in 0..max_iteration { - if !Self::apply_batch(&mut self.graph, batch)? { + if !Self::apply_batch(&mut self.plan, batch)? { break; } } } HepBatchStrategy::LoopIfApplied => { - while Self::apply_batch(&mut self.graph, batch)? {} + while Self::apply_batch(&mut self.plan, batch)? {} } } } - let memo = loader - .and_then(|loader| { - self.implementations - .is_empty() - .not() - .then(|| Memo::new(&self.graph, loader, &self.implementations)) - }) - .transpose()?; - self.graph - .into_plan(memo.as_ref()) - .ok_or(DatabaseError::EmptyPlan) - } + if let Some(loader) = loader { + if self.implementations.is_empty().not() { + let memo = Memo::new(&self.plan, loader, &self.implementations)?; + Memo::annotate_plan(&memo, &mut self.plan); + } + } - fn apply_batch( - graph: *mut HepGraph, - HepBatch { rules, .. }: &HepBatch, - ) -> Result { - let before_version = unsafe { &*graph }.version; + Ok(self.plan) + } - for rule in rules { - // SAFETY: after successfully modifying the graph, the iterator is no longer used. - for node_id in unsafe { &*graph }.nodes_iter(None) { - if Self::apply_rule(unsafe { &mut *graph }, rule, node_id)? { - break; - } + fn apply_batch(plan: &mut LogicalPlan, batch: &HepBatch) -> Result { + let mut applied = false; + for rule in &batch.rules { + if Self::apply_rule(plan, rule)? { + applied = true; } } - - Ok(before_version != unsafe { &*graph }.version) + Ok(applied) } fn apply_rule( - graph: &mut HepGraph, + plan: &mut LogicalPlan, rule: &NormalizationRuleImpl, - node_id: HepNodeId, ) -> Result { - let before_version = graph.version; - - if HepMatcher::new(rule.pattern(), node_id, graph).match_opt_expr() { - rule.apply(node_id, graph)?; + if PlanMatcher::new(rule.pattern(), plan).match_opt_expr() && rule.apply(plan)? { + plan.reset_output_schema_cache_recursive(); + return Ok(true); } - Ok(before_version != graph.version) + match plan.childrens.as_mut() { + Childrens::Only(child) => { + if Self::apply_rule(child, rule)? { + plan.reset_output_schema_cache(); + return Ok(true); + } + Ok(false) + } + Childrens::Twins { left, right } => { + if Self::apply_rule(left, rule)? { + plan.reset_output_schema_cache(); + return Ok(true); + } + if Self::apply_rule(right, rule)? { + plan.reset_output_schema_cache(); + return Ok(true); + } + Ok(false) + } + Childrens::None => Ok(false), + } } } diff --git a/src/optimizer/mod.rs b/src/optimizer/mod.rs index 225c5884..9360f9d6 100644 --- a/src/optimizer/mod.rs +++ b/src/optimizer/mod.rs @@ -16,4 +16,5 @@ /// such as (/core) are referenced from sqlrs pub mod core; pub mod heuristic; +pub mod plan_utils; pub mod rule; diff --git a/src/optimizer/plan_utils.rs b/src/optimizer/plan_utils.rs new file mode 100644 index 00000000..8e16739e --- /dev/null +++ b/src/optimizer/plan_utils.rs @@ -0,0 +1,132 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan}; +use std::mem; + +#[allow(dead_code)] +pub fn child_count(plan: &LogicalPlan) -> usize { + match plan.childrens.as_ref() { + Childrens::None => 0, + Childrens::Only(_) => 1, + Childrens::Twins { .. } => 2, + } +} + +#[allow(dead_code)] +pub fn only_child(plan: &LogicalPlan) -> Option<&LogicalPlan> { + match plan.childrens.as_ref() { + Childrens::Only(child) => Some(child.as_ref()), + _ => None, + } +} + +pub fn only_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { + match plan.childrens.as_mut() { + Childrens::Only(child) => Some(child.as_mut()), + _ => None, + } +} + +pub fn left_child(plan: &LogicalPlan) -> Option<&LogicalPlan> { + match plan.childrens.as_ref() { + Childrens::Only(child) => Some(child.as_ref()), + Childrens::Twins { left, .. } => Some(left.as_ref()), + Childrens::None => None, + } +} + +#[allow(dead_code)] +pub fn left_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { + match plan.childrens.as_mut() { + Childrens::Only(child) => Some(child.as_mut()), + Childrens::Twins { left, .. } => Some(left.as_mut()), + Childrens::None => None, + } +} + +pub fn right_child(plan: &LogicalPlan) -> Option<&LogicalPlan> { + match plan.childrens.as_ref() { + Childrens::Twins { right, .. } => Some(right.as_ref()), + _ => None, + } +} + +#[allow(dead_code)] +pub fn right_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { + match plan.childrens.as_mut() { + Childrens::Twins { right, .. } => Some(right.as_mut()), + _ => None, + } +} + +#[allow(dead_code)] +pub fn child(plan: &LogicalPlan, idx: usize) -> Option<&LogicalPlan> { + match (plan.childrens.as_ref(), idx) { + (Childrens::Only(child), 0) => Some(child.as_ref()), + (Childrens::Twins { left, .. }, 0) => Some(left.as_ref()), + (Childrens::Twins { right, .. }, 1) => Some(right.as_ref()), + _ => None, + } +} + +pub fn child_mut(plan: &mut LogicalPlan, idx: usize) -> Option<&mut LogicalPlan> { + match (plan.childrens.as_mut(), idx) { + (Childrens::Only(child), 0) => Some(child.as_mut()), + (Childrens::Twins { left, .. }, 0) => Some(left.as_mut()), + (Childrens::Twins { right, .. }, 1) => Some(right.as_mut()), + _ => None, + } +} + +#[allow(dead_code)] +pub fn children(plan: &LogicalPlan) -> Vec<&LogicalPlan> { + match plan.childrens.as_ref() { + Childrens::None => vec![], + Childrens::Only(child) => vec![child.as_ref()], + Childrens::Twins { left, right } => vec![left.as_ref(), right.as_ref()], + } +} + +pub fn replace_with_only_child(plan: &mut LogicalPlan) -> bool { + if let Childrens::Only(child) = take_childrens(plan) { + *plan = *child; + true + } else { + false + } +} + +#[allow(dead_code)] +pub fn replace_child_with_only_child(plan: &mut LogicalPlan, child_idx: usize) -> bool { + if let Some(child_plan) = child_mut(plan, child_idx) { + return replace_with_only_child(child_plan); + } + false +} + +pub fn wrap_child_with(plan: &mut LogicalPlan, child_idx: usize, operator: Operator) -> bool { + if let Some(slot) = child_mut(plan, child_idx) { + let previous = mem::replace(slot, LogicalPlan::new(operator, Childrens::None)); + slot.childrens = Box::new(Childrens::Only(Box::new(previous))); + true + } else { + false + } +} + +fn take_childrens(plan: &mut LogicalPlan) -> Childrens { + mem::replace(&mut *plan.childrens, Childrens::None) +} diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index 803f974c..f3403714 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -19,11 +19,10 @@ use crate::expression::visitor::Visitor; use crate::expression::{HasCountStar, ScalarExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::LogicalType; -use itertools::Itertools; use sqlparser::ast::CharLengthUnits; use std::collections::HashSet; use std::sync::LazyLock; @@ -61,15 +60,19 @@ impl ColumnPruning { fn _apply( column_references: HashSet<&ColumnSummary>, all_referenced: bool, - node_id: HepNodeId, - graph: &mut HepGraph, - ) -> Result<(), DatabaseError> { - let operator = graph.operator_mut(node_id); + plan: &mut LogicalPlan, + ) -> Result { + let mut changed = false; + let operator = &mut plan.operator; match operator { Operator::Aggregate(op) => { if !all_referenced { + let before = op.agg_calls.len(); Self::clear_exprs(&column_references, &mut op.agg_calls); + if op.agg_calls.len() != before { + changed = true; + } if op.agg_calls.is_empty() && op.groupby_exprs.is_empty() { let value = DataValue::Utf8 { @@ -84,7 +87,8 @@ impl ColumnPruning { kind: AggKind::Count, args: vec![ScalarExpression::Constant(value)], ty: LogicalType::Integer, - }) + }); + changed = true; } } let is_distinct = op.is_distinct; @@ -97,7 +101,7 @@ impl ColumnPruning { } } - Self::recollect_apply(new_column_references, false, node_id, graph)?; + changed |= Self::recollect_apply(new_column_references, false, plan)?; } Operator::Project(op) => { let mut has_count_star = HasCountStar::default(); @@ -106,18 +110,26 @@ impl ColumnPruning { } if !has_count_star.value { if !all_referenced { + let before = op.exprs.len(); Self::clear_exprs(&column_references, &mut op.exprs); + if op.exprs.len() != before { + changed = true; + } } let referenced_columns = operator.referenced_columns(false); let new_column_references = trans_references!(&referenced_columns); - Self::recollect_apply(new_column_references, false, node_id, graph)?; + changed |= Self::recollect_apply(new_column_references, false, plan)?; } } Operator::TableScan(op) => { if !all_referenced { + let before = op.columns.len(); op.columns .retain(|_, column| column_references.contains(column.summary())); + if op.columns.len() != before { + changed = true; + } } } Operator::Sort(_) @@ -128,25 +140,17 @@ impl ColumnPruning { | Operator::Except(_) | Operator::TopK(_) => { let temp_columns = operator.referenced_columns(false); - // why? + // this is magic!!! do not delete!!! let mut column_references = column_references; for column in temp_columns.iter() { column_references.insert(column.summary()); } - for child_id in graph.children_at(node_id).collect_vec() { - let copy_references = column_references.clone(); - - Self::_apply(copy_references, all_referenced, child_id, graph)?; - } + changed |= Self::recollect_apply(column_references.clone(), all_referenced, plan)?; } // Last Operator Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => (), Operator::Explain => { - if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::_apply(column_references, true, child_id, graph)?; - } else { - unreachable!() - } + changed |= Self::recollect_apply(column_references, true, plan)?; } // DDL Based on Other Plan Operator::Insert(_) @@ -156,11 +160,7 @@ impl ColumnPruning { let referenced_columns = operator.referenced_columns(false); let new_column_references = trans_references!(&referenced_columns); - if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::recollect_apply(new_column_references, true, child_id, graph)?; - } else { - unreachable!(); - } + changed |= Self::recollect_apply(new_column_references, true, plan)?; } // DDL Single Plan Operator::CreateTable(_) @@ -179,21 +179,35 @@ impl ColumnPruning { | Operator::Describe(_) => (), } - Ok(()) + Ok(changed) } fn recollect_apply( referenced_columns: HashSet<&ColumnSummary>, all_referenced: bool, - node_id: HepNodeId, - graph: &mut HepGraph, - ) -> Result<(), DatabaseError> { - for child_id in graph.children_at(node_id).collect_vec() { - let copy_references: HashSet<&ColumnSummary> = referenced_columns.clone(); + plan: &mut LogicalPlan, + ) -> Result { + Self::for_each_child(plan, |child| { + Self::_apply(referenced_columns.clone(), all_referenced, child) + }) + } - Self::_apply(copy_references, all_referenced, child_id, graph)?; + fn for_each_child( + plan: &mut LogicalPlan, + mut f: impl FnMut(&mut LogicalPlan) -> Result, + ) -> Result { + let mut changed = false; + match plan.childrens.as_mut() { + Childrens::Only(child) => { + changed |= f(child.as_mut())?; + } + Childrens::Twins { left, right } => { + changed |= f(left.as_mut())?; + changed |= f(right.as_mut())?; + } + Childrens::None => (), } - Ok(()) + Ok(changed) } } @@ -204,12 +218,8 @@ impl MatchPattern for ColumnPruning { } impl NormalizationRule for ColumnPruning { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - Self::_apply(HashSet::new(), true, node_id, graph)?; - // mark changed to skip this rule batch - graph.version += 1; - - Ok(()) + fn apply(&self, plan: &mut LogicalPlan) -> Result { + Self::_apply(HashSet::new(), true, plan) } } diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index d1f012c6..95ca6d01 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -16,9 +16,11 @@ use crate::errors::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; -use crate::optimizer::rule::normalization::is_subset_exprs; +use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; +use crate::optimizer::rule::normalization::{is_subset_exprs, strip_alias}; +use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::types::LogicalType; use std::collections::HashSet; use std::sync::LazyLock; @@ -34,7 +36,7 @@ static COLLAPSE_PROJECT_RULE: LazyLock = LazyLock::new(|| Pattern { static COMBINE_FILTERS_RULE: LazyLock = LazyLock::new(|| Pattern { predicate: |op| matches!(op, Operator::Filter(_)), children: PatternChildrenPredicate::Predicate(vec![Pattern { - predicate: |op| matches!(op, Operator::Filter(_)), + predicate: |op| matches!(op, Operator::Filter(_)) || is_passthrough_project_operator(op), children: PatternChildrenPredicate::None, }]), }); @@ -53,6 +55,16 @@ static COLLAPSE_GROUP_BY_AGG: LazyLock = LazyLock::new(|| Pattern { }]), }); +fn is_passthrough_project(op: &ProjectOperator) -> bool { + op.exprs + .iter() + .all(|expr| matches!(strip_alias(expr), ScalarExpression::ColumnRef { .. })) +} + +fn is_passthrough_project_operator(op: &Operator) -> bool { + matches!(op, Operator::Project(project_op) if is_passthrough_project(project_op)) +} + /// Combine two adjacent project operators into one. pub struct CollapseProject; @@ -63,18 +75,26 @@ impl MatchPattern for CollapseProject { } impl NormalizationRule for CollapseProject { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Project(op) = graph.operator(node_id) { - if let Some(child_id) = graph.eldest_child_at(node_id) { - if let Operator::Project(child_op) = graph.operator(child_id) { - if is_subset_exprs(&op.exprs, &child_op.exprs) { - graph.remove_node(child_id, false); - } + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let parent_exprs = match &plan.operator { + Operator::Project(op) => op.exprs.clone(), + _ => return Ok(false), + }; + + let mut removed = false; + while let Some(child) = only_child_mut(plan) { + match &child.operator { + Operator::Project(child_op) + if is_passthrough_project(child_op) + && is_subset_exprs(&parent_exprs, &child_op.exprs) => + { + removed |= replace_with_only_child(child); } + _ => break, } } - Ok(()) + Ok(removed) } } @@ -88,25 +108,40 @@ impl MatchPattern for CombineFilter { } impl NormalizationRule for CombineFilter { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Filter(op) = graph.operator(node_id).clone() { - if let Some(child_id) = graph.eldest_child_at(node_id) { - if let Operator::Filter(child_op) = graph.operator_mut(child_id) { + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let (parent_predicate, parent_having) = match &plan.operator { + Operator::Filter(op) => (op.predicate.clone(), op.having), + _ => return Ok(false), + }; + + let cursor = match only_child_mut(plan) { + Some(child) => child, + None => return Ok(false), + }; + + loop { + match &mut cursor.operator { + Operator::Filter(child_op) => { child_op.predicate = ScalarExpression::Binary { op: BinaryOperator::And, - left_expr: Box::new(op.predicate), + left_expr: Box::new(parent_predicate), right_expr: Box::new(child_op.predicate.clone()), evaluator: None, ty: LogicalType::Boolean, }; - child_op.having = op.having || child_op.having; + child_op.having = parent_having || child_op.having; - graph.remove_node(node_id, false); + return Ok(replace_with_only_child(plan)); } + Operator::Project(project_op) if is_passthrough_project(project_op) => { + if replace_with_only_child(cursor) { + continue; + } + return Ok(false); + } + _ => return Ok(false), } } - - Ok(()) } } @@ -119,35 +154,33 @@ impl MatchPattern for CollapseGroupByAgg { } impl NormalizationRule for CollapseGroupByAgg { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Aggregate(op) = graph.operator(node_id).clone() { - // if it is an aggregation operator containing agg_call + fn apply(&self, plan: &mut LogicalPlan) -> Result { + if let Operator::Aggregate(op) = plan.operator.clone() { if !op.agg_calls.is_empty() { - return Ok(()); + return Ok(false); } - if let Some(Operator::Aggregate(child_op)) = graph - .eldest_child_at(node_id) - .map(|child_id| graph.operator_mut(child_id)) - { - if op.groupby_exprs.len() != child_op.groupby_exprs.len() { - return Ok(()); - } - let mut expr_set = HashSet::new(); + if let Some(child) = only_child_mut(plan) { + if let Operator::Aggregate(child_op) = child.operator.clone() { + if op.groupby_exprs.len() != child_op.groupby_exprs.len() { + return Ok(false); + } + let mut expr_set = HashSet::new(); - for expr in op.groupby_exprs.iter() { - expr_set.insert(expr); - } - for expr in child_op.groupby_exprs.iter() { - expr_set.remove(expr); - } - if expr_set.is_empty() { - graph.remove_node(node_id, false); + for expr in op.groupby_exprs.iter() { + expr_set.insert(expr); + } + for expr in child_op.groupby_exprs.iter() { + expr_set.remove(expr); + } + if expr_set.is_empty() { + return Ok(replace_with_only_child(plan)); + } } } } - Ok(()) + Ok(false) } } @@ -155,39 +188,24 @@ impl NormalizationRule for CollapseGroupByAgg { mod tests { use crate::binder::test::build_t1_table; use crate::errors::DatabaseError; - use crate::expression::ScalarExpression::Constant; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::heuristic::batch::HepBatchStrategy; - use crate::optimizer::heuristic::graph::HepNodeId; use crate::optimizer::heuristic::optimizer::HepOptimizer; use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::planner::operator::Operator; use crate::planner::Childrens; use crate::storage::rocksdb::RocksTransaction; - use crate::types::value::DataValue; - use crate::types::LogicalType; #[test] fn test_collapse_project() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select c1, c2 from t1")?; + let plan = table_state.plan("select c1 from (select c1, c2 from t1) t")?; - let mut optimizer = HepOptimizer::new(plan.clone()).batch( + let optimizer = HepOptimizer::new(plan).batch( "test_collapse_project".to_string(), HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::CollapseProject], ); - - let mut new_project_op = optimizer.graph.operator(HepNodeId::new(0)).clone(); - - if let Operator::Project(op) = &mut new_project_op { - op.exprs.remove(0); - } else { - unreachable!("Should be a project operator") - } - - optimizer.graph.add_root(new_project_op); - let best_plan = optimizer.find_best::(None)?; if let Operator::Project(op) = &best_plan.operator { @@ -207,34 +225,49 @@ mod tests { } #[test] - fn test_combine_filter() -> Result<(), DatabaseError> { + fn test_collapse_project_with_alias() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 where c1 > 1")?; - - let mut optimizer = HepOptimizer::new(plan.clone()).batch( - "test_combine_filter".to_string(), + let plan = table_state.plan("select t.x from (select c1 as x from t1) t")?; + let original = plan.clone(); + let original_child = original.childrens.pop_only(); + assert!(matches!(original_child.operator, Operator::Project(_))); + let original_grandchild = original_child.childrens.pop_only(); + assert!(matches!(original_grandchild.operator, Operator::Project(_))); + + let optimizer = HepOptimizer::new(plan).batch( + "test_collapse_project_with_alias".to_string(), HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::CombineFilter], + vec![NormalizationRuleImpl::CollapseProject], ); - - let mut new_filter_op = optimizer.graph.operator(HepNodeId::new(1)).clone(); - - if let Operator::Filter(op) = &mut new_filter_op { - op.predicate = ScalarExpression::Binary { - op: BinaryOperator::Eq, - left_expr: Box::new(Constant(DataValue::Int8(1))), - right_expr: Box::new(Constant(DataValue::Int8(1))), - evaluator: None, - ty: LogicalType::Boolean, - } + let best_plan = optimizer.find_best::(None)?; + if let Operator::Project(op) = &best_plan.operator { + assert_eq!(op.exprs.len(), 1); } else { - unreachable!("Should be a filter operator") + unreachable!("Should be a project operator"); } - optimizer - .graph - .add_node(HepNodeId::new(0), Some(HepNodeId::new(1)), new_filter_op); + let scan_op = best_plan.childrens.pop_only(); + assert!(matches!(scan_op.operator, Operator::Project(_))); + let scan_child = scan_op.childrens.pop_only(); + assert!( + !matches!(scan_child.operator, Operator::Project(_)), + "Child project should be collapsed" + ); + + Ok(()) + } + + #[test] + fn test_combine_filter() -> Result<(), DatabaseError> { + let table_state = build_t1_table()?; + let plan = + table_state.plan("select * from (select * from t1 where c1 > 1) t where 1 = 1")?; + let optimizer = HepOptimizer::new(plan).batch( + "test_combine_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::CombineFilter], + ); let best_plan = optimizer.find_best::(None)?; let filter_op = best_plan.childrens.pop_only(); diff --git a/src/optimizer/rule/normalization/compilation_in_advance.rs b/src/optimizer/rule/normalization/compilation_in_advance.rs index 93950112..63585445 100644 --- a/src/optimizer/rule/normalization/compilation_in_advance.rs +++ b/src/optimizer/rule/normalization/compilation_in_advance.rs @@ -17,9 +17,9 @@ use crate::expression::visitor_mut::VisitorMut; use crate::expression::{BindEvaluator, BindPosition, ScalarExpression}; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan}; use std::borrow::Cow; use std::sync::LazyLock; @@ -39,23 +39,26 @@ pub struct BindExpressionPosition; impl BindExpressionPosition { fn _apply( output_exprs: &mut Vec, - node_id: HepNodeId, - graph: &mut HepGraph, + plan: &mut LogicalPlan, ) -> Result<(), DatabaseError> { - if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::_apply(output_exprs, child_id, graph)?; - } - // for join let mut left_len = 0; - if let Operator::Join(_) | Operator::Union(_) | Operator::Except(_) = - graph.operator(node_id) - { - let mut second_output_exprs = Vec::new(); - if let Some(child_id) = graph.youngest_child_at(node_id) { - Self::_apply(&mut second_output_exprs, child_id, graph)?; + match plan.childrens.as_mut() { + Childrens::Only(child) => { + Self::_apply(output_exprs, child)?; + } + Childrens::Twins { left, right } => { + Self::_apply(output_exprs, left)?; + if matches!( + plan.operator, + Operator::Join(_) | Operator::Union(_) | Operator::Except(_) + ) { + let mut second_output_exprs = Vec::new(); + Self::_apply(&mut second_output_exprs, right)?; + left_len = output_exprs.len(); + output_exprs.append(&mut second_output_exprs); + } } - left_len = output_exprs.len(); - output_exprs.append(&mut second_output_exprs); + Childrens::None => {} } let mut bind_position = BindPosition::new( || { @@ -65,7 +68,7 @@ impl BindExpressionPosition { }, |a, b| a == b, ); - let operator = graph.operator_mut(node_id); + let operator = &mut plan.operator; match operator { Operator::Join(op) => { match &mut op.on { @@ -172,12 +175,9 @@ impl MatchPattern for BindExpressionPosition { } impl NormalizationRule for BindExpressionPosition { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - Self::_apply(&mut Vec::new(), node_id, graph)?; - // mark changed to skip this rule batch - graph.version += 1; - - Ok(()) + fn apply(&self, plan: &mut LogicalPlan) -> Result { + Self::_apply(&mut Vec::new(), plan)?; + Ok(true) } } @@ -185,17 +185,19 @@ impl NormalizationRule for BindExpressionPosition { pub struct EvaluatorBind; impl EvaluatorBind { - fn _apply(node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::_apply(child_id, graph)?; - } - // for join - if let Operator::Join(_) = graph.operator(node_id) { - if let Some(child_id) = graph.youngest_child_at(node_id) { - Self::_apply(child_id, graph)?; + fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + match plan.childrens.as_mut() { + Childrens::Only(child) => Self::_apply(child)?, + Childrens::Twins { left, right } => { + Self::_apply(left)?; + if matches!(plan.operator, Operator::Join(_)) { + Self::_apply(right)?; + } } + Childrens::None => {} } - let operator = graph.operator_mut(node_id); + + let operator = &mut plan.operator; match operator { Operator::Join(op) => { @@ -284,11 +286,8 @@ impl MatchPattern for EvaluatorBind { } impl NormalizationRule for EvaluatorBind { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - Self::_apply(node_id, graph)?; - // mark changed to skip this rule batch - graph.version += 1; - - Ok(()) + fn apply(&self, plan: &mut LogicalPlan) -> Result { + Self::_apply(plan)?; + Ok(true) } } diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 59684a0a..dbc62f55 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -13,10 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::expression::ScalarExpression; +use crate::expression::{AliasType, ScalarExpression}; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::rule::normalization::column_pruning::ColumnPruning; use crate::optimizer::rule::normalization::combine_operators::{ CollapseGroupByAgg, CollapseProject, CombineFilter, @@ -33,6 +32,7 @@ use crate::optimizer::rule::normalization::pushdown_predicates::PushPredicateThr use crate::optimizer::rule::normalization::simplification::ConstantCalculation; use crate::optimizer::rule::normalization::simplification::SimplifyFilter; use crate::optimizer::rule::normalization::top_k::TopK; +use crate::planner::LogicalPlan; mod column_pruning; mod combine_operators; mod compilation_in_advance; @@ -87,39 +87,60 @@ impl MatchPattern for NormalizationRuleImpl { } impl NormalizationRule for NormalizationRuleImpl { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + fn apply(&self, plan: &mut LogicalPlan) -> Result { match self { - NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph), - NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), - NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph), - NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), - NormalizationRuleImpl::LimitProjectTranspose => { - LimitProjectTranspose.apply(node_id, graph) - } - NormalizationRuleImpl::PushLimitThroughJoin => { - PushLimitThroughJoin.apply(node_id, graph) - } - NormalizationRuleImpl::PushLimitIntoTableScan => { - PushLimitIntoScan.apply(node_id, graph) - } - NormalizationRuleImpl::PushPredicateThroughJoin => { - PushPredicateThroughJoin.apply(node_id, graph) - } - NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.apply(node_id, graph), - NormalizationRuleImpl::PushPredicateIntoScan => { - PushPredicateIntoScan.apply(node_id, graph) - } - NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph), - NormalizationRuleImpl::BindExpressionPosition => { - BindExpressionPosition.apply(node_id, graph) - } - NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(node_id, graph), - NormalizationRuleImpl::TopK => TopK.apply(node_id, graph), + NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(plan), + NormalizationRuleImpl::CollapseProject => CollapseProject.apply(plan), + NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(plan), + NormalizationRuleImpl::CombineFilter => CombineFilter.apply(plan), + NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.apply(plan), + NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.apply(plan), + NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.apply(plan), + NormalizationRuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.apply(plan), + NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.apply(plan), + NormalizationRuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(plan), + NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(plan), + NormalizationRuleImpl::BindExpressionPosition => BindExpressionPosition.apply(plan), + NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(plan), + NormalizationRuleImpl::TopK => TopK.apply(plan), } } } /// Return true when left is subset of right +pub(crate) fn strip_alias(expr: &ScalarExpression) -> &ScalarExpression { + match expr { + ScalarExpression::Alias { + expr, + alias: AliasType::Name(_), + } => strip_alias(expr), + ScalarExpression::Alias { + alias: AliasType::Expr(alias_expr), + .. + } => strip_alias(alias_expr), + _ => expr, + } +} + +fn strip_all_alias(expr: &ScalarExpression) -> &ScalarExpression { + match expr { + ScalarExpression::Alias { expr, .. } => strip_all_alias(expr), + _ => expr, + } +} + pub fn is_subset_exprs(left: &[ScalarExpression], right: &[ScalarExpression]) -> bool { - left.iter().all(|l| right.contains(l)) + left.iter().all(|lhs| { + let lhs_stripped = strip_alias(lhs); + right.iter().any(|rhs| { + let rhs_stripped = strip_alias(rhs); + if lhs_stripped == rhs_stripped { + return true; + } + if matches!(lhs, ScalarExpression::ColumnRef { .. }) { + return lhs_stripped == strip_all_alias(rhs); + } + false + }) + }) } diff --git a/src/optimizer/rule/normalization/pushdown_limit.rs b/src/optimizer/rule/normalization/pushdown_limit.rs index ff3327d3..65f75126 100644 --- a/src/optimizer/rule/normalization/pushdown_limit.rs +++ b/src/optimizer/rule/normalization/pushdown_limit.rs @@ -16,10 +16,10 @@ use crate::errors::DatabaseError; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child, wrap_child_with}; use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; -use itertools::Itertools; +use crate::planner::LogicalPlan; use std::sync::LazyLock; static LIMIT_PROJECT_TRANSPOSE_RULE: LazyLock = LazyLock::new(|| Pattern { @@ -55,12 +55,35 @@ impl MatchPattern for LimitProjectTranspose { } impl NormalizationRule for LimitProjectTranspose { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Some(child_id) = graph.eldest_child_at(node_id) { - graph.swap_node(node_id, child_id); + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let operator = std::mem::replace(&mut plan.operator, Operator::Dummy); + + let limit_op = match operator { + Operator::Limit(op) => op, + other => { + plan.operator = other; + return Ok(false); + } + }; + + let mut project_op = None; + + if let Some(child) = only_child_mut(plan) { + if matches!(child.operator, Operator::Project(_)) { + project_op = Some(std::mem::replace( + &mut child.operator, + Operator::Limit(limit_op.clone()), + )); + } } - Ok(()) + if let Some(project_op) = project_op { + plan.operator = project_op; + return Ok(true); + } + + plan.operator = Operator::Limit(limit_op); + Ok(false) } } @@ -79,32 +102,29 @@ impl MatchPattern for PushLimitThroughJoin { } impl NormalizationRule for PushLimitThroughJoin { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Limit(op) = graph.operator(node_id) { - if let Some(child_id) = graph.eldest_child_at(node_id) { - let join_type = if let Operator::Join(op) = graph.operator(child_id) { - Some(op.join_type) - } else { - None - }; - - if let Some(ty) = join_type { - let children = graph.children_at(child_id).collect_vec(); - - if let Some(grandson_id) = match ty { - JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { - children.first() - } - JoinType::RightOuter => children.last(), - _ => None, - } { - graph.add_node(child_id, Some(*grandson_id), Operator::Limit(op.clone())); + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let limit_op = match &plan.operator { + Operator::Limit(op) => op.clone(), + _ => return Ok(false), + }; + + if let Some(child) = only_child_mut(plan) { + if let Operator::Join(join_op) = &child.operator { + let mut applied = false; + match join_op.join_type { + JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { + applied |= wrap_child_with(child, 0, Operator::Limit(limit_op.clone())); + } + JoinType::RightOuter => { + applied |= wrap_child_with(child, 1, Operator::Limit(limit_op)); } + _ => {} } + return Ok(applied); } } - Ok(()) + Ok(false) } } @@ -118,23 +138,21 @@ impl MatchPattern for PushLimitIntoScan { } impl NormalizationRule for PushLimitIntoScan { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Limit(limit_op) = graph.operator(node_id) { - if let Some(child_index) = graph.eldest_child_at(node_id) { - let mut is_apply = false; - let limit = (limit_op.offset, limit_op.limit); - - if let Operator::TableScan(scan_op) = graph.operator_mut(child_index) { - scan_op.limit = limit; - is_apply = true; - } - if is_apply { - graph.remove_node(node_id, false); - } + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let (offset, limit) = match &plan.operator { + Operator::Limit(limit_op) => (limit_op.offset, limit_op.limit), + _ => return Ok(false), + }; + + if let Some(child) = only_child_mut(plan) { + if let Operator::TableScan(scan_op) = &mut child.operator { + scan_op.limit = (offset, limit); + let removed = replace_with_only_child(plan); + return Ok(removed); } } - Ok(()) + Ok(false) } } diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index f345059b..f190b96b 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -19,10 +19,13 @@ use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::plan_utils::{ + left_child, only_child_mut, replace_with_only_child, right_child, wrap_child_with, +}; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::JoinType; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use crate::types::index::{IndexInfo, IndexMetaRef, IndexType}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -115,105 +118,111 @@ impl MatchPattern for PushPredicateThroughJoin { } impl NormalizationRule for PushPredicateThroughJoin { - // TODO: pushdown_predicates need to consider output columns - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - let child_id = match graph.eldest_child_at(node_id) { - Some(child_id) => child_id, - None => return Ok(()), + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let filter_op = match &plan.operator { + Operator::Filter(op) => op.clone(), + _ => return Ok(false), }; - if let Operator::Join(child_op) = graph.operator(child_id) { + + let mut applied = false; + + let parent_replacement = { + let join_plan = match only_child_mut(plan) { + Some(child) => child, + None => return Ok(false), + }; + + let join_op = match &join_plan.operator { + Operator::Join(op) => op, + _ => return Ok(false), + }; + if !matches!( - child_op.join_type, + join_op.join_type, JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightOuter ) { - return Ok(()); + return Ok(false); } - let join_childs = graph.children_at(child_id).collect_vec(); - let left_columns = graph.operator(join_childs[0]).referenced_columns(true); - let right_columns = graph.operator(join_childs[1]).referenced_columns(true); + let left_columns = left_child(join_plan) + .map(|child| child.operator.referenced_columns(true)) + .unwrap_or_default(); + let right_columns = right_child(join_plan) + .map(|child| child.operator.referenced_columns(true)) + .unwrap_or_default(); + + let filter_exprs = split_conjunctive_predicates(&filter_op.predicate); + let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs + .into_iter() + .partition(|f| is_subset_cols(&f.referenced_columns(true), &left_columns)); + let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest + .into_iter() + .partition(|f| is_subset_cols(&f.referenced_columns(true), &right_columns)); let mut new_ops = (None, None, None); - - if let Operator::Filter(op) = graph.operator(node_id) { - let filter_exprs = split_conjunctive_predicates(&op.predicate); - - let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs - .into_iter() - .partition(|f| is_subset_cols(&f.referenced_columns(true), &left_columns)); - let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest - .into_iter() - .partition(|f| is_subset_cols(&f.referenced_columns(true), &right_columns)); - - let replace_filters = match child_op.join_type { - JoinType::Inner => { - if !left_filters.is_empty() { - if let Some(left_filter_op) = reduce_filters(left_filters, op.having) { - new_ops.0 = Some(Operator::Filter(left_filter_op)); - } - } - - if !right_filters.is_empty() { - if let Some(right_filter_op) = reduce_filters(right_filters, op.having) - { - new_ops.1 = Some(Operator::Filter(right_filter_op)); - } - } - - common_filters + let replace_filters = match join_op.join_type { + JoinType::Inner => { + if let Some(left_filter_op) = reduce_filters(left_filters, filter_op.having) { + new_ops.0 = Some(Operator::Filter(left_filter_op)); } - JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { - if !left_filters.is_empty() { - if let Some(left_filter_op) = reduce_filters(left_filters, op.having) { - new_ops.0 = Some(Operator::Filter(left_filter_op)); - } - } - common_filters - .into_iter() - .chain(right_filters) - .collect_vec() + if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { + new_ops.1 = Some(Operator::Filter(right_filter_op)); } - JoinType::RightOuter => { - if !right_filters.is_empty() { - if let Some(right_filter_op) = reduce_filters(right_filters, op.having) - { - new_ops.1 = Some(Operator::Filter(right_filter_op)); - } - } - common_filters.into_iter().chain(left_filters).collect_vec() + common_filters + } + JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { + if let Some(left_filter_op) = reduce_filters(left_filters, filter_op.having) { + new_ops.0 = Some(Operator::Filter(left_filter_op)); } - _ => vec![], - }; - if !replace_filters.is_empty() { - if let Some(replace_filter_op) = reduce_filters(replace_filters, op.having) { - new_ops.2 = Some(Operator::Filter(replace_filter_op)); + common_filters + .into_iter() + .chain(right_filters) + .collect_vec() + } + JoinType::RightOuter => { + if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { + new_ops.1 = Some(Operator::Filter(right_filter_op)); } + + common_filters.into_iter().chain(left_filters).collect_vec() } + _ => vec![], + }; + + if let Some(replace_filter_op) = reduce_filters(replace_filters, filter_op.having) { + new_ops.2 = Some(Operator::Filter(replace_filter_op)); } if let Some(left_op) = new_ops.0 { - graph.add_node(child_id, Some(join_childs[0]), left_op); + applied |= wrap_child_with(join_plan, 0, left_op); } if let Some(right_op) = new_ops.1 { - graph.add_node(child_id, Some(join_childs[1]), right_op); + applied |= wrap_child_with(join_plan, 1, right_op); } - if let Some(common_op) = new_ops.2 { - graph.replace_node(node_id, common_op); - } else { - graph.remove_node(node_id, false); + new_ops.2 + }; + + match parent_replacement { + Some(common_op) => { + plan.operator = common_op; + applied = true; + } + None if applied => { + applied |= replace_with_only_child(plan); } + _ => {} } - Ok(()) + Ok(applied) } } @@ -226,10 +235,11 @@ impl MatchPattern for PushPredicateIntoScan { } impl NormalizationRule for PushPredicateIntoScan { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Filter(op) = graph.operator(node_id).clone() { - if let Some(child_id) = graph.eldest_child_at(node_id) { - if let Operator::TableScan(scan_op) = graph.operator_mut(child_id) { + fn apply(&self, plan: &mut LogicalPlan) -> Result { + if let Operator::Filter(op) = plan.operator.clone() { + if let Some(child) = only_child_mut(plan) { + if let Operator::TableScan(scan_op) = &mut child.operator { + let mut changed = false; for IndexInfo { meta, range, @@ -239,7 +249,6 @@ impl NormalizationRule for PushPredicateIntoScan { if range.is_some() { continue; } - // range detach *range = match meta.ty { IndexType::PrimaryKey { is_multiple: false } | IndexType::Unique @@ -251,7 +260,11 @@ impl NormalizationRule for PushPredicateIntoScan { Self::composite_range(&op, meta)? } }; - // try index covered + if range.is_none() { + continue; + } + changed = true; + let mut deserializers = Vec::with_capacity(meta.column_ids.len()); let mut cover_count = 0; let index_column_types = match &meta.value_ty { @@ -261,7 +274,7 @@ impl NormalizationRule for PushPredicateIntoScan { for (i, column_id) in meta.column_ids.iter().enumerate() { for column in scan_op.columns.values() { deserializers.push( - if column.id().map(|id| &id == column_id).unwrap_or(false) { + if column.id().map(|id| id == *column_id).unwrap_or(false) { cover_count += 1; column.datatype().serializable() } else { @@ -271,14 +284,15 @@ impl NormalizationRule for PushPredicateIntoScan { } } if cover_count == scan_op.columns.len() { - *covered_deserializers = Some(deserializers) + *covered_deserializers = Some(deserializers); } } + return Ok(changed); } } } - Ok(()) + Ok(false) } } diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index c2b34ff2..6ac9a856 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -17,10 +17,9 @@ use crate::expression::simplify::{ConstantCalculator, Simplify}; use crate::expression::visitor_mut::VisitorMut; use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; -use itertools::Itertools; +use crate::planner::{Childrens, LogicalPlan}; use std::sync::LazyLock; static CONSTANT_CALCULATION_RULE: LazyLock = LazyLock::new(|| Pattern { @@ -40,8 +39,8 @@ static SIMPLIFY_FILTER_RULE: LazyLock = LazyLock::new(|| Pattern { pub struct ConstantCalculation; impl ConstantCalculation { - fn _apply(node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - let operator = graph.operator_mut(node_id); + fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + let operator = &mut plan.operator; match operator { Operator::Aggregate(op) => { @@ -75,8 +74,13 @@ impl ConstantCalculation { } _ => (), } - for child_id in graph.children_at(node_id).collect_vec() { - Self::_apply(child_id, graph)?; + match plan.childrens.as_mut() { + Childrens::Only(child) => Self::_apply(child.as_mut())?, + Childrens::Twins { left, right } => { + Self::_apply(left.as_mut())?; + Self::_apply(right.as_mut())?; + } + Childrens::None => (), } Ok(()) @@ -90,12 +94,9 @@ impl MatchPattern for ConstantCalculation { } impl NormalizationRule for ConstantCalculation { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - Self::_apply(node_id, graph)?; - // mark changed to skip this rule batch - graph.version += 1; - - Ok(()) + fn apply(&self, plan: &mut LogicalPlan) -> Result { + Self::_apply(plan)?; + Ok(true) } } @@ -109,22 +110,18 @@ impl MatchPattern for SimplifyFilter { } impl NormalizationRule for SimplifyFilter { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - let mut is_optimized = false; - if let Operator::Filter(filter_op) = graph.operator_mut(node_id) { + fn apply(&self, plan: &mut LogicalPlan) -> Result { + if let Operator::Filter(filter_op) = &mut plan.operator { if filter_op.is_optimized { - return Ok(()); + return Ok(false); } ConstantCalculator.visit(&mut filter_op.predicate)?; Simplify::default().visit(&mut filter_op.predicate)?; filter_op.is_optimized = true; - is_optimized = true; - } - if is_optimized { - graph.version += 1; + return Ok(true); } - Ok(()) + Ok(false) } } diff --git a/src/optimizer/rule/normalization/top_k.rs b/src/optimizer/rule/normalization/top_k.rs index 12f97f57..4775024a 100644 --- a/src/optimizer/rule/normalization/top_k.rs +++ b/src/optimizer/rule/normalization/top_k.rs @@ -16,9 +16,10 @@ use crate::errors::DatabaseError; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::pattern::PatternChildrenPredicate; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; -use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; use crate::planner::operator::top_k::TopKOperator; use crate::planner::operator::Operator; +use crate::planner::LogicalPlan; use std::sync::LazyLock; static TOP_K_RULE: LazyLock = LazyLock::new(|| Pattern { @@ -38,23 +39,37 @@ impl MatchPattern for TopK { } impl NormalizationRule for TopK { - fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { - if let Operator::Limit(op) = graph.operator(node_id) { - if let Some(limit) = op.limit { - let sort_id = graph.eldest_child_at(node_id).unwrap(); - if let Operator::Sort(sort_op) = graph.operator(sort_id) { - graph.replace_node( - node_id, - Operator::TopK(TopKOperator { - sort_fields: sort_op.sort_fields.clone(), - limit, - offset: op.offset, - }), - ); - graph.remove_node(sort_id, false); + fn apply(&self, plan: &mut LogicalPlan) -> Result { + let (offset, limit) = match &plan.operator { + Operator::Limit(op) => match op.limit { + Some(limit) => (op.offset, limit), + None => return Ok(false), + }, + _ => return Ok(false), + }; + + let sort_fields = { + let child = match only_child_mut(plan) { + Some(child) => child, + None => return Ok(false), + }; + + match &child.operator { + Operator::Sort(sort_op) => { + let fields = sort_op.sort_fields.clone(); + let removed = replace_with_only_child(child); + debug_assert!(removed); + fields } + _ => return Ok(false), } - } - Ok(()) + }; + + plan.operator = Operator::TopK(TopKOperator { + sort_fields, + limit, + offset, + }); + Ok(true) } } diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 4db30993..f7a2fced 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -125,6 +125,22 @@ impl LogicalPlan { } } + pub(crate) fn reset_output_schema_cache(&mut self) { + self._output_schema_ref = None; + } + + pub(crate) fn reset_output_schema_cache_recursive(&mut self) { + self.reset_output_schema_cache(); + match self.childrens.as_mut() { + Childrens::Only(child) => child.reset_output_schema_cache_recursive(), + Childrens::Twins { left, right } => { + left.reset_output_schema_cache_recursive(); + right.reset_output_schema_cache_recursive(); + } + Childrens::None => (), + } + } + pub fn referenced_table(&self) -> Vec { fn collect_table(plan: &LogicalPlan, results: &mut Vec) { if let Operator::TableScan(op) = &plan.operator {