Skip to content

Commit f7f78e2

Browse files
authored
fix(cubesql): Correctly calculate eclass cost for recursive nodes (#9947)
1 parent a178744 commit f7f78e2

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17655,4 +17655,47 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1765517655
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
1765617656
assert!(sql.contains("TRUNCATE(EXTRACT(month FROM "));
1765717657
}
17658+
17659+
#[tokio::test]
17660+
async fn test_top_down_extractor_cache() {
17661+
if !Rewriter::sql_push_down_enabled() {
17662+
return;
17663+
}
17664+
init_testing_logger();
17665+
17666+
let logical_plan = convert_select_to_query_plan(
17667+
r#"
17668+
SELECT
17669+
id::integer AS id,
17670+
customer_gender
17671+
FROM KibanaSampleDataEcommerce
17672+
WHERE id = 5
17673+
"#
17674+
.to_string(),
17675+
DatabaseProtocol::PostgreSQL,
17676+
)
17677+
.await
17678+
.as_logical_plan();
17679+
17680+
assert_eq!(
17681+
logical_plan.find_cube_scan().request,
17682+
V1LoadRequestQuery {
17683+
measures: Some(vec![]),
17684+
dimensions: Some(vec![
17685+
"KibanaSampleDataEcommerce.id".to_string(),
17686+
"KibanaSampleDataEcommerce.customer_gender".to_string(),
17687+
]),
17688+
segments: Some(vec![]),
17689+
order: Some(vec![]),
17690+
filters: Some(vec![V1LoadRequestQueryFilterItem {
17691+
member: Some("KibanaSampleDataEcommerce.id".to_string()),
17692+
operator: Some("equals".to_string()),
17693+
values: Some(vec!["5".to_string()]),
17694+
..Default::default()
17695+
}]),
17696+
ungrouped: Some(true),
17697+
..Default::default()
17698+
}
17699+
)
17700+
}
1765817701
}

rust/cubesql/cubesql/src/compile/rewrite/cost.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, sync::Arc};
1+
use std::{
2+
collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, mem::take, sync::Arc,
3+
};
24

35
use crate::{
46
compile::rewrite::{
@@ -531,6 +533,9 @@ where
531533
egraph: &'a EGraph<L, A>,
532534
// Caches results. `None` for nodes in progress to prevent recursion
533535
extract_map: HashMap<IdWithState<L, S>, Option<(usize, C)>>,
536+
has_deep_recursion: bool,
537+
// Cache for second pass to calculate recursive nodes cost.
538+
extract_map_recursive_cache: Option<HashMap<IdWithState<L, S>, Option<(usize, C)>>>,
534539
cost_fn: Arc<CF>,
535540
root_state: Arc<S>,
536541
}
@@ -547,6 +552,8 @@ where
547552
Self {
548553
egraph,
549554
extract_map: HashMap::new(),
555+
has_deep_recursion: false,
556+
extract_map_recursive_cache: None,
550557
cost_fn: Arc::new(cost_fn),
551558
root_state: Arc::new(root_state),
552559
}
@@ -555,8 +562,17 @@ where
555562
/// Returns cost and path for best plan for provided root eclass.
556563
///
557564
/// If all nodes happen to be recursive, returns `None`.
565+
///
566+
/// If there were any nodes with deep recursion, the cost is calculated in two passes;
567+
/// the second pass fetches the cost from the extract map obtained on the first pass
568+
/// for recursive nodes only.
558569
pub fn find_best(&mut self, root: Id) -> Option<(C, RecExpr<L>)> {
559-
let cost = self.extract(root, Arc::clone(&self.root_state))?;
570+
let mut cost = self.extract(root, Arc::clone(&self.root_state))?;
571+
if self.has_deep_recursion {
572+
self.extract_map_recursive_cache = Some(take(&mut self.extract_map));
573+
cost = self.extract(root, Arc::clone(&self.root_state))?;
574+
}
575+
560576
let root_id_with_state = IdWithState::new(root, Arc::clone(&self.root_state));
561577
let root_node = self.choose_node(&root_id_with_state)?;
562578
let recexpr =
@@ -569,14 +585,33 @@ where
569585
/// Recursively extracts the cost of each node in the eclass
570586
/// and returns cost of the node with least cost based on the passed state,
571587
/// caching the cost together with node index inside eclass in `extract_map`.
588+
/// If `extract_map_recursive_cache` is available, fetches the costs
589+
/// of deep recursion nodes from there.
572590
///
573591
/// Yields `None` if eclass is already in progress
574592
/// or all its nodes happen to be recursive.
575593
fn extract(&mut self, eclass: Id, state: Arc<S>) -> Option<C> {
576594
let id_with_state = IdWithState::new(eclass, state);
577595
if let Some(cached_index_and_cost) = self.extract_map.get(&id_with_state) {
578-
// TODO: avoid cloning here?
579-
return cached_index_and_cost.as_ref().map(|(_, cost)| cost.clone());
596+
// If the cost has been computed, return it
597+
if let Some((_, cached_cost)) = cached_index_and_cost {
598+
// TODO: avoid cloning here?
599+
return Some(cached_cost.clone());
600+
}
601+
602+
// If the cost is recursive, fetch from recursive cache if available
603+
if let Some(extract_map_recursive_cache) = &self.extract_map_recursive_cache {
604+
if let Some(Some((_, cached_cost))) =
605+
extract_map_recursive_cache.get(&id_with_state)
606+
{
607+
// TODO: avoid cloning here?
608+
return Some(cached_cost.clone());
609+
}
610+
}
611+
612+
// Otherwise, mark this extractor as having deep recursion
613+
self.has_deep_recursion = true;
614+
return None;
580615
}
581616

582617
// Mark this eclass as in progress
@@ -595,6 +630,12 @@ where
595630
// Recursively get children cost
596631
let mut total_node_cost = this_node_cost;
597632
for child in node.children() {
633+
// If a child is recursive to self, skip this node, as it will never compute
634+
// the cost
635+
if child == &eclass {
636+
continue 'nodes;
637+
}
638+
598639
let Some(child_cost) = self.extract(*child, Arc::clone(&new_state)) else {
599640
// This path is inevitably recursive, try the next node
600641
continue 'nodes;

0 commit comments

Comments
 (0)