Skip to content

Commit 4659f8c

Browse files
committed
refactor(cubesql): Extract cube join condition check for rewrites to function
1 parent 5fd13d1 commit 4659f8c

File tree

1 file changed

+87
-71
lines changed
  • rust/cubesql/cubesql/src/compile/rewrite/rules

1 file changed

+87
-71
lines changed

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 87 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,77 +2713,14 @@ impl MemberRules {
27132713
let left_aliases_var = var!(left_aliases_var);
27142714
let right_aliases_var = var!(right_aliases_var);
27152715
move |egraph, subst| {
2716-
if egraph
2717-
.index(subst[left_aliases_var])
2718-
.data
2719-
.member_name_to_expr
2720-
.is_some()
2721-
{
2722-
if egraph
2723-
.index(subst[right_aliases_var])
2724-
.data
2725-
.member_name_to_expr
2726-
.is_some()
2727-
{
2728-
let left_join_ons: Vec<Vec<_>> =
2729-
var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
2730-
.map(|elem| elem.iter().cloned().collect())
2731-
.collect();
2732-
for left_join_on in left_join_ons {
2733-
for join_on in left_join_on {
2734-
let member_names_to_expr_left = &mut egraph
2735-
.index_mut(subst[left_aliases_var])
2736-
.data
2737-
.member_name_to_expr
2738-
.as_mut()
2739-
.unwrap();
2740-
2741-
// TODO: Avoid the join_on.*.clone() calls (should be trivial).
2742-
let mut column_name = join_on.name.clone();
2743-
if let Some(name) = find_column_by_alias(
2744-
&column_name,
2745-
member_names_to_expr_left,
2746-
&join_on.relation.clone().unwrap_or_default(),
2747-
) {
2748-
column_name = name.split(".").last().unwrap().to_string();
2749-
}
2750-
2751-
if column_name == "__cubeJoinField" {
2752-
let right_join_ons: Vec<Vec<_>> =
2753-
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
2754-
.map(|elem| elem.iter().cloned().collect())
2755-
.collect();
2756-
for right_join_on in right_join_ons {
2757-
for join_on in right_join_on.iter() {
2758-
let member_names_to_expr_right = &mut egraph
2759-
.index_mut(subst[right_aliases_var])
2760-
.data
2761-
.member_name_to_expr
2762-
.as_mut()
2763-
.unwrap();
2764-
2765-
let mut column_name = join_on.name.clone();
2766-
if let Some(name) = find_column_by_alias(
2767-
&column_name,
2768-
member_names_to_expr_right,
2769-
&join_on.relation.clone().unwrap_or_default(),
2770-
) {
2771-
column_name =
2772-
name.split(".").last().unwrap().to_string();
2773-
}
2774-
2775-
if column_name == "__cubeJoinField" {
2776-
return true;
2777-
}
2778-
}
2779-
}
2780-
}
2781-
}
2782-
}
2783-
}
2784-
}
2785-
2786-
false
2716+
is_proper_cube_join_condition(
2717+
egraph,
2718+
subst,
2719+
left_aliases_var,
2720+
left_on_var,
2721+
right_aliases_var,
2722+
right_on_var,
2723+
)
27872724
}
27882725
}
27892726

@@ -2961,6 +2898,85 @@ fn find_column_by_alias(
29612898
None
29622899
}
29632900

2901+
fn is_proper_cube_join_condition(
2902+
egraph: &mut CubeEGraph,
2903+
subst: &Subst,
2904+
left_cube_members_var: Var,
2905+
left_on_var: Var,
2906+
right_cube_members_var: Var,
2907+
right_on_var: Var,
2908+
) -> bool {
2909+
if egraph
2910+
.index(subst[left_cube_members_var])
2911+
.data
2912+
.member_name_to_expr
2913+
.is_some()
2914+
{
2915+
if egraph
2916+
.index(subst[right_cube_members_var])
2917+
.data
2918+
.member_name_to_expr
2919+
.is_some()
2920+
{
2921+
let left_join_ons: Vec<Vec<_>> = var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
2922+
.map(|elem| elem.iter().cloned().collect())
2923+
.collect();
2924+
for left_join_on in left_join_ons {
2925+
for join_on in left_join_on {
2926+
let member_names_to_expr_left = &mut egraph
2927+
.index_mut(subst[left_cube_members_var])
2928+
.data
2929+
.member_name_to_expr
2930+
.as_mut()
2931+
.unwrap();
2932+
2933+
// TODO: Avoid the join_on.*.clone() calls (should be trivial).
2934+
let mut column_name = join_on.name.clone();
2935+
if let Some(name) = find_column_by_alias(
2936+
&column_name,
2937+
member_names_to_expr_left,
2938+
&join_on.relation.clone().unwrap_or_default(),
2939+
) {
2940+
column_name = name.split(".").last().unwrap().to_string();
2941+
}
2942+
2943+
if column_name == "__cubeJoinField" {
2944+
let right_join_ons: Vec<Vec<_>> =
2945+
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
2946+
.map(|elem| elem.iter().cloned().collect())
2947+
.collect();
2948+
for right_join_on in right_join_ons {
2949+
for join_on in right_join_on.iter() {
2950+
let member_names_to_expr_right = &mut egraph
2951+
.index_mut(subst[right_cube_members_var])
2952+
.data
2953+
.member_name_to_expr
2954+
.as_mut()
2955+
.unwrap();
2956+
2957+
let mut column_name = join_on.name.clone();
2958+
if let Some(name) = find_column_by_alias(
2959+
&column_name,
2960+
member_names_to_expr_right,
2961+
&join_on.relation.clone().unwrap_or_default(),
2962+
) {
2963+
column_name = name.split(".").last().unwrap().to_string();
2964+
}
2965+
2966+
if column_name == "__cubeJoinField" {
2967+
return true;
2968+
}
2969+
}
2970+
}
2971+
}
2972+
}
2973+
}
2974+
}
2975+
}
2976+
2977+
false
2978+
}
2979+
29642980
#[cfg(test)]
29652981
mod tests {
29662982
use super::*;

0 commit comments

Comments
 (0)