diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 87fc8190a9008..26d707cca51f2 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -17655,4 +17655,47 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("TRUNCATE(EXTRACT(month FROM ")); } + + #[tokio::test] + async fn test_top_down_extractor_cache() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let logical_plan = convert_select_to_query_plan( + r#" + SELECT + id::integer AS id, + customer_gender + FROM KibanaSampleDataEcommerce + WHERE id = 5 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await + .as_logical_plan(); + + assert_eq!( + logical_plan.find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec![]), + dimensions: Some(vec![ + "KibanaSampleDataEcommerce.id".to_string(), + "KibanaSampleDataEcommerce.customer_gender".to_string(), + ]), + segments: Some(vec![]), + order: Some(vec![]), + filters: Some(vec![V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.id".to_string()), + operator: Some("equals".to_string()), + values: Some(vec!["5".to_string()]), + ..Default::default() + }]), + ungrouped: Some(true), + ..Default::default() + } + ) + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index baa49a23dca7f..16ffed9991a53 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -1,4 +1,6 @@ -use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc}; +use std::{ + collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, mem::take, sync::Arc, +}; use crate::{ compile::rewrite::{ @@ -531,6 +533,9 @@ where egraph: &'a EGraph, // Caches results. `None` for nodes in progress to prevent recursion extract_map: HashMap, Option<(usize, C)>>, + has_deep_recursion: bool, + // Cache for second pass to calculate recursive nodes cost. + extract_map_recursive_cache: Option, Option<(usize, C)>>>, cost_fn: Arc, root_state: Arc, } @@ -547,6 +552,8 @@ where Self { egraph, extract_map: HashMap::new(), + has_deep_recursion: false, + extract_map_recursive_cache: None, cost_fn: Arc::new(cost_fn), root_state: Arc::new(root_state), } @@ -555,8 +562,17 @@ where /// Returns cost and path for best plan for provided root eclass. /// /// If all nodes happen to be recursive, returns `None`. + /// + /// If there were any nodes with deep recursion, the cost is calculated in two passes; + /// the second pass fetches the cost from the extract map obtained on the first pass + /// for recursive nodes only. pub fn find_best(&mut self, root: Id) -> Option<(C, RecExpr)> { - let cost = self.extract(root, Arc::clone(&self.root_state))?; + let mut cost = self.extract(root, Arc::clone(&self.root_state))?; + if self.has_deep_recursion { + self.extract_map_recursive_cache = Some(take(&mut self.extract_map)); + cost = self.extract(root, Arc::clone(&self.root_state))?; + } + let root_id_with_state = IdWithState::new(root, Arc::clone(&self.root_state)); let root_node = self.choose_node(&root_id_with_state)?; let recexpr = @@ -569,14 +585,33 @@ where /// Recursively extracts the cost of each node in the eclass /// and returns cost of the node with least cost based on the passed state, /// caching the cost together with node index inside eclass in `extract_map`. + /// If `extract_map_recursive_cache` is available, fetches the costs + /// of deep recursion nodes from there. /// /// Yields `None` if eclass is already in progress /// or all its nodes happen to be recursive. fn extract(&mut self, eclass: Id, state: Arc) -> Option { let id_with_state = IdWithState::new(eclass, state); if let Some(cached_index_and_cost) = self.extract_map.get(&id_with_state) { - // TODO: avoid cloning here? - return cached_index_and_cost.as_ref().map(|(_, cost)| cost.clone()); + // If the cost has been computed, return it + if let Some((_, cached_cost)) = cached_index_and_cost { + // TODO: avoid cloning here? + return Some(cached_cost.clone()); + } + + // If the cost is recursive, fetch from recursive cache if available + if let Some(extract_map_recursive_cache) = &self.extract_map_recursive_cache { + if let Some(Some((_, cached_cost))) = + extract_map_recursive_cache.get(&id_with_state) + { + // TODO: avoid cloning here? + return Some(cached_cost.clone()); + } + } + + // Otherwise, mark this extractor as having deep recursion + self.has_deep_recursion = true; + return None; } // Mark this eclass as in progress @@ -595,6 +630,12 @@ where // Recursively get children cost let mut total_node_cost = this_node_cost; for child in node.children() { + // If a child is recursive to self, skip this node, as it will never compute + // the cost + if child == &eclass { + continue 'nodes; + } + let Some(child_cost) = self.extract(*child, Arc::clone(&new_state)) else { // This path is inevitably recursive, try the next node continue 'nodes;