Skip to content

Commit 4c845c8

Browse files
committed
refactor(cubesql): Extract cube join condition check for rewrites to function
1 parent 9346f21 commit 4c845c8

File tree

1 file changed

+87
-71
lines changed
  • rust/cubesql/cubesql/src/compile/rewrite/rules

1 file changed

+87
-71
lines changed

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 87 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,77 +2682,14 @@ impl MemberRules {
26822682
let left_aliases_var = var!(left_aliases_var);
26832683
let right_aliases_var = var!(right_aliases_var);
26842684
move |egraph, subst| {
2685-
if egraph
2686-
.index(subst[left_aliases_var])
2687-
.data
2688-
.member_name_to_expr
2689-
.is_some()
2690-
{
2691-
if egraph
2692-
.index(subst[right_aliases_var])
2693-
.data
2694-
.member_name_to_expr
2695-
.is_some()
2696-
{
2697-
let left_join_ons: Vec<Vec<_>> =
2698-
var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
2699-
.map(|elem| elem.iter().cloned().collect())
2700-
.collect();
2701-
for left_join_on in left_join_ons {
2702-
for join_on in left_join_on {
2703-
let member_names_to_expr_left = &mut egraph
2704-
.index_mut(subst[left_aliases_var])
2705-
.data
2706-
.member_name_to_expr
2707-
.as_mut()
2708-
.unwrap();
2709-
2710-
// TODO: Avoid the join_on.*.clone() calls (should be trivial).
2711-
let mut column_name = join_on.name.clone();
2712-
if let Some(name) = find_column_by_alias(
2713-
&column_name,
2714-
member_names_to_expr_left,
2715-
&join_on.relation.clone().unwrap_or_default(),
2716-
) {
2717-
column_name = name.split(".").last().unwrap().to_string();
2718-
}
2719-
2720-
if column_name == "__cubeJoinField" {
2721-
let right_join_ons: Vec<Vec<_>> =
2722-
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
2723-
.map(|elem| elem.iter().cloned().collect())
2724-
.collect();
2725-
for right_join_on in right_join_ons {
2726-
for join_on in right_join_on.iter() {
2727-
let member_names_to_expr_right = &mut egraph
2728-
.index_mut(subst[right_aliases_var])
2729-
.data
2730-
.member_name_to_expr
2731-
.as_mut()
2732-
.unwrap();
2733-
2734-
let mut column_name = join_on.name.clone();
2735-
if let Some(name) = find_column_by_alias(
2736-
&column_name,
2737-
member_names_to_expr_right,
2738-
&join_on.relation.clone().unwrap_or_default(),
2739-
) {
2740-
column_name =
2741-
name.split(".").last().unwrap().to_string();
2742-
}
2743-
2744-
if column_name == "__cubeJoinField" {
2745-
return true;
2746-
}
2747-
}
2748-
}
2749-
}
2750-
}
2751-
}
2752-
}
2753-
}
2754-
2755-
false
2685+
is_proper_cube_join_condition(
2686+
egraph,
2687+
subst,
2688+
left_aliases_var,
2689+
left_on_var,
2690+
right_aliases_var,
2691+
right_on_var,
2692+
)
27562693
}
27572694
}
27582695

@@ -2930,6 +2867,85 @@ fn find_column_by_alias(
29302867
None
29312868
}
29322869

2870+
fn is_proper_cube_join_condition(
2871+
egraph: &mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
2872+
subst: &Subst,
2873+
left_cube_members_var: Var,
2874+
left_on_var: Var,
2875+
right_cube_members_var: Var,
2876+
right_on_var: Var,
2877+
) -> bool {
2878+
if egraph
2879+
.index(subst[left_cube_members_var])
2880+
.data
2881+
.member_name_to_expr
2882+
.is_some()
2883+
{
2884+
if egraph
2885+
.index(subst[right_cube_members_var])
2886+
.data
2887+
.member_name_to_expr
2888+
.is_some()
2889+
{
2890+
let left_join_ons: Vec<Vec<_>> = var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
2891+
.map(|elem| elem.iter().cloned().collect())
2892+
.collect();
2893+
for left_join_on in left_join_ons {
2894+
for join_on in left_join_on {
2895+
let member_names_to_expr_left = &mut egraph
2896+
.index_mut(subst[left_cube_members_var])
2897+
.data
2898+
.member_name_to_expr
2899+
.as_mut()
2900+
.unwrap();
2901+
2902+
// TODO: Avoid the join_on.*.clone() calls (should be trivial).
2903+
let mut column_name = join_on.name.clone();
2904+
if let Some(name) = find_column_by_alias(
2905+
&column_name,
2906+
member_names_to_expr_left,
2907+
&join_on.relation.clone().unwrap_or_default(),
2908+
) {
2909+
column_name = name.split(".").last().unwrap().to_string();
2910+
}
2911+
2912+
if column_name == "__cubeJoinField" {
2913+
let right_join_ons: Vec<Vec<_>> =
2914+
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
2915+
.map(|elem| elem.iter().cloned().collect())
2916+
.collect();
2917+
for right_join_on in right_join_ons {
2918+
for join_on in right_join_on.iter() {
2919+
let member_names_to_expr_right = &mut egraph
2920+
.index_mut(subst[right_cube_members_var])
2921+
.data
2922+
.member_name_to_expr
2923+
.as_mut()
2924+
.unwrap();
2925+
2926+
let mut column_name = join_on.name.clone();
2927+
if let Some(name) = find_column_by_alias(
2928+
&column_name,
2929+
member_names_to_expr_right,
2930+
&join_on.relation.clone().unwrap_or_default(),
2931+
) {
2932+
column_name = name.split(".").last().unwrap().to_string();
2933+
}
2934+
2935+
if column_name == "__cubeJoinField" {
2936+
return true;
2937+
}
2938+
}
2939+
}
2940+
}
2941+
}
2942+
}
2943+
}
2944+
}
2945+
2946+
false
2947+
}
2948+
29332949
#[cfg(test)]
29342950
mod tests {
29352951
use super::*;

0 commit comments

Comments
 (0)