@@ -2,12 +2,13 @@ use crate::{
22 compile:: {
33 engine:: udf:: { MEASURE_UDAF_NAME , PATCH_MEASURE_UDAF_NAME } ,
44 rewrite:: {
5- aggregate, alias_expr,
5+ agg_fun_expr , aggregate, alias_expr,
66 analysis:: ConstantFolding ,
7- cube_scan_wrapper, grouping_set_expr, original_expr_name, rewrite,
7+ binary_expr, case_expr, column_expr, cube_scan_wrapper, grouping_set_expr,
8+ literal_null, original_expr_name, rewrite,
89 rewriter:: { CubeEGraph , CubeRewrite } ,
910 rules:: { members:: MemberRules , wrapper:: WrapperRules } ,
10- subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
11+ subquery, transforming_chain_rewrite, transforming_rewrite, udaf_expr , wrapped_select,
1112 wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
1213 wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
1314 wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
@@ -321,10 +322,6 @@ impl WrapperRules {
321322 ) ;
322323 }
323324
324- use crate :: compile:: rewrite:: {
325- agg_fun_expr, binary_expr, case_expr, column_expr, literal_null, udaf_expr,
326- } ;
327-
328325 // incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column)))
329326 // optional "else null" is fine
330327 // only single when-then
@@ -363,8 +360,7 @@ impl WrapperRules {
363360 PATCH_MEASURE_UDAF_NAME ,
364361 vec![
365362 column_expr( "?measure_column" ) ,
366- // TODO support doing both: changing agg type and adding filters
367- literal_null( ) ,
363+ "?replace_agg_type" . to_string( ) ,
368364 wrapper_pushdown_replacer(
369365 // = is a proper way to filter here:
370366 // CASE NULL WHEN ... will return null
@@ -377,9 +373,17 @@ impl WrapperRules {
377373 ) ,
378374 "?out_measure_alias" ,
379375 ) ,
380- self . transform_filtered_measure( "?aggr_expr" , "?literal" , "?out_measure_alias" ) ,
376+ self . transform_filtered_measure(
377+ "?aggr_expr" ,
378+ "?literal" ,
379+ "?measure_column" ,
380+ "?fun" ,
381+ "?cube_members" ,
382+ "?replace_agg_type" ,
383+ "?out_measure_alias" ,
384+ ) ,
381385 ) ,
382- transforming_rewrite (
386+ rewrite (
383387 "wrapper-pull-up-aggregation-over-filtered-measure" ,
384388 udaf_expr(
385389 PATCH_MEASURE_UDAF_NAME ,
@@ -400,14 +404,6 @@ impl WrapperRules {
400404 ) ,
401405 "?context" ,
402406 ) ,
403- |_egraph, _subst| {
404- // dbg!("wrapper-pull-up-aggregation-over-filtered-measure call");
405- // dbg!(&egraph[subst[var!("?measure_column")]]);
406- // dbg!(&egraph[subst[var!("?filter_expr")]]);
407-
408- // TODO do we need to check something here? like a SQL generator?
409- true
410- } ,
411407 ) ,
412408 ] ) ;
413409 }
@@ -947,20 +943,31 @@ impl WrapperRules {
947943 egraph : & mut CubeEGraph ,
948944 subst : & mut Subst ,
949945 column : Column ,
950- call_agg_type : String ,
946+ call_agg_type : Option < String > ,
951947 alias : String ,
952- out_expr_var : Var ,
948+ out_expr_var : Option < Var > ,
949+ out_replace_agg_type : Option < Var > ,
953950 out_alias_var : Var ,
954951 ) {
955952 let column_expr_column = egraph. add ( LogicalPlanLanguage :: ColumnExprColumn (
956- ColumnExprColumn ( column. clone ( ) ) ,
953+ ColumnExprColumn ( column) ,
957954 ) ) ;
958955 let column_expr = egraph. add ( LogicalPlanLanguage :: ColumnExpr ( [ column_expr_column] ) ) ;
959- let new_aggregation_value = egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
960- LiteralExprValue ( ScalarValue :: Utf8 ( Some ( call_agg_type) ) ) ,
961- ) ) ;
956+ let new_aggregation_value = match call_agg_type {
957+ Some ( call_agg_type) => egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
958+ LiteralExprValue ( ScalarValue :: Utf8 ( Some ( call_agg_type) ) ) ,
959+ ) ) ,
960+ None => egraph. add ( LogicalPlanLanguage :: LiteralExprValue ( LiteralExprValue (
961+ ScalarValue :: Null ,
962+ ) ) ) ,
963+ } ;
962964 let new_aggregation_expr =
963965 egraph. add ( LogicalPlanLanguage :: LiteralExpr ( [ new_aggregation_value] ) ) ;
966+
967+ if let Some ( out_replace_agg_type) = out_replace_agg_type {
968+ subst. insert ( out_replace_agg_type, new_aggregation_expr) ;
969+ }
970+
964971 let add_filters_value = egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
965972 LiteralExprValue ( ScalarValue :: Null ) ,
966973 ) ) ;
@@ -978,7 +985,9 @@ impl WrapperRules {
978985 udaf_args_expr,
979986 ] ) ) ;
980987
981- subst. insert ( out_expr_var, udaf_expr) ;
988+ if let Some ( out_expr_var) = out_expr_var {
989+ subst. insert ( out_expr_var, udaf_expr) ;
990+ }
982991
983992 let alias_expr_alias = egraph. add ( LogicalPlanLanguage :: AliasExprAlias ( AliasExprAlias (
984993 alias. clone ( ) ,
@@ -1079,9 +1088,10 @@ impl WrapperRules {
10791088 egraph,
10801089 subst,
10811090 column,
1082- call_agg_type. clone ( ) ,
1091+ Some ( call_agg_type. clone ( ) ) ,
10831092 alias,
1084- out_expr_var,
1093+ Some ( out_expr_var) ,
1094+ None ,
10851095 out_alias_var,
10861096 ) ;
10871097
@@ -1140,12 +1150,23 @@ impl WrapperRules {
11401150 & self ,
11411151 aggr_expr_var : & ' static str ,
11421152 literal_var : & ' static str ,
1153+ column_var : & ' static str ,
1154+ fun_name_var : & ' static str ,
1155+ cube_members_var : & ' static str ,
1156+ replace_agg_type_var : & ' static str ,
11431157 out_measure_alias_var : & ' static str ,
11441158 ) -> impl Fn ( & mut CubeEGraph , & mut Subst ) -> bool {
11451159 let aggr_expr_var = var ! ( aggr_expr_var) ;
11461160 let literal_var = var ! ( literal_var) ;
1161+ let column_var = var ! ( column_var) ;
1162+ let fun_name_var = var ! ( fun_name_var) ;
1163+ let cube_members_var = var ! ( cube_members_var) ;
1164+ let replace_agg_type_var = var ! ( replace_agg_type_var) ;
11471165 let out_measure_alias_var = var ! ( out_measure_alias_var) ;
11481166
1167+ let meta = self . meta_context . clone ( ) ;
1168+ let disable_strict_agg_type_match = self . config_obj . disable_strict_agg_type_match ( ) ;
1169+
11491170 move |egraph, subst| {
11501171 match & egraph[ subst[ literal_var] ] . data . constant {
11511172 Some ( ConstantFolding :: Scalar ( _) ) => {
@@ -1156,17 +1177,96 @@ impl WrapperRules {
11561177 }
11571178 }
11581179
1159- // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1160- // TODO support both changing agg fun and add filter
1161-
11621180 let Some ( alias) = original_expr_name ( egraph, subst[ aggr_expr_var] ) else {
11631181 return false ;
11641182 } ;
1165- let alias_expr_alias =
1166- egraph. add ( LogicalPlanLanguage :: AliasExprAlias ( AliasExprAlias ( alias) ) ) ;
1167- subst. insert ( out_measure_alias_var, alias_expr_alias) ;
11681183
1169- true
1184+ for fun in var_iter ! ( egraph[ subst[ fun_name_var] ] , AggregateFunctionExprFun )
1185+ . cloned ( )
1186+ . collect :: < Vec < _ > > ( )
1187+ {
1188+ let call_agg_type = MemberRules :: get_agg_type ( Some ( & fun) , false ) ;
1189+
1190+ let column_iter = var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1191+ . cloned ( )
1192+ . collect :: < Vec < _ > > ( ) ;
1193+
1194+ if let Some ( member_names_to_expr) = & mut egraph
1195+ . index_mut ( subst[ cube_members_var] )
1196+ . data
1197+ . member_name_to_expr
1198+ {
1199+ for column in column_iter {
1200+ if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1201+ LogicalPlanData :: do_find_member_by_alias (
1202+ member_names_to_expr,
1203+ & column. name ,
1204+ )
1205+ {
1206+ if let Some ( measure) = meta. find_measure_with_name ( member. to_string ( ) ) {
1207+ if !measure. allow_add_filter ( call_agg_type. as_deref ( ) ) {
1208+ continue ;
1209+ }
1210+
1211+ let Some ( call_agg_type) = & call_agg_type else {
1212+ // call_agg_type is None, rewrite as is
1213+ Self :: insert_replaced_measure (
1214+ egraph,
1215+ subst,
1216+ column,
1217+ None ,
1218+ alias,
1219+ None ,
1220+ Some ( replace_agg_type_var) ,
1221+ out_measure_alias_var,
1222+ ) ;
1223+
1224+ return true ;
1225+ } ;
1226+
1227+ if measure
1228+ . is_same_agg_type ( call_agg_type, disable_strict_agg_type_match)
1229+ {
1230+ Self :: insert_replaced_measure (
1231+ egraph,
1232+ subst,
1233+ column,
1234+ None ,
1235+ alias,
1236+ None ,
1237+ Some ( replace_agg_type_var) ,
1238+ out_measure_alias_var,
1239+ ) ;
1240+
1241+ return true ;
1242+ }
1243+
1244+ if measure. allow_replace_agg_type (
1245+ call_agg_type,
1246+ disable_strict_agg_type_match,
1247+ ) {
1248+ Self :: insert_replaced_measure (
1249+ egraph,
1250+ subst,
1251+ column,
1252+ Some ( call_agg_type. clone ( ) ) ,
1253+ alias,
1254+ None ,
1255+ Some ( replace_agg_type_var) ,
1256+ out_measure_alias_var,
1257+ ) ;
1258+
1259+ return true ;
1260+ }
1261+ }
1262+ }
1263+ }
1264+ }
1265+ }
1266+
1267+ false
1268+
1269+ // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
11701270 }
11711271 }
11721272}
0 commit comments