@@ -22,6 +22,7 @@ use datafusion::{
2222 prelude:: JoinType ,
2323} ;
2424use egg:: { Id , Subst } ;
25+ use itertools:: Itertools ;
2526
2627impl WrapperRules {
2728 pub fn join_rules ( & self , rules : & mut Vec < CubeRewrite > ) {
@@ -831,50 +832,66 @@ impl WrapperRules {
831832 // * Both inputs depend on a single data source
832833 // * SQL generator for that data source have `expressions/subquery` template
833834 // It could be checked later, in WrappedSelect as well
834-
835- let left_columns = egraph[ subst[ left_expr_var] ] . data . referenced_expr . as_ref ( ) ;
836- let Some ( left_columns) = left_columns else {
837- return false ;
835+ // TODO For views: check that each member is coming from same data source (or even cube?)
836+
837+ let prepare_columns = |var| {
838+ let columns = egraph[ subst[ var] ] . data . referenced_expr . as_ref ( ) ;
839+ let Some ( columns) = columns else {
840+ return Err ( "Missing referenced_expr" ) ;
841+ } ;
842+ let columns = columns
843+ . iter ( )
844+ . map ( |column| {
845+ let column = match column {
846+ Expr :: Column ( column) => column. clone ( ) ,
847+ _ => return Err ( "Unexpected expression in referenced_expr" ) ,
848+ } ;
849+ Ok ( column)
850+ } )
851+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
852+ Ok ( columns)
838853 } ;
839- if left_columns. len ( ) != 1 {
840- return false ;
854+
855+ fn prepare_relation ( columns : & [ Column ] ) -> Result < & str , & ' static str > {
856+ let relation = columns
857+ . iter ( )
858+ . map ( |column| & column. relation )
859+ . all_equal_value ( ) ;
860+ let Ok ( Some ( relation) ) = relation else {
861+ // Outer Err means there's either no values at all, or more than one different value
862+ // Inner Err means that all referenced_expr are not columns
863+ // Inner None means that all columns are without relation, don't support that ATM
864+ return Err ( "Relation mismatch" ) ;
865+ } ;
866+ Ok ( relation)
841867 }
842- let left_column = & left_columns[ 0 ] ;
843868
844- let right_columns = egraph[ subst[ right_expr_var] ] . data . referenced_expr . as_ref ( ) ;
845- let Some ( right_columns) = right_columns else {
869+ let Ok ( left_columns) = prepare_columns ( left_expr_var) else {
846870 return false ;
847871 } ;
848- if right_columns . len ( ) != 1 {
872+ let Ok ( left_relation ) = prepare_relation ( & left_columns ) else {
849873 return false ;
850- }
851- let right_column = & right_columns[ 0 ] ;
852-
853- let left_column = match left_column {
854- Expr :: Column ( column) => column,
855- _ => return false ,
856- } ;
857- let right_column = match right_column {
858- Expr :: Column ( column) => column,
859- _ => return false ,
860874 } ;
861875
862- // Simple check that column expressions reference different join sides
863- let Some ( left_relation) = left_column. relation . as_ref ( ) else {
876+ let Ok ( right_columns) = prepare_columns ( right_expr_var) else {
864877 return false ;
865878 } ;
866- let Some ( right_relation) = right_column . relation . as_ref ( ) else {
879+ let Ok ( right_relation) = prepare_relation ( & right_columns ) else {
867880 return false ;
868881 } ;
882+
883+ // Simple check that column expressions reference different join sides
869884 if left_relation == right_relation {
870885 return false ;
871886 }
872887
873- let left_column = left_column. clone ( ) ;
874-
875888 // Don't check right, as it is already grouped
876889
877- if !Self :: are_join_members_supported ( egraph, subst[ left_members_var] , [ & left_column] ) {
890+ if !Self :: are_join_members_supported (
891+ egraph,
892+ subst[ left_members_var] ,
893+ left_columns. iter ( ) ,
894+ ) {
878895 return false ;
879896 }
880897
0 commit comments