Skip to content

Commit 957990c

Browse files
committed
feat(cubesql): Rewrite measure filters in query as PatchMeasure
1 parent 04cb4c8 commit 957990c

File tree

4 files changed

+341
-4
lines changed

4 files changed

+341
-4
lines changed

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,6 @@ fn literal_bool(literal_bool: bool) -> String {
17801780
format!("(LiteralExpr LiteralExprValue:b:{})", literal_bool)
17811781
}
17821782

1783-
#[allow(dead_code)]
17841783
fn literal_null() -> String {
17851784
format!("(LiteralExpr LiteralExprValue:null)")
17861785
}

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 214 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ use crate::{
22
compile::{
33
engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
44
rewrite::{
5-
aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
6-
rewrite,
5+
agg_fun_expr, aggregate, alias_expr,
6+
analysis::ConstantFolding,
7+
binary_expr, case_expr, column_expr, cube_scan_wrapper, grouping_set_expr,
8+
literal_null, original_expr_name, rewrite,
79
rewriter::{CubeEGraph, CubeRewrite},
810
rules::{members::MemberRules, wrapper::WrapperRules},
9-
subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
11+
subquery, transforming_chain_rewrite, transforming_rewrite, udaf_expr, wrapped_select,
1012
wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail,
1113
wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail,
1214
wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail,
@@ -319,6 +321,91 @@ impl WrapperRules {
319321
"GroupingSetExprMembers",
320322
);
321323
}
324+
325+
// incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column)))
326+
// optional "else null" is fine
327+
// only single when-then
328+
rules.extend(vec![
329+
transforming_chain_rewrite(
330+
"wrapper-push-down-aggregation-over-filtered-measure",
331+
wrapper_pushdown_replacer("?aggr_expr", "?context"),
332+
vec![
333+
(
334+
"?aggr_expr",
335+
agg_fun_expr(
336+
"?fun",
337+
vec![case_expr(
338+
Some("?case_expr".to_string()),
339+
vec![("?literal".to_string(), column_expr("?measure_column"))],
340+
// TODO make `ELSE NULL` optional and/or add generic rewrite to normalize it
341+
Some(literal_null()),
342+
)],
343+
"?distinct",
344+
),
345+
),
346+
(
347+
"?context",
348+
wrapper_replacer_context(
349+
"?alias_to_cube",
350+
"WrapperReplacerContextPushToCube:true",
351+
"?in_projection",
352+
"?cube_members",
353+
"?grouped_subqueries",
354+
"?ungrouped_scan",
355+
),
356+
),
357+
],
358+
alias_expr(
359+
udaf_expr(
360+
PATCH_MEASURE_UDAF_NAME,
361+
vec![
362+
column_expr("?measure_column"),
363+
"?replace_agg_type".to_string(),
364+
wrapper_pushdown_replacer(
365+
// = is a proper way to filter here:
366+
// CASE NULL WHEN ... will return null
367+
// So NULL in ?case_expr is equivalent to hitting ELSE branch
368+
// TODO add "is not null" to cond? just to make is always boolean
369+
binary_expr("?case_expr", "=", "?literal"),
370+
"?context",
371+
),
372+
],
373+
),
374+
"?out_measure_alias",
375+
),
376+
self.transform_filtered_measure(
377+
"?aggr_expr",
378+
"?literal",
379+
"?measure_column",
380+
"?fun",
381+
"?cube_members",
382+
"?replace_agg_type",
383+
"?out_measure_alias",
384+
),
385+
),
386+
rewrite(
387+
"wrapper-pull-up-aggregation-over-filtered-measure",
388+
udaf_expr(
389+
PATCH_MEASURE_UDAF_NAME,
390+
vec![
391+
column_expr("?measure_column"),
392+
"?new_agg_type".to_string(),
393+
wrapper_pullup_replacer("?filter_expr", "?context"),
394+
],
395+
),
396+
wrapper_pullup_replacer(
397+
udaf_expr(
398+
PATCH_MEASURE_UDAF_NAME,
399+
vec![
400+
column_expr("?measure_column"),
401+
"?new_agg_type".to_string(),
402+
"?filter_expr".to_string(),
403+
],
404+
),
405+
"?context",
406+
),
407+
),
408+
]);
322409
}
323410

324411
pub fn aggregate_rules_subquery(&self, rules: &mut Vec<CubeRewrite>) {
@@ -1058,4 +1145,128 @@ impl WrapperRules {
10581145
)
10591146
}
10601147
}
1148+
1149+
fn transform_filtered_measure(
1150+
&self,
1151+
aggr_expr_var: &'static str,
1152+
literal_var: &'static str,
1153+
column_var: &'static str,
1154+
fun_name_var: &'static str,
1155+
cube_members_var: &'static str,
1156+
replace_agg_type_var: &'static str,
1157+
out_measure_alias_var: &'static str,
1158+
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
1159+
let aggr_expr_var = var!(aggr_expr_var);
1160+
let literal_var = var!(literal_var);
1161+
let column_var = var!(column_var);
1162+
let fun_name_var = var!(fun_name_var);
1163+
let cube_members_var = var!(cube_members_var);
1164+
let replace_agg_type_var = var!(replace_agg_type_var);
1165+
let out_measure_alias_var = var!(out_measure_alias_var);
1166+
1167+
let meta = self.meta_context.clone();
1168+
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
1169+
1170+
move |egraph, subst| {
1171+
match &egraph[subst[literal_var]].data.constant {
1172+
Some(ConstantFolding::Scalar(_)) => {
1173+
// Do nothing
1174+
}
1175+
_ => {
1176+
return false;
1177+
}
1178+
}
1179+
1180+
let Some(alias) = original_expr_name(egraph, subst[aggr_expr_var]) else {
1181+
return false;
1182+
};
1183+
1184+
for fun in var_iter!(egraph[subst[fun_name_var]], AggregateFunctionExprFun)
1185+
.cloned()
1186+
.collect::<Vec<_>>()
1187+
{
1188+
let call_agg_type = MemberRules::get_agg_type(Some(&fun), false);
1189+
1190+
let column_iter = var_iter!(egraph[subst[column_var]], ColumnExprColumn)
1191+
.cloned()
1192+
.collect::<Vec<_>>();
1193+
1194+
if let Some(member_names_to_expr) = &mut egraph
1195+
.index_mut(subst[cube_members_var])
1196+
.data
1197+
.member_name_to_expr
1198+
{
1199+
for column in column_iter {
1200+
if let Some((&(Some(ref member), _, _), _)) =
1201+
LogicalPlanData::do_find_member_by_alias(
1202+
member_names_to_expr,
1203+
&column.name,
1204+
)
1205+
{
1206+
if let Some(measure) = meta.find_measure_with_name(member) {
1207+
if !measure.allow_add_filter(call_agg_type.as_deref()) {
1208+
continue;
1209+
}
1210+
1211+
let Some(call_agg_type) = &call_agg_type else {
1212+
// call_agg_type is None, rewrite as is
1213+
Self::insert_patch_measure(
1214+
egraph,
1215+
subst,
1216+
column,
1217+
None,
1218+
alias,
1219+
None,
1220+
Some(replace_agg_type_var),
1221+
out_measure_alias_var,
1222+
);
1223+
1224+
return true;
1225+
};
1226+
1227+
if measure
1228+
.is_same_agg_type(call_agg_type, disable_strict_agg_type_match)
1229+
{
1230+
Self::insert_patch_measure(
1231+
egraph,
1232+
subst,
1233+
column,
1234+
None,
1235+
alias,
1236+
None,
1237+
Some(replace_agg_type_var),
1238+
out_measure_alias_var,
1239+
);
1240+
1241+
return true;
1242+
}
1243+
1244+
if measure.allow_replace_agg_type(
1245+
call_agg_type,
1246+
disable_strict_agg_type_match,
1247+
) {
1248+
Self::insert_patch_measure(
1249+
egraph,
1250+
subst,
1251+
column,
1252+
Some(call_agg_type.clone()),
1253+
alias,
1254+
None,
1255+
Some(replace_agg_type_var),
1256+
out_measure_alias_var,
1257+
);
1258+
1259+
return true;
1260+
}
1261+
}
1262+
}
1263+
}
1264+
}
1265+
}
1266+
1267+
false
1268+
1269+
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1270+
}
1271+
}
10611272
}

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,3 +1801,104 @@ async fn wrapper_min_from_avg_measure() {
18011801
}
18021802
);
18031803
}
1804+
1805+
#[tokio::test]
1806+
async fn test_ad_hoc_measure_filter() {
1807+
if !Rewriter::sql_push_down_enabled() {
1808+
return;
1809+
}
1810+
init_testing_logger();
1811+
1812+
let query_plan = convert_select_to_query_plan(
1813+
// language=PostgreSQL
1814+
r#"SELECT
1815+
dim_str0,
1816+
AVG(
1817+
CASE (
1818+
(
1819+
CAST(TRUNC(EXTRACT(YEAR FROM dim_date0)) AS INTEGER) = 2024
1820+
)
1821+
AND
1822+
(
1823+
CAST(TRUNC(EXTRACT(MONTH FROM dim_date0)) AS INTEGER) <= 11
1824+
)
1825+
)
1826+
WHEN TRUE
1827+
THEN avgPrice
1828+
ELSE NULL
1829+
END
1830+
),
1831+
SUM(
1832+
CASE (dim_str1 = 'foo')
1833+
WHEN TRUE
1834+
THEN maxPrice
1835+
ELSE NULL
1836+
END
1837+
)
1838+
FROM MultiTypeCube
1839+
GROUP BY
1840+
1
1841+
;"#
1842+
.to_string(),
1843+
DatabaseProtocol::PostgreSQL,
1844+
)
1845+
.await;
1846+
1847+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1848+
println!(
1849+
"Physical plan: {}",
1850+
displayable(physical_plan.as_ref()).indent()
1851+
);
1852+
1853+
assert_eq!(
1854+
query_plan
1855+
.as_logical_plan()
1856+
.find_cube_scan_wrapped_sql()
1857+
.request,
1858+
TransportLoadRequestQuery {
1859+
measures: Some(vec![
1860+
json!({
1861+
"cubeName": "MultiTypeCube",
1862+
"alias": "avg_case_cast_tr",
1863+
"expr": {
1864+
"type": "PatchMeasure",
1865+
"sourceMeasure": "MultiTypeCube.avgPrice",
1866+
"replaceAggregationType": null,
1867+
"addFilters": [{
1868+
"cubeParams": ["MultiTypeCube"],
1869+
"sql": "(((CAST(TRUNC(EXTRACT(YEAR FROM ${MultiTypeCube.dim_date0})) AS INTEGER) = 2024) AND (CAST(TRUNC(EXTRACT(MONTH FROM ${MultiTypeCube.dim_date0})) AS INTEGER) <= 11)) = TRUE)"
1870+
}],
1871+
},
1872+
"groupingSet": null,
1873+
}).to_string(),
1874+
json!({
1875+
"cubeName": "MultiTypeCube",
1876+
"alias": "sum_case_multity",
1877+
"expr": {
1878+
"type": "PatchMeasure",
1879+
"sourceMeasure": "MultiTypeCube.maxPrice",
1880+
"replaceAggregationType": "sum",
1881+
"addFilters": [{
1882+
"cubeParams": ["MultiTypeCube"],
1883+
"sql": "((${MultiTypeCube.dim_str1} = $0$) = TRUE)"
1884+
}],
1885+
},
1886+
"groupingSet": null,
1887+
}).to_string(),
1888+
]),
1889+
dimensions: Some(vec![json!({
1890+
"cubeName": "MultiTypeCube",
1891+
"alias": "dim_str0",
1892+
"expr": {
1893+
"type": "SqlFunction",
1894+
"cubeParams": ["MultiTypeCube"],
1895+
"sql": "${MultiTypeCube.dim_str0}",
1896+
},
1897+
"groupingSet": null,
1898+
}).to_string(),]),
1899+
segments: Some(vec![]),
1900+
order: Some(vec![]),
1901+
..Default::default()
1902+
}
1903+
);
1904+
}

rust/cubesql/cubesql/src/transport/ext.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub trait V1CubeMetaMeasureExt {
1212

1313
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool;
1414

15+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool;
16+
1517
fn get_sql_type(&self) -> ColumnType;
1618
}
1719

@@ -72,6 +74,30 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure {
7274
}
7375
}
7476

77+
// This should be aligned with BaseMeasure.preparePatchedMeasure
78+
// See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16
79+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool {
80+
let Some(agg_type) = &self.agg_type else {
81+
return false;
82+
};
83+
84+
let agg_type = match query_agg_type {
85+
Some(query_agg_type) => query_agg_type,
86+
None => agg_type,
87+
};
88+
89+
match agg_type {
90+
"sum"
91+
| "avg"
92+
| "min"
93+
| "max"
94+
| "count"
95+
| "count_distinct"
96+
| "count_distinct_approx" => true,
97+
_ => false,
98+
}
99+
}
100+
75101
fn get_sql_type(&self) -> ColumnType {
76102
let from_type = match &self.r#type.to_lowercase().as_str() {
77103
&"number" => ColumnType::Double,

0 commit comments

Comments
 (0)