Skip to content

Commit ce48b2d

Browse files
committed
refactor(cubesql): Rewrite measure in wrapper as MEASURE(column)
1 parent cb2c386 commit ce48b2d

File tree

2 files changed

+98
-30
lines changed

2 files changed

+98
-30
lines changed

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::{
22
compile::{
3-
engine::df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode},
3+
engine::{
4+
df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode},
5+
udf::MEASURE_UDAF_NAME,
6+
},
47
rewrite::{
58
extract_exprlist_from_groupping_set,
69
rules::{
@@ -2788,7 +2791,55 @@ impl CubeScanWrapperNode {
27882791
})?;
27892792
Ok((resulting_sql, sql_query))
27902793
}
2791-
// Expr::AggregateUDF { .. } => {}
2794+
Expr::AggregateUDF { ref fun, ref args } => {
2795+
match fun.name.as_str() {
2796+
// TODO allow this only in agg expr
2797+
MEASURE_UDAF_NAME => {
2798+
let Some(PushToCubeContext {
2799+
ungrouped_scan_node,
2800+
..
2801+
}) = push_to_cube_context
2802+
else {
2803+
return Err(DataFusionError::Internal(format!(
2804+
"Unexpected {} UDAF expression without push-to-Cube context: {expr}",
2805+
fun.name,
2806+
)));
2807+
};
2808+
2809+
let measure_column = match args.as_slice() {
2810+
[Expr::Column(measure_column)] => measure_column,
2811+
_ => {
2812+
return Err(DataFusionError::Internal(format!(
2813+
"Unexpected arguments for {} UDAF: {expr}",
2814+
fun.name,
2815+
)))
2816+
}
2817+
};
2818+
2819+
let member = Self::find_member_in_ungrouped_scan(
2820+
ungrouped_scan_node,
2821+
measure_column,
2822+
)?;
2823+
2824+
let MemberField::Member(member) = member else {
2825+
return Err(DataFusionError::Internal(format!(
2826+
"First argument for {} UDAF should reference member, not literal: {expr}",
2827+
fun.name,
2828+
)));
2829+
};
2830+
2831+
if let Some(used_members) = used_members {
2832+
used_members.insert(member.clone());
2833+
}
2834+
2835+
Ok((format!("${{{member}}}"), sql_query))
2836+
}
2837+
_ => Err(DataFusionError::Internal(format!(
2838+
"Can't generate SQL for UDAF: {}",
2839+
fun.name
2840+
))),
2841+
}
2842+
}
27922843
Expr::InList {
27932844
expr,
27942845
list,

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
use crate::{
2-
compile::rewrite::{
3-
aggregate,
4-
analysis::LogicalPlanData,
5-
cube_scan_wrapper, grouping_set_expr, original_expr_name, rewrite,
6-
rewriter::{CubeEGraph, CubeRewrite},
7-
rules::{members::MemberRules, wrapper::WrapperRules},
8-
subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
9-
wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
10-
wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
11-
wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
12-
wrapped_select_projection_expr_empty_tail, wrapped_select_subqueries_empty_tail,
13-
wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer, wrapper_pushdown_replacer,
14-
wrapper_replacer_context, AggregateFunctionExprDistinct, AggregateFunctionExprFun,
15-
AliasExprAlias, ColumnExprColumn, ListType, LogicalPlanLanguage, WrappedSelectPushToCube,
16-
WrapperReplacerContextAliasToCube, WrapperReplacerContextPushToCube,
2+
compile::{
3+
engine::udf::MEASURE_UDAF_NAME,
4+
rewrite::{
5+
aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
6+
rewrite,
7+
rewriter::{CubeEGraph, CubeRewrite},
8+
rules::{members::MemberRules, wrapper::WrapperRules},
9+
subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
10+
wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
11+
wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
12+
wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
13+
wrapped_select_projection_expr_empty_tail, wrapped_select_subqueries_empty_tail,
14+
wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer,
15+
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct,
16+
AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn,
17+
ListType, LogicalPlanData, LogicalPlanLanguage, WrappedSelectPushToCube,
18+
WrapperReplacerContextAliasToCube, WrapperReplacerContextPushToCube,
19+
},
1720
},
1821
copy_flag,
1922
transport::V1CubeMetaMeasureExt,
@@ -250,7 +253,7 @@ impl WrapperRules {
250253
),
251254
vec![("?aggr_expr", aggr_expr)],
252255
wrapper_pullup_replacer(
253-
"?measure",
256+
alias_expr("?out_measure_expr", "?out_measure_alias"),
254257
wrapper_replacer_context(
255258
"?alias_to_cube",
256259
"WrapperReplacerContextPushToCube:true",
@@ -267,7 +270,8 @@ impl WrapperRules {
267270
distinct,
268271
cast_data_type,
269272
"?cube_members",
270-
"?measure",
273+
"?out_measure_expr",
274+
"?out_measure_alias",
271275
),
272276
)
273277
},
@@ -827,15 +831,17 @@ impl WrapperRules {
827831
// TODO support cast push downs
828832
_cast_data_type_var: Option<&'static str>,
829833
cube_members_var: &'static str,
830-
measure_out_var: &'static str,
834+
out_expr_var: &'static str,
835+
out_alias_var: &'static str,
831836
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
832837
let original_expr_var = var!(original_expr_var);
833838
let column_var = column_var.map(|v| var!(v));
834839
let fun_name_var = fun_name_var.map(|v| var!(v));
835840
let distinct_var = distinct_var.map(|v| var!(v));
836841
// let cast_data_type_var = cast_data_type_var.map(|v| var!(v));
837842
let cube_members_var = var!(cube_members_var);
838-
let measure_out_var = var!(measure_out_var);
843+
let out_expr_var = var!(out_expr_var);
844+
let out_alias_var = var!(out_alias_var);
839845
let meta = self.meta_context.clone();
840846
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
841847
move |egraph, subst| {
@@ -889,23 +895,34 @@ impl WrapperRules {
889895
egraph.add(LogicalPlanLanguage::ColumnExprColumn(
890896
ColumnExprColumn(column.clone()),
891897
));
892-
893898
let column_expr =
894899
egraph.add(LogicalPlanLanguage::ColumnExpr([
895900
column_expr_column,
896901
]));
902+
let udaf_name_expr = egraph.add(
903+
LogicalPlanLanguage::AggregateUDFExprFun(
904+
AggregateUDFExprFun(
905+
MEASURE_UDAF_NAME.to_string(),
906+
),
907+
),
908+
);
909+
let udaf_args_expr = egraph.add(
910+
LogicalPlanLanguage::AggregateUDFExprArgs(vec![
911+
column_expr,
912+
]),
913+
);
914+
let udaf_expr =
915+
egraph.add(LogicalPlanLanguage::AggregateUDFExpr(
916+
[udaf_name_expr, udaf_args_expr],
917+
));
918+
919+
subst.insert(out_expr_var, udaf_expr);
920+
897921
let alias_expr_alias =
898922
egraph.add(LogicalPlanLanguage::AliasExprAlias(
899923
AliasExprAlias(alias.clone()),
900924
));
901-
902-
let alias_expr =
903-
egraph.add(LogicalPlanLanguage::AliasExpr([
904-
column_expr,
905-
alias_expr_alias,
906-
]));
907-
908-
subst.insert(measure_out_var, alias_expr);
925+
subst.insert(out_alias_var, alias_expr_alias);
909926

910927
return true;
911928
}

0 commit comments

Comments
 (0)