diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs index 1c123c1242fe3..6cc5cc4fcd9e3 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs @@ -2713,77 +2713,14 @@ impl MemberRules { let left_aliases_var = var!(left_aliases_var); let right_aliases_var = var!(right_aliases_var); move |egraph, subst| { - if egraph - .index(subst[left_aliases_var]) - .data - .member_name_to_expr - .is_some() - { - if egraph - .index(subst[right_aliases_var]) - .data - .member_name_to_expr - .is_some() - { - let left_join_ons: Vec> = - var_iter!(egraph[subst[left_on_var]], JoinLeftOn) - .map(|elem| elem.iter().cloned().collect()) - .collect(); - for left_join_on in left_join_ons { - for join_on in left_join_on { - let member_names_to_expr_left = &mut egraph - .index_mut(subst[left_aliases_var]) - .data - .member_name_to_expr - .as_mut() - .unwrap(); - - // TODO: Avoid the join_on.*.clone() calls (should be trivial). - let mut column_name = join_on.name.clone(); - if let Some(name) = find_column_by_alias( - &column_name, - member_names_to_expr_left, - &join_on.relation.clone().unwrap_or_default(), - ) { - column_name = name.split(".").last().unwrap().to_string(); - } - - if column_name == "__cubeJoinField" { - let right_join_ons: Vec> = - var_iter!(egraph[subst[right_on_var]], JoinRightOn) - .map(|elem| elem.iter().cloned().collect()) - .collect(); - for right_join_on in right_join_ons { - for join_on in right_join_on.iter() { - let member_names_to_expr_right = &mut egraph - .index_mut(subst[right_aliases_var]) - .data - .member_name_to_expr - .as_mut() - .unwrap(); - - let mut column_name = join_on.name.clone(); - if let Some(name) = find_column_by_alias( - &column_name, - member_names_to_expr_right, - &join_on.relation.clone().unwrap_or_default(), - ) { - column_name = - name.split(".").last().unwrap().to_string(); - } - - if column_name == "__cubeJoinField" { - return true; - } - } - } - } - } - } - } - } - - false + is_proper_cube_join_condition( + egraph, + subst, + left_aliases_var, + left_on_var, + right_aliases_var, + right_on_var, + ) } } @@ -2961,6 +2898,85 @@ fn find_column_by_alias( None } +fn is_proper_cube_join_condition( + egraph: &mut CubeEGraph, + subst: &Subst, + left_cube_members_var: Var, + left_on_var: Var, + right_cube_members_var: Var, + right_on_var: Var, +) -> bool { + if egraph + .index(subst[left_cube_members_var]) + .data + .member_name_to_expr + .is_some() + { + if egraph + .index(subst[right_cube_members_var]) + .data + .member_name_to_expr + .is_some() + { + let left_join_ons: Vec> = var_iter!(egraph[subst[left_on_var]], JoinLeftOn) + .map(|elem| elem.iter().cloned().collect()) + .collect(); + for left_join_on in left_join_ons { + for join_on in left_join_on { + let member_names_to_expr_left = &mut egraph + .index_mut(subst[left_cube_members_var]) + .data + .member_name_to_expr + .as_mut() + .unwrap(); + + // TODO: Avoid the join_on.*.clone() calls (should be trivial). + let mut column_name = join_on.name.clone(); + if let Some(name) = find_column_by_alias( + &column_name, + member_names_to_expr_left, + &join_on.relation.clone().unwrap_or_default(), + ) { + column_name = name.split(".").last().unwrap().to_string(); + } + + if column_name == "__cubeJoinField" { + let right_join_ons: Vec> = + var_iter!(egraph[subst[right_on_var]], JoinRightOn) + .map(|elem| elem.iter().cloned().collect()) + .collect(); + for right_join_on in right_join_ons { + for join_on in right_join_on.iter() { + let member_names_to_expr_right = &mut egraph + .index_mut(subst[right_cube_members_var]) + .data + .member_name_to_expr + .as_mut() + .unwrap(); + + let mut column_name = join_on.name.clone(); + if let Some(name) = find_column_by_alias( + &column_name, + member_names_to_expr_right, + &join_on.relation.clone().unwrap_or_default(), + ) { + column_name = name.split(".").last().unwrap().to_string(); + } + + if column_name == "__cubeJoinField" { + return true; + } + } + } + } + } + } + } + } + + false +} + #[cfg(test)] mod tests { use super::*;