@@ -17,7 +17,8 @@ use crate::{
1717 wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct ,
1818 AggregateFunctionExprFun , AggregateUDFExprFun , AliasExprAlias , ColumnExprColumn ,
1919 ListType , LiteralExprValue , LogicalPlanData , LogicalPlanLanguage ,
20- WrappedSelectPushToCube , WrapperReplacerContextPushToCube ,
20+ WrappedSelectPushToCube , WrapperReplacerContextAliasToCube ,
21+ WrapperReplacerContextPushToCube ,
2122 } ,
2223 } ,
2324 copy_flag,
@@ -26,7 +27,7 @@ use crate::{
2627} ;
2728use datafusion:: { logical_plan:: Column , scalar:: ScalarValue } ;
2829use egg:: { Subst , Var } ;
29- use std:: ops:: IndexMut ;
30+ use std:: { collections :: HashSet , ops:: IndexMut } ;
3031
3132impl WrapperRules {
3233 pub fn aggregate_rules ( & self , rules : & mut Vec < CubeRewrite > ) {
@@ -290,6 +291,7 @@ impl WrapperRules {
290291 "?cube_members" ,
291292 "?out_measure_expr" ,
292293 "?out_measure_alias" ,
294+ "?alias_to_cube" ,
293295 ) ,
294296 )
295297 } ,
@@ -1035,97 +1037,119 @@ impl WrapperRules {
10351037 cube_members_var : Var ,
10361038 out_expr_var : Var ,
10371039 out_alias_var : Var ,
1040+ alias_to_cube_var : Var ,
10381041 meta : & MetaContext ,
10391042 disable_strict_agg_type_match : bool ,
10401043 ) -> bool {
10411044 let Some ( alias) = original_expr_name ( egraph, subst[ original_expr_var] ) else {
10421045 return false ;
10431046 } ;
10441047
1045- for fun in fun_name_var
1046- . map ( |fun_var| {
1047- var_iter ! ( egraph[ subst[ fun_var] ] , AggregateFunctionExprFun )
1048- . map ( |fun| Some ( fun. clone ( ) ) )
1049- . collect ( )
1050- } )
1051- . unwrap_or ( vec ! [ None ] )
1048+ for alias_to_cube in var_iter ! (
1049+ egraph[ subst[ alias_to_cube_var] ] ,
1050+ WrapperReplacerContextAliasToCube
1051+ )
1052+ . cloned ( )
1053+ . collect :: < Vec < _ > > ( )
10521054 {
1053- for distinct in distinct_var
1054- . map ( |distinct_var| {
1055- var_iter ! ( egraph[ subst[ distinct_var] ] , AggregateFunctionExprDistinct )
1056- . map ( |d| * d)
1055+ // Do not push down COUNT(*) if there are joined cubes
1056+ let is_count_rows = column_var. is_none ( ) ;
1057+ if is_count_rows {
1058+ let joined_cubes = alias_to_cube
1059+ . iter ( )
1060+ . map ( |( _, cube_name) | cube_name)
1061+ . collect :: < HashSet < _ > > ( ) ;
1062+ if joined_cubes. len ( ) > 1 {
1063+ continue ;
1064+ }
1065+ }
1066+
1067+ for fun in fun_name_var
1068+ . map ( |fun_var| {
1069+ var_iter ! ( egraph[ subst[ fun_var] ] , AggregateFunctionExprFun )
1070+ . map ( |fun| Some ( fun. clone ( ) ) )
10571071 . collect ( )
10581072 } )
1059- . unwrap_or ( vec ! [ false ] )
1073+ . unwrap_or ( vec ! [ None ] )
10601074 {
1061- let call_agg_type = MemberRules :: get_agg_type ( fun. as_ref ( ) , distinct) ;
1075+ for distinct in distinct_var
1076+ . map ( |distinct_var| {
1077+ var_iter ! ( egraph[ subst[ distinct_var] ] , AggregateFunctionExprDistinct )
1078+ . map ( |d| * d)
1079+ . collect ( )
1080+ } )
1081+ . unwrap_or ( vec ! [ false ] )
1082+ {
1083+ let call_agg_type = MemberRules :: get_agg_type ( fun. as_ref ( ) , distinct) ;
10621084
1063- let column_iter = if let Some ( column_var) = column_var {
1064- var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1065- . cloned ( )
1066- . collect ( )
1067- } else {
1068- vec ! [ Column :: from_name( MemberRules :: default_count_measure_name( ) ) ]
1069- } ;
1085+ let column_iter = if let Some ( column_var) = column_var {
1086+ var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1087+ . cloned ( )
1088+ . collect ( )
1089+ } else {
1090+ vec ! [ Column :: from_name( MemberRules :: default_count_measure_name( ) ) ]
1091+ } ;
10701092
1071- if let Some ( member_names_to_expr) = & mut egraph
1072- . index_mut ( subst[ cube_members_var] )
1073- . data
1074- . member_name_to_expr
1075- {
1076- for column in column_iter {
1077- if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1078- LogicalPlanData :: do_find_member_by_alias (
1079- member_names_to_expr,
1080- & column. name ,
1081- )
1082- {
1083- if let Some ( measure) = meta. find_measure_with_name ( member) {
1084- let Some ( call_agg_type) = & call_agg_type else {
1085- // call_agg_type is None, rewrite as is
1086- Self :: insert_regular_measure (
1087- egraph,
1088- subst,
1089- column,
1090- alias,
1091- out_expr_var,
1092- out_alias_var,
1093- ) ;
1093+ if let Some ( member_names_to_expr) = & mut egraph
1094+ . index_mut ( subst[ cube_members_var] )
1095+ . data
1096+ . member_name_to_expr
1097+ {
1098+ for column in column_iter {
1099+ if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1100+ LogicalPlanData :: do_find_member_by_alias (
1101+ member_names_to_expr,
1102+ & column. name ,
1103+ )
1104+ {
1105+ if let Some ( measure) = meta. find_measure_with_name ( member) {
1106+ let Some ( call_agg_type) = & call_agg_type else {
1107+ // call_agg_type is None, rewrite as is
1108+ Self :: insert_regular_measure (
1109+ egraph,
1110+ subst,
1111+ column,
1112+ alias,
1113+ out_expr_var,
1114+ out_alias_var,
1115+ ) ;
10941116
1095- return true ;
1096- } ;
1117+ return true ;
1118+ } ;
10971119
1098- if measure
1099- . is_same_agg_type ( call_agg_type, disable_strict_agg_type_match)
1100- {
1101- Self :: insert_regular_measure (
1102- egraph,
1103- subst,
1104- column,
1105- alias,
1106- out_expr_var,
1107- out_alias_var,
1108- ) ;
1120+ if measure. is_same_agg_type (
1121+ call_agg_type,
1122+ disable_strict_agg_type_match,
1123+ ) {
1124+ Self :: insert_regular_measure (
1125+ egraph,
1126+ subst,
1127+ column,
1128+ alias,
1129+ out_expr_var,
1130+ out_alias_var,
1131+ ) ;
11091132
1110- return true ;
1111- }
1133+ return true ;
1134+ }
11121135
1113- if measure. allow_replace_agg_type (
1114- call_agg_type,
1115- disable_strict_agg_type_match,
1116- ) {
1117- Self :: insert_patch_measure (
1118- egraph,
1119- subst,
1120- column,
1121- Some ( call_agg_type. clone ( ) ) ,
1122- alias,
1123- Some ( out_expr_var) ,
1124- None ,
1125- out_alias_var,
1126- ) ;
1136+ if measure. allow_replace_agg_type (
1137+ call_agg_type,
1138+ disable_strict_agg_type_match,
1139+ ) {
1140+ Self :: insert_patch_measure (
1141+ egraph,
1142+ subst,
1143+ column,
1144+ Some ( call_agg_type. clone ( ) ) ,
1145+ alias,
1146+ Some ( out_expr_var) ,
1147+ None ,
1148+ out_alias_var,
1149+ ) ;
11271150
1128- return true ;
1151+ return true ;
1152+ }
11291153 }
11301154 }
11311155 }
@@ -1148,6 +1172,7 @@ impl WrapperRules {
11481172 cube_members_var : & ' static str ,
11491173 out_expr_var : & ' static str ,
11501174 out_alias_var : & ' static str ,
1175+ alias_to_cube_var : & ' static str ,
11511176 ) -> impl Fn ( & mut CubeEGraph , & mut Subst ) -> bool {
11521177 let original_expr_var = var ! ( original_expr_var) ;
11531178 let column_var = column_var. map ( |v| var ! ( v) ) ;
@@ -1157,6 +1182,7 @@ impl WrapperRules {
11571182 let cube_members_var = var ! ( cube_members_var) ;
11581183 let out_expr_var = var ! ( out_expr_var) ;
11591184 let out_alias_var = var ! ( out_alias_var) ;
1185+ let alias_to_cube_var = var ! ( alias_to_cube_var) ;
11601186 let meta = self . meta_context . clone ( ) ;
11611187 let disable_strict_agg_type_match = self . config_obj . disable_strict_agg_type_match ( ) ;
11621188 move |egraph, subst| {
@@ -1170,6 +1196,7 @@ impl WrapperRules {
11701196 cube_members_var,
11711197 out_expr_var,
11721198 out_alias_var,
1199+ alias_to_cube_var,
11731200 & meta,
11741201 disable_strict_agg_type_match,
11751202 )
0 commit comments