diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index d375e9cb03de0..d2e19faa8e41a 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -17168,4 +17168,59 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), } ) } + + #[tokio::test] + async fn test_push_down_limit_sort_projection() { + init_testing_logger(); + + let logical_plan = convert_select_to_query_plan( + r#" + SELECT + "ta_1"."customer_gender" AS "ca_1", + DATE_TRUNC('MONTH', CAST("ta_1"."order_date" AS date)) AS "ca_2", + COALESCE(sum("ta_1"."sumPrice"), 0) AS "ca_3" + FROM + "db"."public"."KibanaSampleDataEcommerce" AS "ta_1" + WHERE + ( + "ta_1"."order_date" >= TIMESTAMP '2024-01-01 00:00:00.0' + AND "ta_1"."order_date" < TIMESTAMP '2025-01-01 00:00:00.0' + ) + GROUP BY + "ca_1", + "ca_2" + ORDER BY + "ca_2" ASC NULLS LAST + LIMIT + 5000 + ;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await + .as_logical_plan(); + + assert_eq!( + logical_plan.find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.sumPrice".to_string()]), + dimensions: Some(vec!["KibanaSampleDataEcommerce.customer_gender".to_string()]), + segments: Some(vec![]), + time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { + dimension: "KibanaSampleDataEcommerce.order_date".to_string(), + granularity: Some("month".to_string()), + date_range: Some(json!(vec![ + "2024-01-01T00:00:00.000Z".to_string(), + "2024-12-31T23:59:59.999Z".to_string() + ])), + },]), + order: Some(vec![vec![ + "KibanaSampleDataEcommerce.order_date".to_string(), + "asc".to_string(), + ]]), + limit: Some(5000), + ..Default::default() + } + ) + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index 0c338ee1d3b5e..baa49a23dca7f 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -51,11 +51,6 @@ impl BestCubePlan { _ => 0, }; - let non_pushed_down_limit_sort = match enode { - LogicalPlanLanguage::Sort(_) => 1, - _ => 0, - }; - let ast_size_inside_wrapper = match enode { LogicalPlanLanguage::WrappedSelect(_) => 1, _ => 0, @@ -130,6 +125,8 @@ impl BestCubePlan { LogicalPlanLanguage::JoinCheckStage(_) => 1, LogicalPlanLanguage::JoinCheckPushDown(_) => 1, LogicalPlanLanguage::JoinCheckPullUp(_) => 1, + LogicalPlanLanguage::SortProjectionPushdownReplacer(_) => 1, + LogicalPlanLanguage::SortProjectionPullupReplacer(_) => 1, // Not really replacers but those should be deemed as mandatory rewrites and as soon as // there's always rewrite rule it's fine to have replacer cost. // Needs to be added as alias rewrite always more expensive than original function. @@ -220,7 +217,8 @@ impl BestCubePlan { member_errors, non_pushed_down_window, non_pushed_down_grouping_sets, - non_pushed_down_limit_sort, + // Will be filled in finalize + non_pushed_down_limit_sort: 0, zero_members_wrapper, cube_members, errors: this_errors, @@ -405,9 +403,8 @@ impl CubePlanCost { CubePlanState::Wrapper => 0, }, non_pushed_down_limit_sort: match sort_state { - SortState::DirectChild => self.non_pushed_down_limit_sort, - SortState::Current => self.non_pushed_down_limit_sort, - _ => 0, + SortState::Current => self.non_pushed_down_limit_sort + 1, + _ => self.non_pushed_down_limit_sort, }, // Don't track state here: we want representation that have fewer wrappers with zero members _in total_ zero_members_wrapper: self.zero_members_wrapper, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs index 7697e50090dd1..b9abce1be32fb 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs @@ -508,6 +508,13 @@ crate::plan_to_language! { members: Vec, alias_to_cube: Vec<(String, String)>, }, + SortProjectionPushdownReplacer { + expr: Arc, + column_to_expr: Vec<(Column, Expr)>, + }, + SortProjectionPullupReplacer { + expr: Arc, + }, EventNotification { name: String, members: Vec, @@ -2236,6 +2243,17 @@ fn join_check_pull_up(expr: impl Display, left: impl Display, right: impl Displa format!("(JoinCheckPullUp {expr} {left} {right})") } +fn sort_projection_pushdown_replacer(expr: impl Display, column_to_expr: impl Display) -> String { + format!( + "(SortProjectionPushdownReplacer {} {})", + expr, column_to_expr + ) +} + +fn sort_projection_pullup_replacer(expr: impl Display) -> String { + format!("(SortProjectionPullupReplacer {})", expr) +} + pub fn original_expr_name(egraph: &CubeEGraph, id: Id) -> Option { egraph[id] .data diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs index ae86c55dafb1c..96ae42f568fc3 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs @@ -477,7 +477,7 @@ impl Rewriter { eval_stable_functions, ), &DateRules::new(config_obj.clone()), - &OrderRules::new(), + &OrderRules::new(config_obj.clone()), &CommonRules::new(config_obj.clone()), ]; let mut rewrites = Vec::new(); diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/order.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/order.rs index 9dfa0f5550f0f..18ee390431d6e 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/order.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/order.rs @@ -1,18 +1,33 @@ +use std::{ + ops::{Index, IndexMut}, + sync::Arc, +}; + +use egg::Subst; + use crate::{ - compile::rewrite::{ - analysis::OriginalExpr, - column_name_to_member_vec, cube_scan, cube_scan_order, cube_scan_order_empty_tail, - expr_column_name, order, order_replacer, referenced_columns, rewrite, - rewriter::{CubeEGraph, CubeRewrite, RewriteRules}, - sort, sort_exp, sort_exp_empty_tail, sort_expr, transforming_rewrite, LogicalPlanLanguage, - OrderAsc, OrderMember, OrderReplacerColumnNameToMember, SortExprAsc, + compile::{ + datafusion::logical_plan::Column, + rewrite::{ + analysis::OriginalExpr, + column_expr, column_name_to_member_vec, + converter::LogicalPlanToLanguageConverter, + cube_scan, cube_scan_order, cube_scan_order_empty_tail, expr_column_name, limit, order, + order_replacer, projection, referenced_columns, rewrite, + rewriter::{CubeEGraph, CubeRewrite, RewriteRules}, + sort, sort_exp, sort_exp_empty_tail, sort_expr, sort_projection_pullup_replacer, + sort_projection_pushdown_replacer, transforming_rewrite, ColumnExprColumn, + LogicalPlanLanguage, OrderAsc, OrderMember, OrderReplacerColumnNameToMember, + ProjectionAlias, SortExprAsc, SortProjectionPushdownReplacerColumnToExpr, + }, }, + config::ConfigObj, var, var_iter, }; -use egg::Subst; -use std::ops::{Index, IndexMut}; -pub struct OrderRules {} +pub struct OrderRules { + config_obj: Arc, +} impl RewriteRules for OrderRules { fn rewrite_rules(&self) -> Vec { @@ -70,13 +85,107 @@ impl RewriteRules for OrderRules { order_replacer(sort_exp_empty_tail(), "?aliases"), cube_scan_order_empty_tail(), ), + // TODO: refactor this rule to `push-down-sort-projection`, + // possibly adjust cost function to penalize Limit-...-Sort plan + transforming_rewrite( + "push-down-limit-sort-projection", + limit( + "?skip", + "?fetch", + sort( + "?sort_expr", + projection( + "?projection_expr", + "?input", + "?projection_alias", + "?projection_split", + ), + ), + ), + projection( + "?projection_expr", + limit( + "?skip", + "?fetch", + sort( + sort_projection_pushdown_replacer("?sort_expr", "?column_to_expr"), + "?input", + ), + ), + "?projection_alias", + "?projection_split", + ), + self.push_down_limit_sort_projection( + "?input", + "?projection_expr", + "?projection_alias", + "?column_to_expr", + ), + ), + rewrite( + "sort-projection-replacer-pull-up-sort", + sort(sort_projection_pullup_replacer("?expr"), "?input"), + sort("?expr", "?input"), + ), + rewrite( + "sort-projection-replacer-push-down-sortexp", + sort_projection_pushdown_replacer(sort_exp("?left", "?right"), "?column_to_expr"), + sort_exp( + sort_projection_pushdown_replacer("?left", "?column_to_expr"), + sort_projection_pushdown_replacer("?right", "?column_to_expr"), + ), + ), + rewrite( + "sort-projection-replacer-push-down-sortexp-tail", + sort_projection_pushdown_replacer(sort_exp_empty_tail(), "?column_to_expr"), + sort_projection_pullup_replacer(sort_exp_empty_tail()), + ), + rewrite( + "sort-projection-replacer-pull-up-sortexp", + sort_exp( + sort_projection_pullup_replacer("?left"), + sort_projection_pullup_replacer("?right"), + ), + sort_projection_pullup_replacer(sort_exp("?left", "?right")), + ), + rewrite( + "sort-projection-replacer-push-down-sortexpr", + sort_projection_pushdown_replacer( + sort_expr("?expr", "?asc", "?nulls_first"), + "?column_to_expr", + ), + sort_expr( + sort_projection_pushdown_replacer("?expr", "?column_to_expr"), + "?asc", + "?nulls_first", + ), + ), + rewrite( + "sort-projection-replacer-pull-up-sortexpr", + sort_expr( + sort_projection_pullup_replacer("?expr"), + "?asc", + "?nulls_first", + ), + sort_projection_pullup_replacer(sort_expr("?expr", "?asc", "?nulls_first")), + ), + transforming_rewrite( + "sort-projection-replacer-push-down-column", + sort_projection_pushdown_replacer(column_expr("?column"), "?column_to_expr"), + sort_projection_pullup_replacer("?new_expr"), + self.sort_projection_replacer_push_down_column( + "?column", + "?column_to_expr", + "?new_expr", + ), + ), ] } } impl OrderRules { - pub fn new() -> Self { - Self {} + pub fn new(config_obj: Arc) -> Self { + Self { config_obj } } fn push_down_sort( @@ -173,4 +282,107 @@ impl OrderRules { false } } + + fn push_down_limit_sort_projection( + &self, + input_var: &'static str, + projection_expr_var: &'static str, + projection_alias_var: &'static str, + column_to_expr_var: &'static str, + ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { + let input_var = var!(input_var); + let projection_expr_var = var!(projection_expr_var); + let projection_alias_var = var!(projection_alias_var); + let column_to_expr_var = var!(column_to_expr_var); + move |egraph, subst| { + let input_is_sort_or_limit = egraph[subst[input_var]].nodes.iter().any(|node| { + matches!( + node, + LogicalPlanLanguage::Sort(_) | LogicalPlanLanguage::Limit(_) + ) + }); + if input_is_sort_or_limit { + return false; + } + + let Some(expr_to_alias) = egraph[subst[projection_expr_var]] + .data + .expr_to_alias + .as_deref() + else { + return false; + }; + + for projection_alias in var_iter!(egraph[subst[projection_alias_var]], ProjectionAlias) + { + let mut column_to_expr = vec![]; + for (expr, alias, _) in expr_to_alias { + let column = Column::from_name(alias); + column_to_expr.push((column, expr.clone())); + if let Some(projection_alias) = projection_alias.as_deref() { + let column = Column { + relation: Some(projection_alias.to_string()), + name: alias.to_string(), + }; + column_to_expr.push((column, expr.clone())); + } + } + + subst.insert( + column_to_expr_var, + egraph.add( + LogicalPlanLanguage::SortProjectionPushdownReplacerColumnToExpr( + SortProjectionPushdownReplacerColumnToExpr(column_to_expr), + ), + ), + ); + return true; + } + false + } + } + + fn sort_projection_replacer_push_down_column( + &self, + column_var: &'static str, + column_to_expr_var: &'static str, + new_expr_var: &'static str, + ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { + let column_var = var!(column_var); + let column_to_expr_var = var!(column_to_expr_var); + let new_expr_var = var!(new_expr_var); + let flat_list = self.config_obj.push_down_pull_up_split(); + move |egraph, subst| { + for old_column in var_iter!(egraph[subst[column_var]], ColumnExprColumn).cloned() { + for column_to_expr in var_iter!( + egraph[subst[column_to_expr_var]], + SortProjectionPushdownReplacerColumnToExpr + ) + .cloned() + { + let Some(expr) = column_to_expr.iter().find_map(|(column, expr)| { + if column == &old_column { + Some(expr) + } else { + None + } + }) else { + continue; + }; + + let Ok(new_expr_id) = + LogicalPlanToLanguageConverter::add_expr(egraph, expr, flat_list) + else { + // Insertion failure should never happen as it can be partial, + // so fail right away. + return false; + }; + + subst.insert(new_expr_var, new_expr_id); + return true; + } + } + false + } + } }