@@ -2,11 +2,13 @@ use crate::{
22 compile:: {
33 engine:: udf:: { MEASURE_UDAF_NAME , PATCH_MEASURE_UDAF_NAME } ,
44 rewrite:: {
5- aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
6- rewrite,
5+ agg_fun_expr, aggregate, alias_expr,
6+ analysis:: ConstantFolding ,
7+ binary_expr, case_expr, column_expr, cube_scan_wrapper, grouping_set_expr,
8+ literal_null, original_expr_name, rewrite,
79 rewriter:: { CubeEGraph , CubeRewrite } ,
810 rules:: { members:: MemberRules , wrapper:: WrapperRules } ,
9- subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
11+ subquery, transforming_chain_rewrite, transforming_rewrite, udaf_expr , wrapped_select,
1012 wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
1113 wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
1214 wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
@@ -319,6 +321,91 @@ impl WrapperRules {
319321 "GroupingSetExprMembers" ,
320322 ) ;
321323 }
324+
325+ // incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column)))
326+ // optional "else null" is fine
327+ // only single when-then
328+ rules. extend ( vec ! [
329+ transforming_chain_rewrite(
330+ "wrapper-push-down-aggregation-over-filtered-measure" ,
331+ wrapper_pushdown_replacer( "?aggr_expr" , "?context" ) ,
332+ vec![
333+ (
334+ "?aggr_expr" ,
335+ agg_fun_expr(
336+ "?fun" ,
337+ vec![ case_expr(
338+ Some ( "?case_expr" . to_string( ) ) ,
339+ vec![ ( "?literal" . to_string( ) , column_expr( "?measure_column" ) ) ] ,
340+ // TODO make `ELSE NULL` optional and/or add generic rewrite to normalize it
341+ Some ( literal_null( ) ) ,
342+ ) ] ,
343+ "?distinct" ,
344+ ) ,
345+ ) ,
346+ (
347+ "?context" ,
348+ wrapper_replacer_context(
349+ "?alias_to_cube" ,
350+ "WrapperReplacerContextPushToCube:true" ,
351+ "?in_projection" ,
352+ "?cube_members" ,
353+ "?grouped_subqueries" ,
354+ "?ungrouped_scan" ,
355+ ) ,
356+ ) ,
357+ ] ,
358+ alias_expr(
359+ udaf_expr(
360+ PATCH_MEASURE_UDAF_NAME ,
361+ vec![
362+ column_expr( "?measure_column" ) ,
363+ "?replace_agg_type" . to_string( ) ,
364+ wrapper_pushdown_replacer(
365+ // = is a proper way to filter here:
366+ // CASE NULL WHEN ... will return null
367+ // So NULL in ?case_expr is equivalent to hitting ELSE branch
368+ // TODO add "is not null" to cond? just to make is always boolean
369+ binary_expr( "?case_expr" , "=" , "?literal" ) ,
370+ "?context" ,
371+ ) ,
372+ ] ,
373+ ) ,
374+ "?out_measure_alias" ,
375+ ) ,
376+ self . transform_filtered_measure(
377+ "?aggr_expr" ,
378+ "?literal" ,
379+ "?measure_column" ,
380+ "?fun" ,
381+ "?cube_members" ,
382+ "?replace_agg_type" ,
383+ "?out_measure_alias" ,
384+ ) ,
385+ ) ,
386+ rewrite(
387+ "wrapper-pull-up-aggregation-over-filtered-measure" ,
388+ udaf_expr(
389+ PATCH_MEASURE_UDAF_NAME ,
390+ vec![
391+ column_expr( "?measure_column" ) ,
392+ "?new_agg_type" . to_string( ) ,
393+ wrapper_pullup_replacer( "?filter_expr" , "?context" ) ,
394+ ] ,
395+ ) ,
396+ wrapper_pullup_replacer(
397+ udaf_expr(
398+ PATCH_MEASURE_UDAF_NAME ,
399+ vec![
400+ column_expr( "?measure_column" ) ,
401+ "?new_agg_type" . to_string( ) ,
402+ "?filter_expr" . to_string( ) ,
403+ ] ,
404+ ) ,
405+ "?context" ,
406+ ) ,
407+ ) ,
408+ ] ) ;
322409 }
323410
324411 pub fn aggregate_rules_subquery ( & self , rules : & mut Vec < CubeRewrite > ) {
@@ -1058,4 +1145,128 @@ impl WrapperRules {
10581145 )
10591146 }
10601147 }
1148+
1149+ fn transform_filtered_measure (
1150+ & self ,
1151+ aggr_expr_var : & ' static str ,
1152+ literal_var : & ' static str ,
1153+ column_var : & ' static str ,
1154+ fun_name_var : & ' static str ,
1155+ cube_members_var : & ' static str ,
1156+ replace_agg_type_var : & ' static str ,
1157+ out_measure_alias_var : & ' static str ,
1158+ ) -> impl Fn ( & mut CubeEGraph , & mut Subst ) -> bool {
1159+ let aggr_expr_var = var ! ( aggr_expr_var) ;
1160+ let literal_var = var ! ( literal_var) ;
1161+ let column_var = var ! ( column_var) ;
1162+ let fun_name_var = var ! ( fun_name_var) ;
1163+ let cube_members_var = var ! ( cube_members_var) ;
1164+ let replace_agg_type_var = var ! ( replace_agg_type_var) ;
1165+ let out_measure_alias_var = var ! ( out_measure_alias_var) ;
1166+
1167+ let meta = self . meta_context . clone ( ) ;
1168+ let disable_strict_agg_type_match = self . config_obj . disable_strict_agg_type_match ( ) ;
1169+
1170+ move |egraph, subst| {
1171+ match & egraph[ subst[ literal_var] ] . data . constant {
1172+ Some ( ConstantFolding :: Scalar ( _) ) => {
1173+ // Do nothing
1174+ }
1175+ _ => {
1176+ return false ;
1177+ }
1178+ }
1179+
1180+ let Some ( alias) = original_expr_name ( egraph, subst[ aggr_expr_var] ) else {
1181+ return false ;
1182+ } ;
1183+
1184+ for fun in var_iter ! ( egraph[ subst[ fun_name_var] ] , AggregateFunctionExprFun )
1185+ . cloned ( )
1186+ . collect :: < Vec < _ > > ( )
1187+ {
1188+ let call_agg_type = MemberRules :: get_agg_type ( Some ( & fun) , false ) ;
1189+
1190+ let column_iter = var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1191+ . cloned ( )
1192+ . collect :: < Vec < _ > > ( ) ;
1193+
1194+ if let Some ( member_names_to_expr) = & mut egraph
1195+ . index_mut ( subst[ cube_members_var] )
1196+ . data
1197+ . member_name_to_expr
1198+ {
1199+ for column in column_iter {
1200+ if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1201+ LogicalPlanData :: do_find_member_by_alias (
1202+ member_names_to_expr,
1203+ & column. name ,
1204+ )
1205+ {
1206+ if let Some ( measure) = meta. find_measure_with_name ( member) {
1207+ if !measure. allow_add_filter ( call_agg_type. as_deref ( ) ) {
1208+ continue ;
1209+ }
1210+
1211+ let Some ( call_agg_type) = & call_agg_type else {
1212+ // call_agg_type is None, rewrite as is
1213+ Self :: insert_patch_measure (
1214+ egraph,
1215+ subst,
1216+ column,
1217+ None ,
1218+ alias,
1219+ None ,
1220+ Some ( replace_agg_type_var) ,
1221+ out_measure_alias_var,
1222+ ) ;
1223+
1224+ return true ;
1225+ } ;
1226+
1227+ if measure
1228+ . is_same_agg_type ( call_agg_type, disable_strict_agg_type_match)
1229+ {
1230+ Self :: insert_patch_measure (
1231+ egraph,
1232+ subst,
1233+ column,
1234+ None ,
1235+ alias,
1236+ None ,
1237+ Some ( replace_agg_type_var) ,
1238+ out_measure_alias_var,
1239+ ) ;
1240+
1241+ return true ;
1242+ }
1243+
1244+ if measure. allow_replace_agg_type (
1245+ call_agg_type,
1246+ disable_strict_agg_type_match,
1247+ ) {
1248+ Self :: insert_patch_measure (
1249+ egraph,
1250+ subst,
1251+ column,
1252+ Some ( call_agg_type. clone ( ) ) ,
1253+ alias,
1254+ None ,
1255+ Some ( replace_agg_type_var) ,
1256+ out_measure_alias_var,
1257+ ) ;
1258+
1259+ return true ;
1260+ }
1261+ }
1262+ }
1263+ }
1264+ }
1265+ }
1266+
1267+ false
1268+
1269+ // TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1270+ }
1271+ }
10611272}
0 commit comments