Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17655,4 +17655,47 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
assert!(sql.contains("TRUNCATE(EXTRACT(month FROM "));
}

#[tokio::test]
async fn test_top_down_extractor_cache() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let logical_plan = convert_select_to_query_plan(
r#"
SELECT
id::integer AS id,
customer_gender
FROM KibanaSampleDataEcommerce
WHERE id = 5
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await
.as_logical_plan();

assert_eq!(
logical_plan.find_cube_scan().request,
V1LoadRequestQuery {
measures: Some(vec![]),
dimensions: Some(vec![
"KibanaSampleDataEcommerce.id".to_string(),
"KibanaSampleDataEcommerce.customer_gender".to_string(),
]),
segments: Some(vec![]),
order: Some(vec![]),
filters: Some(vec![V1LoadRequestQueryFilterItem {
member: Some("KibanaSampleDataEcommerce.id".to_string()),
operator: Some("equals".to_string()),
values: Some(vec!["5".to_string()]),
..Default::default()
}]),
ungrouped: Some(true),
..Default::default()
}
)
}
}
49 changes: 45 additions & 4 deletions rust/cubesql/cubesql/src/compile/rewrite/cost.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc};
use std::{
collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, mem::take, sync::Arc,
};

use crate::{
compile::rewrite::{
Expand Down Expand Up @@ -531,6 +533,9 @@ where
egraph: &'a EGraph<L, A>,
// Caches results. `None` for nodes in progress to prevent recursion
extract_map: HashMap<IdWithState<L, S>, Option<(usize, C)>>,
has_deep_recursion: bool,
// Cache for second pass to calculate recursive nodes cost.
extract_map_recursive_cache: Option<HashMap<IdWithState<L, S>, Option<(usize, C)>>>,
cost_fn: Arc<CF>,
root_state: Arc<S>,
}
Expand All @@ -547,6 +552,8 @@ where
Self {
egraph,
extract_map: HashMap::new(),
has_deep_recursion: false,
extract_map_recursive_cache: None,
cost_fn: Arc::new(cost_fn),
root_state: Arc::new(root_state),
}
Expand All @@ -555,8 +562,17 @@ where
/// Returns cost and path for best plan for provided root eclass.
///
/// If all nodes happen to be recursive, returns `None`.
///
/// If there were any nodes with deep recursion, the cost is calculated in two passes;
/// the second pass fetches the cost from the extract map obtained on the first pass
/// for recursive nodes only.
pub fn find_best(&mut self, root: Id) -> Option<(C, RecExpr<L>)> {
let cost = self.extract(root, Arc::clone(&self.root_state))?;
let mut cost = self.extract(root, Arc::clone(&self.root_state))?;
if self.has_deep_recursion {
self.extract_map_recursive_cache = Some(take(&mut self.extract_map));
cost = self.extract(root, Arc::clone(&self.root_state))?;
}

let root_id_with_state = IdWithState::new(root, Arc::clone(&self.root_state));
let root_node = self.choose_node(&root_id_with_state)?;
let recexpr =
Expand All @@ -569,14 +585,33 @@ where
/// Recursively extracts the cost of each node in the eclass
/// and returns cost of the node with least cost based on the passed state,
/// caching the cost together with node index inside eclass in `extract_map`.
/// If `extract_map_recursive_cache` is available, fetches the costs
/// of deep recursion nodes from there.
///
/// Yields `None` if eclass is already in progress
/// or all its nodes happen to be recursive.
fn extract(&mut self, eclass: Id, state: Arc<S>) -> Option<C> {
let id_with_state = IdWithState::new(eclass, state);
if let Some(cached_index_and_cost) = self.extract_map.get(&id_with_state) {
// TODO: avoid cloning here?
return cached_index_and_cost.as_ref().map(|(_, cost)| cost.clone());
// If the cost has been computed, return it
if let Some((_, cached_cost)) = cached_index_and_cost {
// TODO: avoid cloning here?
return Some(cached_cost.clone());
}

// If the cost is recursive, fetch from recursive cache if available
if let Some(extract_map_recursive_cache) = &self.extract_map_recursive_cache {
if let Some(Some((_, cached_cost))) =
extract_map_recursive_cache.get(&id_with_state)
{
// TODO: avoid cloning here?
return Some(cached_cost.clone());
}
}

// Otherwise, mark this extractor as having deep recursion
self.has_deep_recursion = true;
return None;
}

// Mark this eclass as in progress
Expand All @@ -595,6 +630,12 @@ where
// Recursively get children cost
let mut total_node_cost = this_node_cost;
for child in node.children() {
// If a child is recursive to self, skip this node, as it will never compute
// the cost
if child == &eclass {
continue 'nodes;
}

let Some(child_cost) = self.extract(*child, Arc::clone(&new_state)) else {
// This path is inevitably recursive, try the next node
continue 'nodes;
Expand Down
Loading