Skip to content

Commit 9ee4ef2

Browse files
committed
[WIP] feat(cubesql): Rewrite measure filters in query as PatchMeasure
1 parent a8e9696 commit 9ee4ef2

File tree

3 files changed

+212
-3
lines changed

3 files changed

+212
-3
lines changed

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,6 @@ fn literal_bool(literal_bool: bool) -> String {
17781778
format!("(LiteralExpr LiteralExprValue:b:{})", literal_bool)
17791779
}
17801780

1781-
#[allow(dead_code)]
17821781
fn literal_null() -> String {
17831782
format!("(LiteralExpr LiteralExprValue:null)")
17841783
}

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::{
22
compile::{
33
engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
44
rewrite::{
5-
aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
6-
rewrite,
5+
aggregate, alias_expr,
6+
analysis::ConstantFolding,
7+
cube_scan_wrapper, grouping_set_expr, original_expr_name, rewrite,
78
rewriter::{CubeEGraph, CubeRewrite},
89
rules::{members::MemberRules, wrapper::WrapperRules},
910
subquery, transforming_chain_rewrite, transforming_rewrite, wrapped_select,
@@ -319,6 +320,96 @@ impl WrapperRules {
319320
"GroupingSetExprMembers",
320321
);
321322
}
323+
324+
use crate::compile::rewrite::{
325+
agg_fun_expr, binary_expr, case_expr, column_expr, literal_null, udaf_expr,
326+
};
327+
328+
// incoming structure: agg_fun(?name, case(?cond, (?when_value, measure_column)))
329+
// optional "else null" is fine
330+
// only single when-then
331+
rules.extend(vec![
332+
transforming_chain_rewrite(
333+
"wrapper-push-down-aggregation-over-filtered-measure",
334+
wrapper_pushdown_replacer("?aggr_expr", "?context"),
335+
vec![
336+
(
337+
"?aggr_expr",
338+
agg_fun_expr(
339+
"?fun",
340+
vec![case_expr(
341+
Some("?case_expr".to_string()),
342+
vec![("?literal".to_string(), column_expr("?measure_column"))],
343+
// TODO make `ELSE NULL` optional and/or add generic rewrite to normalize it
344+
Some(literal_null()),
345+
)],
346+
"?distinct",
347+
),
348+
),
349+
(
350+
"?context",
351+
wrapper_replacer_context(
352+
"?alias_to_cube",
353+
"WrapperReplacerContextPushToCube:true",
354+
"?in_projection",
355+
"?cube_members",
356+
"?grouped_subqueries",
357+
"?ungrouped_scan",
358+
),
359+
),
360+
],
361+
alias_expr(
362+
udaf_expr(
363+
PATCH_MEASURE_UDAF_NAME,
364+
vec![
365+
column_expr("?measure_column"),
366+
// TODO support doing both: changing agg type and adding filters
367+
literal_null(),
368+
wrapper_pushdown_replacer(
369+
// = is a proper way to filter here:
370+
// CASE NULL WHEN ... will return null
371+
// So NULL in ?case_expr is equivalent to hitting ELSE branch
372+
// TODO add "is not null" to cond? just to make is always boolean
373+
binary_expr("?case_expr", "=", "?literal"),
374+
"?context",
375+
),
376+
],
377+
),
378+
"?out_measure_alias",
379+
),
380+
self.transform_filtered_measure("?aggr_expr", "?literal", "?out_measure_alias"),
381+
),
382+
transforming_rewrite(
383+
"wrapper-pull-up-aggregation-over-filtered-measure",
384+
udaf_expr(
385+
PATCH_MEASURE_UDAF_NAME,
386+
vec![
387+
column_expr("?measure_column"),
388+
"?new_agg_type".to_string(),
389+
wrapper_pullup_replacer("?filter_expr", "?context"),
390+
],
391+
),
392+
wrapper_pullup_replacer(
393+
udaf_expr(
394+
PATCH_MEASURE_UDAF_NAME,
395+
vec![
396+
column_expr("?measure_column"),
397+
"?new_agg_type".to_string(),
398+
"?filter_expr".to_string(),
399+
],
400+
),
401+
"?context",
402+
),
403+
|_egraph, _subst| {
404+
// dbg!("wrapper-pull-up-aggregation-over-filtered-measure call");
405+
// dbg!(&egraph[subst[var!("?measure_column")]]);
406+
// dbg!(&egraph[subst[var!("?filter_expr")]]);
407+
408+
// TODO do we need to check something here? like a SQL generator?
409+
true
410+
},
411+
),
412+
]);
322413
}
323414

324415
pub fn aggregate_rules_subquery(&self, rules: &mut Vec<CubeRewrite>) {
@@ -1044,4 +1135,38 @@ impl WrapperRules {
10441135
)
10451136
}
10461137
}
1138+
1139+
fn transform_filtered_measure(
1140+
&self,
1141+
aggr_expr_var: &'static str,
1142+
literal_var: &'static str,
1143+
out_measure_alias_var: &'static str,
1144+
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
1145+
let aggr_expr_var = var!(aggr_expr_var);
1146+
let literal_var = var!(literal_var);
1147+
let out_measure_alias_var = var!(out_measure_alias_var);
1148+
1149+
move |egraph, subst| {
1150+
match &egraph[subst[literal_var]].data.constant {
1151+
Some(ConstantFolding::Scalar(_)) => {
1152+
// Do nothing
1153+
}
1154+
_ => {
1155+
return false;
1156+
}
1157+
}
1158+
1159+
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1160+
// TODO support both changing agg fun and add filter
1161+
1162+
let Some(alias) = original_expr_name(egraph, subst[aggr_expr_var]) else {
1163+
return false;
1164+
};
1165+
let alias_expr_alias =
1166+
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(alias)));
1167+
subst.insert(out_measure_alias_var, alias_expr_alias);
1168+
1169+
true
1170+
}
1171+
}
10471172
}

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,3 +1518,88 @@ async fn wrapper_min_from_avg_measure() {
15181518
}
15191519
);
15201520
}
1521+
1522+
#[tokio::test]
1523+
async fn test_ad_hoc_measure_filter() {
1524+
if !Rewriter::sql_push_down_enabled() {
1525+
return;
1526+
}
1527+
init_testing_logger();
1528+
1529+
let query_plan = convert_select_to_query_plan(
1530+
// language=PostgreSQL
1531+
r#"SELECT
1532+
dim_str0,
1533+
AVG(
1534+
CASE (
1535+
(
1536+
CAST(TRUNC(EXTRACT(YEAR FROM dim_date0)) AS INTEGER) = 2024
1537+
)
1538+
AND
1539+
(
1540+
CAST(TRUNC(EXTRACT(MONTH FROM dim_date0)) AS INTEGER) <= 11
1541+
)
1542+
)
1543+
WHEN TRUE
1544+
THEN avgPrice
1545+
ELSE NULL
1546+
END
1547+
),
1548+
SUM(
1549+
CASE (dim_str1 = 'foo')
1550+
WHEN TRUE
1551+
THEN maxPrice
1552+
ELSE NULL
1553+
END
1554+
)
1555+
FROM MultiTypeCube
1556+
GROUP BY
1557+
1
1558+
;"#
1559+
.to_string(),
1560+
DatabaseProtocol::PostgreSQL,
1561+
)
1562+
.await;
1563+
1564+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1565+
println!(
1566+
"Physical plan: {}",
1567+
displayable(physical_plan.as_ref()).indent()
1568+
);
1569+
1570+
assert_eq!(
1571+
query_plan
1572+
.as_logical_plan()
1573+
.find_cube_scan_wrapped_sql()
1574+
.request,
1575+
TransportLoadRequestQuery {
1576+
measures: Some(vec![json!({
1577+
"cubeName": "MultiTypeCube",
1578+
"alias": "avg_case_cast_tr",
1579+
"expr": {
1580+
"type": "PatchMeasure",
1581+
"sourceMeasure": "MultiTypeCube.avgPrice",
1582+
"replaceAggregationType": null,
1583+
"addFilters": [{
1584+
"cubeParams": ["MultiTypeCube"],
1585+
"sql": "(((CAST(TRUNC(EXTRACT(YEAR FROM ${MultiTypeCube.dim_date0})) AS INTEGER) = 2024) AND (CAST(TRUNC(EXTRACT(MONTH FROM ${MultiTypeCube.dim_date0})) AS INTEGER) <= 11)) = TRUE)"
1586+
}],
1587+
},
1588+
"groupingSet": null,
1589+
}).to_string(),]),
1590+
dimensions: Some(vec![json!({
1591+
"cubeName": "MultiTypeCube",
1592+
"alias": "dim_str0",
1593+
"expr": {
1594+
"type": "SqlFunction",
1595+
"cubeParams": ["MultiTypeCube"],
1596+
"sql": "${MultiTypeCube.dim_str0}",
1597+
},
1598+
"groupingSet": null,
1599+
}).to_string(),]),
1600+
segments: Some(vec![]),
1601+
order: Some(vec![]),
1602+
..Default::default()
1603+
}
1604+
);
1605+
}

0 commit comments

Comments
 (0)