11use crate :: {
22 compile:: {
3- engine:: udf:: MEASURE_UDAF_NAME ,
3+ engine:: udf:: { MEASURE_UDAF_NAME , PATCH_MEASURE_UDAF_NAME } ,
44 rewrite:: {
55 aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
66 rewrite,
@@ -14,15 +14,16 @@ use crate::{
1414 wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer,
1515 wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct ,
1616 AggregateFunctionExprFun , AggregateUDFExprFun , AliasExprAlias , ColumnExprColumn ,
17- ListType , LogicalPlanData , LogicalPlanLanguage , WrappedSelectPushToCube ,
18- WrapperReplacerContextAliasToCube , WrapperReplacerContextPushToCube ,
17+ ListType , LiteralExprValue , LogicalPlanData , LogicalPlanLanguage ,
18+ WrappedSelectPushToCube , WrapperReplacerContextAliasToCube ,
19+ WrapperReplacerContextPushToCube ,
1920 } ,
2021 } ,
2122 copy_flag,
2223 transport:: V1CubeMetaMeasureExt ,
2324 var, var_iter,
2425} ;
25- use datafusion:: logical_plan:: Column ;
26+ use datafusion:: { logical_plan:: Column , scalar :: ScalarValue } ;
2627use egg:: { Subst , Var } ;
2728use std:: ops:: IndexMut ;
2829
@@ -887,15 +888,17 @@ impl WrapperRules {
887888 if let Some ( measure) =
888889 meta. find_measure_with_name ( member. to_string ( ) )
889890 {
890- if call_agg_type. is_none ( )
891- || measure. is_same_agg_type (
892- call_agg_type. as_ref ( ) . unwrap ( ) ,
893- disable_strict_agg_type_match,
894- )
895- {
891+ fn insert_regular_measure (
892+ egraph : & mut CubeEGraph ,
893+ subst : & mut Subst ,
894+ column : Column ,
895+ alias : String ,
896+ out_expr_var : Var ,
897+ out_alias_var : Var ,
898+ ) {
896899 let column_expr_column =
897900 egraph. add ( LogicalPlanLanguage :: ColumnExprColumn (
898- ColumnExprColumn ( column. clone ( ) ) ,
901+ ColumnExprColumn ( column) ,
899902 ) ) ;
900903 let column_expr =
901904 egraph. add ( LogicalPlanLanguage :: ColumnExpr ( [
@@ -920,11 +923,119 @@ impl WrapperRules {
920923
921924 subst. insert ( out_expr_var, udaf_expr) ;
922925
926+ let alias_expr_alias =
927+ egraph. add ( LogicalPlanLanguage :: AliasExprAlias (
928+ AliasExprAlias ( alias) ,
929+ ) ) ;
930+ subst. insert ( out_alias_var, alias_expr_alias) ;
931+ }
932+
933+ fn insert_replaced_measure (
934+ egraph : & mut CubeEGraph ,
935+ subst : & mut Subst ,
936+ column : Column ,
937+ call_agg_type : String ,
938+ alias : String ,
939+ out_expr_var : Var ,
940+ out_alias_var : Var ,
941+ ) {
942+ let column_expr_column =
943+ egraph. add ( LogicalPlanLanguage :: ColumnExprColumn (
944+ ColumnExprColumn ( column. clone ( ) ) ,
945+ ) ) ;
946+ let column_expr =
947+ egraph. add ( LogicalPlanLanguage :: ColumnExpr ( [
948+ column_expr_column,
949+ ] ) ) ;
950+ let new_aggregation_value =
951+ egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
952+ LiteralExprValue ( ScalarValue :: Utf8 ( Some (
953+ call_agg_type,
954+ ) ) ) ,
955+ ) ) ;
956+ let new_aggregation_expr =
957+ egraph. add ( LogicalPlanLanguage :: LiteralExpr ( [
958+ new_aggregation_value,
959+ ] ) ) ;
960+ let add_filters_value =
961+ egraph. add ( LogicalPlanLanguage :: LiteralExprValue (
962+ LiteralExprValue ( ScalarValue :: Null ) ,
963+ ) ) ;
964+ let add_filters_expr =
965+ egraph. add ( LogicalPlanLanguage :: LiteralExpr ( [
966+ add_filters_value,
967+ ] ) ) ;
968+ let udaf_name_expr = egraph. add (
969+ LogicalPlanLanguage :: AggregateUDFExprFun (
970+ AggregateUDFExprFun (
971+ PATCH_MEASURE_UDAF_NAME . to_string ( ) ,
972+ ) ,
973+ ) ,
974+ ) ;
975+ let udaf_args_expr = egraph. add (
976+ LogicalPlanLanguage :: AggregateUDFExprArgs ( vec ! [
977+ column_expr,
978+ new_aggregation_expr,
979+ add_filters_expr,
980+ ] ) ,
981+ ) ;
982+ let udaf_expr =
983+ egraph. add ( LogicalPlanLanguage :: AggregateUDFExpr (
984+ [ udaf_name_expr, udaf_args_expr] ,
985+ ) ) ;
986+
987+ subst. insert ( out_expr_var, udaf_expr) ;
988+
923989 let alias_expr_alias =
924990 egraph. add ( LogicalPlanLanguage :: AliasExprAlias (
925991 AliasExprAlias ( alias. clone ( ) ) ,
926992 ) ) ;
927993 subst. insert ( out_alias_var, alias_expr_alias) ;
994+ }
995+
996+ let Some ( call_agg_type) = & call_agg_type else {
997+ // call_agg_type is None, rewrite as is
998+ insert_regular_measure (
999+ egraph,
1000+ subst,
1001+ column,
1002+ alias,
1003+ out_expr_var,
1004+ out_alias_var,
1005+ ) ;
1006+
1007+ return true ;
1008+ } ;
1009+
1010+ if measure. is_same_agg_type (
1011+ call_agg_type,
1012+ disable_strict_agg_type_match,
1013+ ) {
1014+ insert_regular_measure (
1015+ egraph,
1016+ subst,
1017+ column,
1018+ alias,
1019+ out_expr_var,
1020+ out_alias_var,
1021+ ) ;
1022+
1023+ return true ;
1024+ }
1025+
1026+ if measure. allow_replace_agg_type (
1027+ call_agg_type,
1028+ disable_strict_agg_type_match,
1029+ ) {
1030+ insert_replaced_measure (
1031+ egraph,
1032+ subst,
1033+ column,
1034+ call_agg_type. clone ( ) ,
1035+ alias,
1036+ out_expr_var,
1037+ out_alias_var,
1038+ ) ;
9281039
9291040 return true ;
9301041 }
0 commit comments