From d90a6158f479320d4b89c7f7c4222c4e7ec9319e Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 28 Aug 2024 18:49:15 +0400 Subject: [PATCH] feat(cubesql): Top-down extractor for rewrites --- packages/cubejs-backend-native/Cargo.lock | 1 + rust/cubenativeutils/Cargo.lock | 1 + rust/cubesql/Cargo.lock | 21 +- rust/cubesql/cubesql/Cargo.toml | 1 + rust/cubesql/cubesql/src/compile/mod.rs | 55 +- .../cubesql/src/compile/query_engine.rs | 8 +- .../cubesql/src/compile/rewrite/cost.rs | 734 +++++++++++++----- .../cubesql/src/compile/rewrite/rewriter.rs | 35 +- rust/cubesql/cubesql/src/config/mod.rs | 9 + rust/cubesqlplanner/Cargo.lock | 1 + 10 files changed, 650 insertions(+), 216 deletions(-) diff --git a/packages/cubejs-backend-native/Cargo.lock b/packages/cubejs-backend-native/Cargo.lock index 2b03e9de8ee5e..5eb487efaf687 100644 --- a/packages/cubejs-backend-native/Cargo.lock +++ b/packages/cubejs-backend-native/Cargo.lock @@ -804,6 +804,7 @@ dependencies = [ "futures-core", "futures-util", "hashbrown 0.14.3", + "indexmap 1.9.3", "itertools", "log", "lru", diff --git a/rust/cubenativeutils/Cargo.lock b/rust/cubenativeutils/Cargo.lock index 66b4d9227c5b0..8e60af93afa27 100644 --- a/rust/cubenativeutils/Cargo.lock +++ b/rust/cubenativeutils/Cargo.lock @@ -698,6 +698,7 @@ dependencies = [ "futures-core", "futures-util", "hashbrown 0.14.5", + "indexmap 1.9.3", "itertools", "log", "lru", diff --git a/rust/cubesql/Cargo.lock b/rust/cubesql/Cargo.lock index a832246ac7504..58b0377de0dd2 100644 --- a/rust/cubesql/Cargo.lock +++ b/rust/cubesql/Cargo.lock @@ -136,7 +136,7 @@ dependencies = [ "flatbuffers", "half", "hex", - "indexmap 1.8.1", + "indexmap 1.9.3", "lazy_static", "lexical-core", "multiversion", @@ -473,7 +473,7 @@ checksum = "86447ad904c7fb335a790c9d7fe3d0d971dc523b8ccd1561a520de9a85302750" dependencies = [ "bitflags 1.3.2", "clap_lex", - "indexmap 1.8.1", + "indexmap 1.9.3", "textwrap", ] @@ -773,6 +773,7 @@ dependencies = [ "futures-core", "futures-util", "hashbrown 0.14.3", + "indexmap 1.9.3", "insta", "itertools", "log", @@ -995,7 +996,7 @@ dependencies = [ "env_logger", "fxhash", "hashbrown 0.12.1", - "indexmap 1.8.1", + "indexmap 1.9.3", "instant", "log", "num-bigint", @@ -1276,12 +1277,6 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" - [[package]] name = "hashbrown" version = "0.12.1" @@ -1499,12 +1494,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.8.1" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f647032dfaa1f8b6dc29bd3edb7bbef4861b8b8007ebb118d6db284fd59f6ee" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown 0.11.2", + "hashbrown 0.12.1", ] [[package]] @@ -2762,7 +2757,7 @@ version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4a521f2940385c165a24ee286aa8599633d162077a54bdcae2a6fd5a7bfa7a0" dependencies = [ - "indexmap 1.8.1", + "indexmap 1.9.3", "ryu", "serde", "yaml-rust", diff --git a/rust/cubesql/cubesql/Cargo.toml b/rust/cubesql/cubesql/Cargo.toml index 0e32b9568d358..60c6d5284eafd 100644 --- a/rust/cubesql/cubesql/Cargo.toml +++ b/rust/cubesql/cubesql/Cargo.toml @@ -58,6 +58,7 @@ minijinja = { version = "1", features = ["json", "loader"] } lru = "0.12.1" sha2 = "0.10.8" bigdecimal = "0.4.2" +indexmap = "1.9.3" [dev-dependencies] diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 761e413ae433c..b0d27d43edf48 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -14369,7 +14369,11 @@ ORDER BY "source"."str0" ASC .wrapped_sql .unwrap() .sql; - assert!(sql.contains("\"limit\":1000")); + if Rewriter::top_down_extractor_enabled() { + assert!(sql.contains("LIMIT 1000")); + } else { + assert!(sql.contains("\"limit\":1000")); + } assert!(sql.contains("% 7")); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -16547,7 +16551,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), measure(count) AS cnt, date_trunc('month', order_date) AS dt FROM KibanaSampleDataEcommerce - WHERE order_date IN (to_timestamp('2019-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')) + WHERE date_trunc('month', order_date) IN (to_timestamp('2019-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US')) GROUP BY 2 ;"# .to_string(), @@ -16565,10 +16569,18 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { dimension: "KibanaSampleDataEcommerce.order_date".to_string(), granularity: Some("month".to_string()), - date_range: Some(json!(vec![ - "2019-01-01T00:00:00.000Z".to_string(), - "2019-01-01T00:00:00.000Z".to_string() - ])) + date_range: if Rewriter::top_down_extractor_enabled() { + Some(json!(vec![ + "2019-01-01T00:00:00.000Z".to_string(), + "2019-01-31T23:59:59.999Z".to_string() + ])) + } else { + // Non-optimal variant with top down extractor disabled + Some(json!(vec![ + "2019-01-01 00:00:00.000".to_string(), + "2019-01-31 23:59:59.999".to_string() + ])) + } }]), order: Some(vec![]), limit: None, @@ -16915,4 +16927,35 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), Ok(()) } + + #[tokio::test] + async fn test_wrapper_limit_zero() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT MAX(order_date) FROM KibanaSampleDataEcommerce LIMIT 0 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql; + assert!(sql.contains("LIMIT 0")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } } diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index c7ba608a36e3c..2532fa12efc48 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -190,7 +190,13 @@ pub trait QueryEngine { let mut rewriter = Rewriter::new(finalized_graph, cube_ctx.clone()); let result = rewriter - .find_best_plan(root, state.auth_context().unwrap(), qtrace, span_id.clone()) + .find_best_plan( + root, + state.auth_context().unwrap(), + qtrace, + span_id.clone(), + self.config_ref().top_down_extractor(), + ) .await .map_err(|e| match e.cause { CubeErrorCauseType::Internal(_) => CompilationError::Internal( diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index 9a573390a7f0d..83f23e6dedef1 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -1,3 +1,5 @@ +use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; + use crate::{ compile::rewrite::{ rules::utils::granularity_str_to_int_order, CubeScanUngrouped, CubeScanWrapped, @@ -6,9 +8,10 @@ use crate::{ }, transport::{MetaContext, V1CubeMetaDimensionExt}, }; -use egg::{CostFunction, Id, Language}; -use std::sync::Arc; +use egg::{Analysis, CostFunction, EGraph, Id, Language, RecExpr}; +use indexmap::IndexSet; +#[derive(Debug)] pub struct BestCubePlan { meta_context: Arc, } @@ -17,6 +20,200 @@ impl BestCubePlan { pub fn new(meta_context: Arc) -> Self { Self { meta_context } } + + pub fn initial_cost(&self, enode: &LogicalPlanLanguage, top_down: bool) -> CubePlanCost { + let table_scans = match enode { + LogicalPlanLanguage::TableScan(_) => 1, + _ => 0, + }; + + let non_detected_cube_scans = match enode { + LogicalPlanLanguage::CubeScan(_) => 1, + _ => 0, + }; + + let cube_scan_nodes = match enode { + LogicalPlanLanguage::CubeScan(_) => 1, + _ => 0, + }; + + let non_pushed_down_window = match enode { + LogicalPlanLanguage::Window(_) => 1, + _ => 0, + }; + + let non_pushed_down_grouping_sets = match enode { + LogicalPlanLanguage::GroupingSetExpr(_) => 1, + _ => 0, + }; + + let non_pushed_down_limit_sort = match enode { + LogicalPlanLanguage::Limit(_) if !top_down => 1, + LogicalPlanLanguage::Sort(_) if top_down => 1, + _ => 0, + }; + + let ast_size_inside_wrapper = match enode { + LogicalPlanLanguage::WrappedSelect(_) => 1, + _ => 0, + }; + + let wrapper_nodes = match enode { + LogicalPlanLanguage::CubeScanWrapper(_) => 1, + _ => 0, + }; + + let filter_members = match enode { + LogicalPlanLanguage::FilterMember(_) => 1, + _ => 0, + }; + + let filters = match enode { + LogicalPlanLanguage::Filter(_) => 1, + _ => 0, + }; + + let member_errors = match enode { + LogicalPlanLanguage::MemberError(_) => 1, + _ => 0, + }; + + let cube_members = match enode { + LogicalPlanLanguage::Measure(_) => 1, + LogicalPlanLanguage::Dimension(_) => 1, + LogicalPlanLanguage::ChangeUser(_) => 1, + LogicalPlanLanguage::VirtualField(_) => 1, + LogicalPlanLanguage::LiteralMember(_) => 1, + LogicalPlanLanguage::TimeDimensionGranularity(TimeDimensionGranularity(Some(_))) => 1, + // MemberError must be present here as well in order to preserve error priority + LogicalPlanLanguage::MemberError(_) => 1, + _ => 0, + }; + + let this_replacers = match enode { + LogicalPlanLanguage::OrderReplacer(_) => 1, + LogicalPlanLanguage::MemberReplacer(_) => 1, + LogicalPlanLanguage::FilterReplacer(_) => 1, + LogicalPlanLanguage::TimeDimensionDateRangeReplacer(_) => 1, + LogicalPlanLanguage::InnerAggregateSplitReplacer(_) => 1, + LogicalPlanLanguage::OuterProjectionSplitReplacer(_) => 1, + LogicalPlanLanguage::OuterAggregateSplitReplacer(_) => 1, + LogicalPlanLanguage::GroupExprSplitReplacer(_) => 1, + LogicalPlanLanguage::GroupAggregateSplitReplacer(_) => 1, + LogicalPlanLanguage::MemberPushdownReplacer(_) => 1, + LogicalPlanLanguage::EventNotification(_) => 1, + LogicalPlanLanguage::MergedMembersReplacer(_) => 1, + LogicalPlanLanguage::CaseExprReplacer(_) => 1, + LogicalPlanLanguage::WrapperPushdownReplacer(_) => 1, + LogicalPlanLanguage::WrapperPullupReplacer(_) => 1, + LogicalPlanLanguage::FlattenPushdownReplacer(_) => 1, + LogicalPlanLanguage::AggregateSplitPushDownReplacer(_) => 1, + LogicalPlanLanguage::AggregateSplitPullUpReplacer(_) => 1, + LogicalPlanLanguage::ProjectionSplitPushDownReplacer(_) => 1, + LogicalPlanLanguage::ProjectionSplitPullUpReplacer(_) => 1, + LogicalPlanLanguage::QueryParam(_) => 1, + // Not really replacers but those should be deemed as mandatory rewrites and as soon as + // there's always rewrite rule it's fine to have replacer cost. + // Needs to be added as alias rewrite always more expensive than original function. + LogicalPlanLanguage::ScalarUDFExprFun(ScalarUDFExprFun(fun)) + if fun.as_str() == "current_timestamp" => + { + 1 + } + LogicalPlanLanguage::ScalarUDFExprFun(ScalarUDFExprFun(fun)) + if fun.as_str() == "localtimestamp" => + { + 1 + } + _ => 0, + }; + + let time_dimensions_used_as_dimensions = match enode { + LogicalPlanLanguage::DimensionName(DimensionName(name)) => { + if let Some(dimension) = self.meta_context.find_dimension_with_name(name.clone()) { + if dimension.is_time() { + 1 + } else { + 0 + } + } else { + 0 + } + } + _ => 0, + }; + + let max_time_dimensions_granularity = match enode { + LogicalPlanLanguage::TimeDimensionGranularity(TimeDimensionGranularity(Some( + granularity, + ))) => (8 - granularity_str_to_int_order(granularity, Some(false)).unwrap_or(0)) as i64, + _ => 0, + }; + + let this_errors = match enode { + LogicalPlanLanguage::MemberErrorPriority(MemberErrorPriority(priority)) => { + (100 - priority) as i64 + } + _ => 0, + }; + + let structure_points = match enode { + // TODO needed to get rid of FilterOpFilters on upper level + LogicalPlanLanguage::FilterOpFilters(_) => 1, + LogicalPlanLanguage::Join(_) => 1, + LogicalPlanLanguage::CrossJoin(_) => 1, + _ => 0, + }; + + let ast_size_without_alias = match enode { + LogicalPlanLanguage::AliasExpr(_) => 0, + LogicalPlanLanguage::AliasExprAlias(_) => 0, + _ => 1, + }; + + let ungrouped_nodes = match enode { + LogicalPlanLanguage::CubeScanUngrouped(CubeScanUngrouped(true)) => 1, + _ => 0, + }; + + let wrapped_select_ungrouped_scan = match enode { + LogicalPlanLanguage::WrappedSelectUngroupedScan(WrappedSelectUngroupedScan(true)) => 1, + _ => 0, + }; + + let unwrapped_subqueries = match enode { + LogicalPlanLanguage::Subquery(_) => 1, + _ => 0, + }; + + CubePlanCost { + replacers: this_replacers, + table_scans, + filters, + filter_members, + non_detected_cube_scans, + member_errors, + non_pushed_down_window, + non_pushed_down_grouping_sets, + non_pushed_down_limit_sort, + cube_members, + errors: this_errors, + time_dimensions_used_as_dimensions, + max_time_dimensions_granularity, + structure_points, + ungrouped_aggregates: 0, + wrapper_nodes, + wrapped_select_ungrouped_scan, + empty_wrappers: 0, + ast_size_outside_wrapper: 0, + ast_size_inside_wrapper, + cube_scan_nodes, + ast_size_without_alias, + ast_size: 1, + ungrouped_nodes, + unwrapped_subqueries, + } + } } /// This cost struct maintains following structural relationships: @@ -61,7 +258,7 @@ pub struct CubePlanCost { ungrouped_nodes: usize, } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, Hash, PartialEq)] pub enum CubePlanState { Wrapped, Unwrapped(usize), @@ -79,7 +276,7 @@ impl CubePlanState { } } -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, Hash, PartialEq)] pub enum SortState { None, Current, @@ -126,7 +323,9 @@ impl CubePlanCostAndState { pub fn finalize(&self, enode: &LogicalPlanLanguage) -> Self { Self { - cost: self.cost.finalize(&self.state, &self.sort_state, enode), + cost: self + .cost + .finalize(&self.state, &self.sort_state, enode, false), state: self.state.clone(), sort_state: self.sort_state.clone(), } @@ -180,6 +379,7 @@ impl CubePlanCost { state: &CubePlanState, sort_state: &SortState, enode: &LogicalPlanLanguage, + top_down: bool, ) -> Self { Self { replacers: self.replacers, @@ -200,6 +400,7 @@ impl CubePlanCost { }, non_pushed_down_limit_sort: match sort_state { SortState::DirectChild => self.non_pushed_down_limit_sort, + SortState::Current if top_down => self.non_pushed_down_limit_sort, _ => 0, }, cube_members: self.cube_members, @@ -256,21 +457,6 @@ impl CostFunction for BestCubePlan { where C: FnMut(Id) -> Self::Cost, { - let table_scans = match enode { - LogicalPlanLanguage::TableScan(_) => 1, - _ => 0, - }; - - let non_detected_cube_scans = match enode { - LogicalPlanLanguage::CubeScan(_) => 1, - _ => 0, - }; - - let cube_scan_nodes = match enode { - LogicalPlanLanguage::CubeScan(_) => 1, - _ => 0, - }; - let ast_size_outside_wrapper = match enode { LogicalPlanLanguage::Aggregate(_) => 1, LogicalPlanLanguage::Projection(_) => 1, @@ -285,202 +471,370 @@ impl CostFunction for BestCubePlan { _ => 0, }; - let non_pushed_down_window = match enode { - LogicalPlanLanguage::Window(_) => 1, - _ => 0, + let cost = self.initial_cost(enode, false); + let initial_cost = CubePlanCostAndState { + cost, + state: match enode { + LogicalPlanLanguage::CubeScanWrapped(CubeScanWrapped(true)) => { + CubePlanState::Wrapped + } + LogicalPlanLanguage::CubeScanWrapper(_) => CubePlanState::Wrapper, + _ => CubePlanState::Unwrapped(ast_size_outside_wrapper), + }, + sort_state: match enode { + LogicalPlanLanguage::Sort(_) => SortState::Current, + _ => SortState::None, + }, }; + let res = enode + .children() + .iter() + .fold(initial_cost.clone(), |cost, id| { + let child = costs(*id); + cost.add_child(&child) + }) + .finalize(enode); + res + } +} - let non_pushed_down_grouping_sets = match enode { - LogicalPlanLanguage::GroupingSetExpr(_) => 1, - _ => 0, - }; +pub trait TopDownCost: Clone + Debug + PartialOrd { + fn add(&self, other: &Self) -> Self; +} - let non_pushed_down_limit_sort = match enode { - LogicalPlanLanguage::Limit(_) => 1, - _ => 0, - }; +pub trait TopDownState: Clone + Debug + Eq + Hash +where + L: Language, +{ + /// Transforms the current state based on node's contents. + fn transform(&self, node: &L, egraph: &EGraph) -> Self + where + A: Analysis; +} - let ast_size_inside_wrapper = match enode { - LogicalPlanLanguage::WrappedSelect(_) => 1, - _ => 0, - }; +/// Simple implementation of TopDownState for lack of state. +impl TopDownState for () +where + L: Language, +{ + fn transform(&self, _: &L, _: &EGraph) -> Self + where + A: Analysis, + { + () + } +} - let wrapper_nodes = match enode { - LogicalPlanLanguage::CubeScanWrapper(_) => 1, - _ => 0, - }; +pub trait TopDownCostFunction: Debug +where + L: Language, + S: TopDownState, + C: TopDownCost, +{ + /// Returns the cost for the current node. + fn cost(&self, node: &L) -> C; + + // Finalize the cost based on node and state. + fn finalize(&self, cost: C, node: &L, state: &S) -> C; +} - let filter_members = match enode { - LogicalPlanLanguage::FilterMember(_) => 1, - _ => 0, - }; +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +struct IdWithState +where + L: Language, + S: TopDownState, +{ + id: Id, + state: Arc, + phantom: PhantomData, +} - let filters = match enode { - LogicalPlanLanguage::Filter(_) => 1, - _ => 0, - }; +impl IdWithState +where + L: Language, + S: TopDownState, +{ + pub fn new(id: Id, state: Arc) -> Self { + Self { + id, + state, + phantom: PhantomData, + } + } +} - let member_errors = match enode { - LogicalPlanLanguage::MemberError(_) => 1, - _ => 0, - }; +#[derive(Clone, Debug)] +pub struct TopDownExtractor<'a, L, A, C, S, CF> +where + L: Language, + A: Analysis, + C: TopDownCost, + S: TopDownState, + CF: TopDownCostFunction, +{ + egraph: &'a EGraph, + // Caches results. `None` for nodes in progress to prevent recursion + extract_map: HashMap, Option<(usize, C)>>, + cost_fn: Arc, + root_state: Arc, +} - let cube_members = match enode { - LogicalPlanLanguage::Measure(_) => 1, - LogicalPlanLanguage::Dimension(_) => 1, - LogicalPlanLanguage::ChangeUser(_) => 1, - LogicalPlanLanguage::VirtualField(_) => 1, - LogicalPlanLanguage::LiteralMember(_) => 1, - LogicalPlanLanguage::TimeDimensionGranularity(TimeDimensionGranularity(Some(_))) => 1, - // MemberError must be present here as well in order to preserve error priority - LogicalPlanLanguage::MemberError(_) => 1, - _ => 0, - }; +impl<'a, L, A, C, S, CF> TopDownExtractor<'a, L, A, C, S, CF> +where + L: Language, + A: Analysis, + C: TopDownCost, + S: TopDownState, + CF: TopDownCostFunction, +{ + pub fn new(egraph: &'a EGraph, cost_fn: CF, root_state: S) -> Self { + Self { + egraph, + extract_map: HashMap::new(), + cost_fn: Arc::new(cost_fn), + root_state: Arc::new(root_state), + } + } - let this_replacers = match enode { - LogicalPlanLanguage::OrderReplacer(_) => 1, - LogicalPlanLanguage::MemberReplacer(_) => 1, - LogicalPlanLanguage::FilterReplacer(_) => 1, - LogicalPlanLanguage::TimeDimensionDateRangeReplacer(_) => 1, - LogicalPlanLanguage::InnerAggregateSplitReplacer(_) => 1, - LogicalPlanLanguage::OuterProjectionSplitReplacer(_) => 1, - LogicalPlanLanguage::OuterAggregateSplitReplacer(_) => 1, - LogicalPlanLanguage::GroupExprSplitReplacer(_) => 1, - LogicalPlanLanguage::GroupAggregateSplitReplacer(_) => 1, - LogicalPlanLanguage::MemberPushdownReplacer(_) => 1, - LogicalPlanLanguage::EventNotification(_) => 1, - LogicalPlanLanguage::MergedMembersReplacer(_) => 1, - LogicalPlanLanguage::CaseExprReplacer(_) => 1, - LogicalPlanLanguage::WrapperPushdownReplacer(_) => 1, - LogicalPlanLanguage::WrapperPullupReplacer(_) => 1, - LogicalPlanLanguage::FlattenPushdownReplacer(_) => 1, - LogicalPlanLanguage::AggregateSplitPushDownReplacer(_) => 1, - LogicalPlanLanguage::AggregateSplitPullUpReplacer(_) => 1, - LogicalPlanLanguage::ProjectionSplitPushDownReplacer(_) => 1, - LogicalPlanLanguage::ProjectionSplitPullUpReplacer(_) => 1, - LogicalPlanLanguage::QueryParam(_) => 1, - // Not really replacers but those should be deemed as mandatory rewrites and as soon as - // there's always rewrite rule it's fine to have replacer cost. - // Needs to be added as alias rewrite always more expensive than original function. - LogicalPlanLanguage::ScalarUDFExprFun(ScalarUDFExprFun(fun)) - if fun.as_str() == "current_timestamp" => - { - 1 - } - LogicalPlanLanguage::ScalarUDFExprFun(ScalarUDFExprFun(fun)) - if fun.as_str() == "localtimestamp" => - { - 1 + /// Returns cost and path for best plan for provided root eclass. + /// + /// If all nodes happen to be recursive, returns `None`. + pub fn find_best(&mut self, root: Id) -> Option<(C, RecExpr)> { + let cost = self.extract(root, Arc::clone(&self.root_state))?; + let root_id_with_state = IdWithState::new(root, Arc::clone(&self.root_state)); + let root_node = self.choose_node(&root_id_with_state)?; + let recexpr = + self.build_recexpr(&root_node, root_id_with_state.state, |id_with_state| { + self.choose_node(id_with_state) + })?; + Some((cost, recexpr)) + } + + /// Recursively extracts the cost of each node in the eclass + /// and returns cost of the node with least cost based on the passed state, + /// caching the cost together with node index inside eclass in `extract_map`. + /// + /// Yields `None` if eclass is already in progress + /// or all its nodes happen to be recursive. + fn extract(&mut self, eclass: Id, state: Arc) -> Option { + let id_with_state = IdWithState::new(eclass, state); + if let Some(cached_index_and_cost) = self.extract_map.get(&id_with_state) { + // TODO: avoid cloning here? + return cached_index_and_cost.as_ref().map(|(_, cost)| cost.clone()); + } + + // Mark this eclass as in progress + self.extract_map.insert(id_with_state.clone(), None); + + // Compute the cost of each node, take the minimum + let mut min_index = None; + let mut min_cost = None; + 'nodes: for (index, node) in self.egraph[eclass].nodes.iter().enumerate() { + // Compute the cost of this node + let this_node_cost = self.cost_fn.cost(node); + + // Get state for this node and its children + let new_state = Arc::new(id_with_state.state.transform(node, self.egraph)); + + // Recursively get children cost + let mut total_node_cost = this_node_cost; + for child in node.children() { + let Some(child_cost) = self.extract(*child, Arc::clone(&new_state)) else { + // This path is inevitably recursive, try the next node + continue 'nodes; + }; + total_node_cost = total_node_cost.add(&child_cost); } - _ => 0, - }; + total_node_cost = self.cost_fn.finalize(total_node_cost, node, &new_state); - let time_dimensions_used_as_dimensions = match enode { - LogicalPlanLanguage::DimensionName(DimensionName(name)) => { - if let Some(dimension) = self.meta_context.find_dimension_with_name(name.clone()) { - if dimension.is_time() { - 1 - } else { - 0 - } - } else { - 0 + // Now that we've finalized the cost, check if it's lower than the minimum + if let Some(min_cost) = &min_cost { + if &total_node_cost > min_cost { + continue; } } - _ => 0, - }; - let max_time_dimensions_granularity = match enode { - LogicalPlanLanguage::TimeDimensionGranularity(TimeDimensionGranularity(Some( - granularity, - ))) => (8 - granularity_str_to_int_order(granularity, Some(false)).unwrap_or(0)) as i64, - _ => 0, + min_index = Some(index); + min_cost = Some(total_node_cost); + } + + let (Some(min_index), Some(min_cost)) = (min_index, min_cost) else { + // All nodes were recursive + self.extract_map.remove(&id_with_state); + return None; }; - let this_errors = match enode { - LogicalPlanLanguage::MemberErrorPriority(MemberErrorPriority(priority)) => { - (100 - priority) as i64 + self.extract_map + .insert(id_with_state, Some((min_index, min_cost.clone()))); + Some(min_cost) + } + + /// A custom version of [`egg::Language::build_recexpr`], accepting state + /// in addition to [`egg::Id`]. + fn build_recexpr(&self, node: &L, start_state: Arc, get_node: F) -> Option> + where + F: Fn(&IdWithState) -> Option, + { + let state = Arc::new(start_state.transform(node, self.egraph)); + let mut set = IndexSet::::default(); + let mut ids = HashMap::, Id>::default(); + let mut todo = node + .children() + .into_iter() + .map(|id| IdWithState::new(*id, Arc::clone(&state))) + .collect::>(); + + while let Some(id_with_state) = todo.last().cloned() { + if ids.contains_key(&id_with_state) { + todo.pop(); + continue; } - _ => 0, - }; - let structure_points = match enode { - // TODO needed to get rid of FilterOpFilters on upper level - LogicalPlanLanguage::FilterOpFilters(_) => 1, - LogicalPlanLanguage::Join(_) => 1, - LogicalPlanLanguage::CrossJoin(_) => 1, - _ => 0, - }; + let node = get_node(&id_with_state)?; + let node_state = Arc::new(id_with_state.state.transform(&node, self.egraph)); - let ast_size_without_alias = match enode { - LogicalPlanLanguage::AliasExpr(_) => 0, - LogicalPlanLanguage::AliasExprAlias(_) => 0, - _ => 1, - }; + // Check to see if we can do this node yet + let mut ids_has_all_children = true; + for child in node.children() { + let child_id_with_state = IdWithState::new(*child, Arc::clone(&node_state)); + if !ids.contains_key(&child_id_with_state) { + ids_has_all_children = false; + todo.push(child_id_with_state); + } + } - let ungrouped_nodes = match enode { - LogicalPlanLanguage::CubeScanUngrouped(CubeScanUngrouped(true)) => 1, - _ => 0, - }; + // All children are processed, so we can lookup this node safely + if ids_has_all_children { + let node = node.map_children(|id| { + let id_with_state = IdWithState::new(id, Arc::clone(&node_state)); + ids[&id_with_state] + }); + let (new_id, _) = set.insert_full(node); + ids.insert(id_with_state, Id::from(new_id)); + todo.pop(); + } + } - let wrapped_select_ungrouped_scan = match enode { - LogicalPlanLanguage::WrappedSelectUngroupedScan(WrappedSelectUngroupedScan(true)) => 1, - _ => 0, + // Finally, add the root node and create the expression + let mut nodes = set.into_iter().collect::>(); + nodes.push(node.clone().map_children(|id| { + let id_with_state = IdWithState::new(id, Arc::clone(&state)); + ids[&id_with_state] + })); + Some(RecExpr::from(nodes)) + } + + fn choose_node(&self, id_with_state: &IdWithState) -> Option { + let index = *self + .extract_map + .get(&id_with_state)? + .as_ref() + .map(|(index, _)| index)?; + Some(self.egraph[id_with_state.id].nodes[index].clone()) + } +} + +impl TopDownCost for CubePlanCost { + fn add(&self, other: &Self) -> Self { + self.add_child(other) + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct CubePlanTopDownState { + wrapped: CubePlanState, + limit: SortState, +} + +impl CubePlanTopDownState { + pub fn new() -> Self { + Self { + wrapped: CubePlanState::Unwrapped(0), + limit: SortState::None, + } + } + + pub fn is_wrapped( + &self, + node: &LogicalPlanLanguage, + egraph: &EGraph, + ) -> bool + where + A: Analysis, + { + let LogicalPlanLanguage::CubeScan(cube_scan) = node else { + return false; }; + let wrapped_index = 8; + let wrapped_id = cube_scan[wrapped_index]; + for node in &egraph[wrapped_id].nodes { + if !matches!( + node, + LogicalPlanLanguage::CubeScanWrapped(CubeScanWrapped(true)) + ) { + return false; + } + } + return true; + } +} - let unwrapped_subqueries = match enode { - LogicalPlanLanguage::Subquery(_) => 1, - _ => 0, +impl TopDownState for CubePlanTopDownState { + fn transform( + &self, + node: &LogicalPlanLanguage, + egraph: &EGraph, + ) -> Self + where + A: Analysis, + { + let wrapped = match node { + LogicalPlanLanguage::CubeScanWrapper(_) => CubePlanState::Wrapper, + _ if self.wrapped == CubePlanState::Wrapped => CubePlanState::Wrapped, + LogicalPlanLanguage::CubeScan(_) if self.is_wrapped(node, egraph) => { + CubePlanState::Wrapped + } + _ => { + let ast_size_outside_wrapper = match node { + LogicalPlanLanguage::Aggregate(_) => 1, + LogicalPlanLanguage::Projection(_) => 1, + LogicalPlanLanguage::Limit(_) => 1, + LogicalPlanLanguage::Sort(_) => 1, + LogicalPlanLanguage::Filter(_) => 1, + LogicalPlanLanguage::Join(_) => 1, + LogicalPlanLanguage::CrossJoin(_) => 1, + LogicalPlanLanguage::Union(_) => 1, + LogicalPlanLanguage::Window(_) => 1, + LogicalPlanLanguage::Subquery(_) => 1, + _ => 0, + }; + CubePlanState::Unwrapped(ast_size_outside_wrapper) + } }; - let initial_cost = CubePlanCostAndState { - cost: CubePlanCost { - replacers: this_replacers, - table_scans, - filters, - filter_members, - non_detected_cube_scans, - member_errors, - non_pushed_down_window, - non_pushed_down_grouping_sets, - non_pushed_down_limit_sort, - cube_members, - errors: this_errors, - time_dimensions_used_as_dimensions, - max_time_dimensions_granularity, - structure_points, - ungrouped_aggregates: 0, - wrapper_nodes, - wrapped_select_ungrouped_scan, - empty_wrappers: 0, - ast_size_outside_wrapper: 0, - ast_size_inside_wrapper, - cube_scan_nodes, - ast_size_without_alias, - ast_size: 1, - ungrouped_nodes, - unwrapped_subqueries, - }, - state: match enode { - LogicalPlanLanguage::CubeScanWrapped(CubeScanWrapped(true)) => { - CubePlanState::Wrapped - } - LogicalPlanLanguage::CubeScanWrapper(_) => CubePlanState::Wrapper, - _ => CubePlanState::Unwrapped(ast_size_outside_wrapper), - }, - sort_state: match enode { - LogicalPlanLanguage::Sort(_) => SortState::Current, - _ => SortState::None, - }, + let limit = match node { + LogicalPlanLanguage::Limit(_) => SortState::DirectChild, + LogicalPlanLanguage::Sort(_) if self.limit == SortState::DirectChild => { + SortState::Current + } + _ => SortState::None, }; - let res = enode - .children() - .iter() - .fold(initial_cost.clone(), |cost, id| { - let child = costs(*id); - cost.add_child(&child) - }) - .finalize(enode); - res + + Self { wrapped, limit } + } +} + +impl TopDownCostFunction for BestCubePlan { + fn cost(&self, node: &LogicalPlanLanguage) -> CubePlanCost { + self.initial_cost(node, true) + } + + fn finalize( + &self, + cost: CubePlanCost, + node: &LogicalPlanLanguage, + state: &CubePlanTopDownState, + ) -> CubePlanCost { + CubePlanCost::finalize(&cost, &state.wrapped, &state.limit, node, true) } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs index 2f693e8c7dda2..d7cf6cf591a95 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs @@ -4,7 +4,7 @@ use crate::{ rewrite::{ analysis::LogicalPlanAnalysis, converter::LanguageToLogicalPlanConverter, - cost::BestCubePlan, + cost::{BestCubePlan, CubePlanTopDownState, TopDownExtractor}, rules::{ case::CaseRules, common::CommonRules, dates::DateRules, filters::FilterRules, flatten::FlattenRules, members::MemberRules, old_split::OldSplitRules, @@ -314,6 +314,7 @@ impl Rewriter { auth_context: AuthContextRef, qtrace: &mut Option, span_id: Option>, + top_down_extractor: bool, ) -> Result { let cube_context = self.cube_context.clone(); let egraph = self.graph.clone(); @@ -337,9 +338,26 @@ impl Rewriter { let (runner, qtrace_egraph_iterations) = Self::run_rewrites(&cube_context, egraph, rules, "final")?; - let extractor = - Extractor::new(&runner.egraph, BestCubePlan::new(cube_context.meta.clone())); - let (best_cost, best) = extractor.find_best(root); + let best = if top_down_extractor { + let mut extractor = TopDownExtractor::new( + &runner.egraph, + BestCubePlan::new(cube_context.meta.clone()), + CubePlanTopDownState::new(), + ); + let Some((best_cost, best)) = extractor.find_best(root) else { + return Err(CubeError::internal("Unable to find best plan".to_string())); + }; + log::debug!("Best cost: {:?}", best_cost); + best + } else { + let extractor = Extractor::new( + &runner.egraph, + BestCubePlan::new(cube_context.meta.clone()), + ); + let (best_cost, best) = extractor.find_best(root); + log::debug!("Best cost: {:?}", best_cost); + best + }; let qtrace_best_graph = if Qtrace::is_enabled() { best.as_ref().iter().cloned().collect() } else { @@ -354,14 +372,13 @@ impl Rewriter { .map(|(i, n)| format!("{}: {:?}", i, n)) .join(", ") ); - log::debug!("Best cost: {:?}", best_cost); let converter = LanguageToLogicalPlanConverter::new( best, cube_context.clone(), auth_context, span_id.clone(), ); - Ok::<_, CubeError>(( + Ok(( converter.to_logical_plan(new_root), qtrace_egraph_iterations, qtrace_best_graph, @@ -501,6 +518,12 @@ impl Rewriter { .unwrap_or(true) } + pub fn top_down_extractor_enabled() -> bool { + env::var("CUBESQL_TOP_DOWN_EXTRACTOR") + .map(|v| v.to_lowercase() != "false") + .unwrap_or(true) + } + pub fn rewrite_rules( meta_context: Arc, config_obj: Arc, diff --git a/rust/cubesql/cubesql/src/config/mod.rs b/rust/cubesql/cubesql/src/config/mod.rs index a3ed43d99077d..9f49884f26488 100644 --- a/rust/cubesql/cubesql/src/config/mod.rs +++ b/rust/cubesql/cubesql/src/config/mod.rs @@ -115,6 +115,8 @@ pub trait ConfigObj: DIService + Debug { fn max_sessions(&self) -> usize; fn no_implicit_order(&self) -> bool; + + fn top_down_extractor(&self) -> bool; } #[derive(Debug, Clone)] @@ -135,6 +137,7 @@ pub struct ConfigObjImpl { pub non_streaming_query_max_row_limit: i32, pub max_sessions: usize, pub no_implicit_order: bool, + pub top_down_extractor: bool, } impl ConfigObjImpl { @@ -172,6 +175,7 @@ impl ConfigObjImpl { non_streaming_query_max_row_limit: env_parse("CUBEJS_DB_QUERY_LIMIT", 50000), max_sessions: env_parse("CUBEJS_MAX_SESSIONS", 1024), no_implicit_order: env_parse("CUBESQL_SQL_NO_IMPLICIT_ORDER", true), + top_down_extractor: env_parse("CUBESQL_TOP_DOWN_EXTRACTOR", true), } } } @@ -238,6 +242,10 @@ impl ConfigObj for ConfigObjImpl { fn max_sessions(&self) -> usize { self.max_sessions } + + fn top_down_extractor(&self) -> bool { + self.top_down_extractor + } } impl Config { @@ -270,6 +278,7 @@ impl Config { non_streaming_query_max_row_limit: 50000, max_sessions: 1024, no_implicit_order: true, + top_down_extractor: true, }), } } diff --git a/rust/cubesqlplanner/Cargo.lock b/rust/cubesqlplanner/Cargo.lock index 73e0b3b849fdc..834e1ed12c0ff 100644 --- a/rust/cubesqlplanner/Cargo.lock +++ b/rust/cubesqlplanner/Cargo.lock @@ -730,6 +730,7 @@ dependencies = [ "futures-core", "futures-util", "hashbrown 0.14.5", + "indexmap 1.9.3", "itertools", "log", "lru",