@@ -360,8 +360,7 @@ impl WrapperRules {
360360 PATCH_MEASURE_UDAF_NAME ,
361361 vec![
362362 column_expr( "?measure_column" ) ,
363- // TODO support doing both: changing agg type and adding filters
364- literal_null( ) ,
363+ "?replace_agg_type" . to_string( ) ,
365364 wrapper_pushdown_replacer(
366365 // = is a proper way to filter here:
367366 // CASE NULL WHEN ... will return null
@@ -374,7 +373,15 @@ impl WrapperRules {
374373 ) ,
375374 "?out_measure_alias" ,
376375 ) ,
377- self . transform_filtered_measure( "?aggr_expr" , "?literal" , "?out_measure_alias" ) ,
376+ self . transform_filtered_measure(
377+ "?aggr_expr" ,
378+ "?literal" ,
379+ "?measure_column" ,
380+ "?fun" ,
381+ "?cube_members" ,
382+ "?replace_agg_type" ,
383+ "?out_measure_alias" ,
384+ ) ,
378385 ) ,
379386 rewrite(
380387 "wrapper-pull-up-aggregation-over-filtered-measure" ,
@@ -936,20 +943,31 @@ impl WrapperRules {
936943 egraph : & mut CubeEGraph ,
937944 subst : & mut Subst ,
938945 column : Column ,
939- call_agg_type : String ,
946+ call_agg_type : Option < String > ,
940947 alias : String ,
941- out_expr_var : Var ,
948+ out_expr_var : Option < Var > ,
949+ out_replace_agg_type : Option < Var > ,
942950 out_alias_var : Var ,
943951 ) {
944952 let column_expr_column = egraph. add ( LogicalPlanLanguage :: ColumnExprColumn (
945- ColumnExprColumn ( column. clone ( ) ) ,
953+ ColumnExprColumn ( column) ,
946954 ) ) ;
947955 let column_expr = egraph. add ( LogicalPlanLanguage :: ColumnExpr ( [ column_expr_column] ) ) ;
948- let new_aggregation_value = egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
949- LiteralExprValue ( ScalarValue :: Utf8 ( Some ( call_agg_type) ) ) ,
950- ) ) ;
956+ let new_aggregation_value = match call_agg_type {
957+ Some ( call_agg_type) => egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
958+ LiteralExprValue ( ScalarValue :: Utf8 ( Some ( call_agg_type) ) ) ,
959+ ) ) ,
960+ None => egraph. add ( LogicalPlanLanguage :: LiteralExprValue ( LiteralExprValue (
961+ ScalarValue :: Null ,
962+ ) ) ) ,
963+ } ;
951964 let new_aggregation_expr =
952965 egraph. add ( LogicalPlanLanguage :: LiteralExpr ( [ new_aggregation_value] ) ) ;
966+
967+ if let Some ( out_replace_agg_type) = out_replace_agg_type {
968+ subst. insert ( out_replace_agg_type, new_aggregation_expr) ;
969+ }
970+
953971 let add_filters_value = egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
954972 LiteralExprValue ( ScalarValue :: Null ) ,
955973 ) ) ;
@@ -967,7 +985,9 @@ impl WrapperRules {
967985 udaf_args_expr,
968986 ] ) ) ;
969987
970- subst. insert ( out_expr_var, udaf_expr) ;
988+ if let Some ( out_expr_var) = out_expr_var {
989+ subst. insert ( out_expr_var, udaf_expr) ;
990+ }
971991
972992 let alias_expr_alias = egraph. add ( LogicalPlanLanguage :: AliasExprAlias ( AliasExprAlias (
973993 alias. clone ( ) ,
@@ -1068,9 +1088,10 @@ impl WrapperRules {
10681088 egraph,
10691089 subst,
10701090 column,
1071- call_agg_type. clone ( ) ,
1091+ Some ( call_agg_type. clone ( ) ) ,
10721092 alias,
1073- out_expr_var,
1093+ Some ( out_expr_var) ,
1094+ None ,
10741095 out_alias_var,
10751096 ) ;
10761097
@@ -1129,12 +1150,23 @@ impl WrapperRules {
11291150 & self ,
11301151 aggr_expr_var : & ' static str ,
11311152 literal_var : & ' static str ,
1153+ column_var : & ' static str ,
1154+ fun_name_var : & ' static str ,
1155+ cube_members_var : & ' static str ,
1156+ replace_agg_type_var : & ' static str ,
11321157 out_measure_alias_var : & ' static str ,
11331158 ) -> impl Fn ( & mut CubeEGraph , & mut Subst ) -> bool {
11341159 let aggr_expr_var = var ! ( aggr_expr_var) ;
11351160 let literal_var = var ! ( literal_var) ;
1161+ let column_var = var ! ( column_var) ;
1162+ let fun_name_var = var ! ( fun_name_var) ;
1163+ let cube_members_var = var ! ( cube_members_var) ;
1164+ let replace_agg_type_var = var ! ( replace_agg_type_var) ;
11361165 let out_measure_alias_var = var ! ( out_measure_alias_var) ;
11371166
1167+ let meta = self . meta_context . clone ( ) ;
1168+ let disable_strict_agg_type_match = self . config_obj . disable_strict_agg_type_match ( ) ;
1169+
11381170 move |egraph, subst| {
11391171 match & egraph[ subst[ literal_var] ] . data . constant {
11401172 Some ( ConstantFolding :: Scalar ( _) ) => {
@@ -1145,17 +1177,96 @@ impl WrapperRules {
11451177 }
11461178 }
11471179
1148- // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1149- // TODO support both changing agg fun and add filter
1150-
11511180 let Some ( alias) = original_expr_name ( egraph, subst[ aggr_expr_var] ) else {
11521181 return false ;
11531182 } ;
1154- let alias_expr_alias =
1155- egraph. add ( LogicalPlanLanguage :: AliasExprAlias ( AliasExprAlias ( alias) ) ) ;
1156- subst. insert ( out_measure_alias_var, alias_expr_alias) ;
11571183
1158- true
1184+ for fun in var_iter ! ( egraph[ subst[ fun_name_var] ] , AggregateFunctionExprFun )
1185+ . cloned ( )
1186+ . collect :: < Vec < _ > > ( )
1187+ {
1188+ let call_agg_type = MemberRules :: get_agg_type ( Some ( & fun) , false ) ;
1189+
1190+ let column_iter = var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1191+ . cloned ( )
1192+ . collect :: < Vec < _ > > ( ) ;
1193+
1194+ if let Some ( member_names_to_expr) = & mut egraph
1195+ . index_mut ( subst[ cube_members_var] )
1196+ . data
1197+ . member_name_to_expr
1198+ {
1199+ for column in column_iter {
1200+ if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1201+ LogicalPlanData :: do_find_member_by_alias (
1202+ member_names_to_expr,
1203+ & column. name ,
1204+ )
1205+ {
1206+ if let Some ( measure) = meta. find_measure_with_name ( member) {
1207+ if !measure. allow_add_filter ( call_agg_type. as_deref ( ) ) {
1208+ continue ;
1209+ }
1210+
1211+ let Some ( call_agg_type) = & call_agg_type else {
1212+ // call_agg_type is None, rewrite as is
1213+ Self :: insert_patch_measure (
1214+ egraph,
1215+ subst,
1216+ column,
1217+ None ,
1218+ alias,
1219+ None ,
1220+ Some ( replace_agg_type_var) ,
1221+ out_measure_alias_var,
1222+ ) ;
1223+
1224+ return true ;
1225+ } ;
1226+
1227+ if measure
1228+ . is_same_agg_type ( call_agg_type, disable_strict_agg_type_match)
1229+ {
1230+ Self :: insert_patch_measure (
1231+ egraph,
1232+ subst,
1233+ column,
1234+ None ,
1235+ alias,
1236+ None ,
1237+ Some ( replace_agg_type_var) ,
1238+ out_measure_alias_var,
1239+ ) ;
1240+
1241+ return true ;
1242+ }
1243+
1244+ if measure. allow_replace_agg_type (
1245+ call_agg_type,
1246+ disable_strict_agg_type_match,
1247+ ) {
1248+ Self :: insert_patch_measure (
1249+ egraph,
1250+ subst,
1251+ column,
1252+ Some ( call_agg_type. clone ( ) ) ,
1253+ alias,
1254+ None ,
1255+ Some ( replace_agg_type_var) ,
1256+ out_measure_alias_var,
1257+ ) ;
1258+
1259+ return true ;
1260+ }
1261+ }
1262+ }
1263+ }
1264+ }
1265+ }
1266+
1267+ false
1268+
1269+ // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
11591270 }
11601271 }
11611272}
0 commit comments