Skip to content

Commit f9bc496

Browse files
committed
[WIP][FIXUP][SPLIT] rework measure filters rewrite
1 parent 9ee4ef2 commit f9bc496

File tree

3 files changed

+192
-49
lines changed

3 files changed

+192
-49
lines changed

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 135 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ use crate::{
22
compile::{
33
engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
44
rewrite::{
5-
aggregate, alias_expr,
5+
agg_fun_expr, aggregate, alias_expr,
66
analysis::ConstantFolding,
7-
cube_scan_wrapper, grouping_set_expr, original_expr_name, rewrite,
7+
binary_expr, case_expr, column_expr, cube_scan_wrapper, grouping_set_expr,
8+
literal_null, original_expr_name, rewrite,
89
rewriter::{CubeEGraph, CubeRewrite},
910
rules::{members::MemberRules, wrapper::WrapperRules},
10-
subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
11+
subquery, transforming_chain_rewrite, transforming_rewrite, udaf_expr, wrapped_select,
1112
wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
1213
wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
1314
wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
@@ -321,10 +322,6 @@ impl WrapperRules {
321322
);
322323
}
323324

324-
use crate::compile::rewrite::{
325-
agg_fun_expr, binary_expr, case_expr, column_expr, literal_null, udaf_expr,
326-
};
327-
328325
// incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column)))
329326
// optional "else null" is fine
330327
// only single when-then
@@ -363,8 +360,7 @@ impl WrapperRules {
363360
PATCH_MEASURE_UDAF_NAME,
364361
vec![
365362
column_expr("?measure_column"),
366-
// TODO support doing both: changing agg type and adding filters
367-
literal_null(),
363+
"?replace_agg_type".to_string(),
368364
wrapper_pushdown_replacer(
369365
// = is a proper way to filter here:
370366
// CASE NULL WHEN ... will return null
@@ -377,9 +373,17 @@ impl WrapperRules {
377373
),
378374
"?out_measure_alias",
379375
),
380-
self.transform_filtered_measure("?aggr_expr", "?literal", "?out_measure_alias"),
376+
self.transform_filtered_measure(
377+
"?aggr_expr",
378+
"?literal",
379+
"?measure_column",
380+
"?fun",
381+
"?cube_members",
382+
"?replace_agg_type",
383+
"?out_measure_alias",
384+
),
381385
),
382-
transforming_rewrite(
386+
rewrite(
383387
"wrapper-pull-up-aggregation-over-filtered-measure",
384388
udaf_expr(
385389
PATCH_MEASURE_UDAF_NAME,
@@ -400,14 +404,6 @@ impl WrapperRules {
400404
),
401405
"?context",
402406
),
403-
|_egraph, _subst| {
404-
// dbg!("wrapper-pull-up-aggregation-over-filtered-measure call");
405-
// dbg!(&egraph[subst[var!("?measure_column")]]);
406-
// dbg!(&egraph[subst[var!("?filter_expr")]]);
407-
408-
// TODO do we need to check something here? like a SQL generator?
409-
true
410-
},
411407
),
412408
]);
413409
}
@@ -947,20 +943,31 @@ impl WrapperRules {
947943
egraph: &mut CubeEGraph,
948944
subst: &mut Subst,
949945
column: Column,
950-
call_agg_type: String,
946+
call_agg_type: Option<String>,
951947
alias: String,
952-
out_expr_var: Var,
948+
out_expr_var: Option<Var>,
949+
out_replace_agg_type: Option<Var>,
953950
out_alias_var: Var,
954951
) {
955952
let column_expr_column = egraph.add(LogicalPlanLanguage::ColumnExprColumn(
956-
ColumnExprColumn(column.clone()),
953+
ColumnExprColumn(column),
957954
));
958955
let column_expr = egraph.add(LogicalPlanLanguage::ColumnExpr([column_expr_column]));
959-
let new_aggregation_value = egraph.add(LogicalPlanLanguage::LiteralExprValue(
960-
LiteralExprValue(ScalarValue::Utf8(Some(call_agg_type))),
961-
));
956+
let new_aggregation_value = match call_agg_type {
957+
Some(call_agg_type) => egraph.add(LogicalPlanLanguage::LiteralExprValue(
958+
LiteralExprValue(ScalarValue::Utf8(Some(call_agg_type))),
959+
)),
960+
None => egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(
961+
ScalarValue::Null,
962+
))),
963+
};
962964
let new_aggregation_expr =
963965
egraph.add(LogicalPlanLanguage::LiteralExpr([new_aggregation_value]));
966+
967+
if let Some(out_replace_agg_type) = out_replace_agg_type {
968+
subst.insert(out_replace_agg_type, new_aggregation_expr);
969+
}
970+
964971
let add_filters_value = egraph.add(LogicalPlanLanguage::LiteralExprValue(
965972
LiteralExprValue(ScalarValue::Null),
966973
));
@@ -978,7 +985,9 @@ impl WrapperRules {
978985
udaf_args_expr,
979986
]));
980987

981-
subst.insert(out_expr_var, udaf_expr);
988+
if let Some(out_expr_var) = out_expr_var {
989+
subst.insert(out_expr_var, udaf_expr);
990+
}
982991

983992
let alias_expr_alias = egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
984993
alias.clone(),
@@ -1079,9 +1088,10 @@ impl WrapperRules {
10791088
egraph,
10801089
subst,
10811090
column,
1082-
call_agg_type.clone(),
1091+
Some(call_agg_type.clone()),
10831092
alias,
1084-
out_expr_var,
1093+
Some(out_expr_var),
1094+
None,
10851095
out_alias_var,
10861096
);
10871097

@@ -1140,12 +1150,23 @@ impl WrapperRules {
11401150
&self,
11411151
aggr_expr_var: &'static str,
11421152
literal_var: &'static str,
1153+
column_var: &'static str,
1154+
fun_name_var: &'static str,
1155+
cube_members_var: &'static str,
1156+
replace_agg_type_var: &'static str,
11431157
out_measure_alias_var: &'static str,
11441158
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
11451159
let aggr_expr_var = var!(aggr_expr_var);
11461160
let literal_var = var!(literal_var);
1161+
let column_var = var!(column_var);
1162+
let fun_name_var = var!(fun_name_var);
1163+
let cube_members_var = var!(cube_members_var);
1164+
let replace_agg_type_var = var!(replace_agg_type_var);
11471165
let out_measure_alias_var = var!(out_measure_alias_var);
11481166

1167+
let meta = self.meta_context.clone();
1168+
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
1169+
11491170
move |egraph, subst| {
11501171
match &egraph[subst[literal_var]].data.constant {
11511172
Some(ConstantFolding::Scalar(_)) => {
@@ -1156,17 +1177,96 @@ impl WrapperRules {
11561177
}
11571178
}
11581179

1159-
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1160-
// TODO support both changing agg fun and add filter
1161-
11621180
let Some(alias) = original_expr_name(egraph, subst[aggr_expr_var]) else {
11631181
return false;
11641182
};
1165-
let alias_expr_alias =
1166-
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(alias)));
1167-
subst.insert(out_measure_alias_var, alias_expr_alias);
11681183

1169-
true
1184+
for fun in var_iter!(egraph[subst[fun_name_var]], AggregateFunctionExprFun)
1185+
.cloned()
1186+
.collect::<Vec<_>>()
1187+
{
1188+
let call_agg_type = MemberRules::get_agg_type(Some(&fun), false);
1189+
1190+
let column_iter = var_iter!(egraph[subst[column_var]], ColumnExprColumn)
1191+
.cloned()
1192+
.collect::<Vec<_>>();
1193+
1194+
if let Some(member_names_to_expr) = &mut egraph
1195+
.index_mut(subst[cube_members_var])
1196+
.data
1197+
.member_name_to_expr
1198+
{
1199+
for column in column_iter {
1200+
if let Some((&(Some(ref member), _, _), _)) =
1201+
LogicalPlanData::do_find_member_by_alias(
1202+
member_names_to_expr,
1203+
&column.name,
1204+
)
1205+
{
1206+
if let Some(measure) = meta.find_measure_with_name(member.to_string()) {
1207+
if !measure.allow_add_filter(call_agg_type.as_deref()) {
1208+
continue;
1209+
}
1210+
1211+
let Some(call_agg_type) = &call_agg_type else {
1212+
// call_agg_type is None, rewrite as is
1213+
Self::insert_replaced_measure(
1214+
egraph,
1215+
subst,
1216+
column,
1217+
None,
1218+
alias,
1219+
None,
1220+
Some(replace_agg_type_var),
1221+
out_measure_alias_var,
1222+
);
1223+
1224+
return true;
1225+
};
1226+
1227+
if measure
1228+
.is_same_agg_type(call_agg_type, disable_strict_agg_type_match)
1229+
{
1230+
Self::insert_replaced_measure(
1231+
egraph,
1232+
subst,
1233+
column,
1234+
None,
1235+
alias,
1236+
None,
1237+
Some(replace_agg_type_var),
1238+
out_measure_alias_var,
1239+
);
1240+
1241+
return true;
1242+
}
1243+
1244+
if measure.allow_replace_agg_type(
1245+
call_agg_type,
1246+
disable_strict_agg_type_match,
1247+
) {
1248+
Self::insert_replaced_measure(
1249+
egraph,
1250+
subst,
1251+
column,
1252+
Some(call_agg_type.clone()),
1253+
alias,
1254+
None,
1255+
Some(replace_agg_type_var),
1256+
out_measure_alias_var,
1257+
);
1258+
1259+
return true;
1260+
}
1261+
}
1262+
}
1263+
}
1264+
}
1265+
}
1266+
1267+
false
1268+
1269+
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
11701270
}
11711271
}
11721272
}

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,20 +1573,36 @@ GROUP BY
15731573
.find_cube_scan_wrapped_sql()
15741574
.request,
15751575
TransportLoadRequestQuery {
1576-
measures: Some(vec![json!({
1577-
"cubeName": "MultiTypeCube",
1578-
"alias": "avg_case_cast_tr",
1579-
"expr": {
1580-
"type": "PatchMeasure",
1581-
"sourceMeasure": "MultiTypeCube.avgPrice",
1582-
"replaceAggregationType": null,
1583-
"addFilters": [{
1584-
"cubeParams": ["MultiTypeCube"],
1585-
"sql": "(((CAST(TRUNC(EXTRACT(YEAR FROM ${MultiTypeCube.dim_date0})) AS INTEGER) = 2024) AND (CAST(TRUNC(EXTRACT(MONTH FROM ${MultiTypeCube.dim_date0})) AS INTEGER) <= 11)) = TRUE)"
1586-
}],
1587-
},
1588-
"groupingSet": null,
1589-
}).to_string(),]),
1576+
measures: Some(vec![
1577+
json!({
1578+
"cubeName": "MultiTypeCube",
1579+
"alias": "avg_case_cast_tr",
1580+
"expr": {
1581+
"type": "PatchMeasure",
1582+
"sourceMeasure": "MultiTypeCube.avgPrice",
1583+
"replaceAggregationType": null,
1584+
"addFilters": [{
1585+
"cubeParams": ["MultiTypeCube"],
1586+
"sql": "(((CAST(TRUNC(EXTRACT(YEAR FROM ${MultiTypeCube.dim_date0})) AS INTEGER) = 2024) AND (CAST(TRUNC(EXTRACT(MONTH FROM ${MultiTypeCube.dim_date0})) AS INTEGER) <= 11)) = TRUE)"
1587+
}],
1588+
},
1589+
"groupingSet": null,
1590+
}).to_string(),
1591+
json!({
1592+
"cubeName": "MultiTypeCube",
1593+
"alias": "sum_case_multity",
1594+
"expr": {
1595+
"type": "PatchMeasure",
1596+
"sourceMeasure": "MultiTypeCube.maxPrice",
1597+
"replaceAggregationType": "sum",
1598+
"addFilters": [{
1599+
"cubeParams": ["MultiTypeCube"],
1600+
"sql": "((${MultiTypeCube.dim_str1} = $0$) = TRUE)"
1601+
}],
1602+
},
1603+
"groupingSet": null,
1604+
}).to_string(),
1605+
]),
15901606
dimensions: Some(vec![json!({
15911607
"cubeName": "MultiTypeCube",
15921608
"alias": "dim_str0",

rust/cubesql/cubesql/src/transport/ext.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub trait V1CubeMetaMeasureExt {
1212

1313
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool;
1414

15+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool;
16+
1517
fn get_sql_type(&self) -> ColumnType;
1618
}
1719

@@ -72,6 +74,31 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure {
7274
}
7375
}
7476

77+
// This should be aligned with BaseMeasure.preparePatchedMeasure
78+
// TODO proper reference
79+
// See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16
80+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool {
81+
let Some(agg_type) = &self.agg_type else {
82+
return false;
83+
};
84+
85+
let agg_type = match query_agg_type {
86+
Some(query_agg_type) => query_agg_type,
87+
None => agg_type,
88+
};
89+
90+
match agg_type {
91+
"sum"
92+
| "avg"
93+
| "min"
94+
| "max"
95+
| "count"
96+
| "count_distinct"
97+
| "count_distinct_approx" => true,
98+
_ => false,
99+
}
100+
}
101+
75102
fn get_sql_type(&self) -> ColumnType {
76103
let from_type = match &self.r#type.to_lowercase().as_str() {
77104
&"number" => ColumnType::Double,

0 commit comments

Comments
 (0)