Skip to content

Commit 1e7a241

Browse files
committed
feat(cubesql): Support filter push down for date_part('year', ?col) = ?literal
1 parent 4f5e5dc commit 1e7a241

File tree

2 files changed

+152
-33
lines changed

2 files changed

+152
-33
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9482,7 +9482,85 @@ ORDER BY "source"."str0" ASC
94829482
}
94839483

94849484
#[tokio::test]
9485-
async fn test_tableau_filter_by_year() {
9485+
async fn test_filter_date_part_by_year() {
9486+
init_testing_logger();
9487+
9488+
let logical_plan = convert_select_to_query_plan(
9489+
r#"
9490+
SELECT
9491+
COUNT(*) AS "count",
9492+
date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") AS "yr:completedAt:ok"
9493+
FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce"
9494+
WHERE date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") = 2019
9495+
GROUP BY 2
9496+
;"#
9497+
.to_string(),
9498+
DatabaseProtocol::PostgreSQL,
9499+
)
9500+
.await
9501+
.as_logical_plan();
9502+
9503+
assert_eq!(
9504+
logical_plan.find_cube_scan().request,
9505+
V1LoadRequestQuery {
9506+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
9507+
dimensions: Some(vec![]),
9508+
segments: Some(vec![]),
9509+
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
9510+
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
9511+
granularity: Some("year".to_string()),
9512+
date_range: Some(json!(vec![
9513+
"2019-01-01".to_string(),
9514+
"2019-12-31".to_string(),
9515+
])),
9516+
},]),
9517+
order: Some(vec![]),
9518+
..Default::default()
9519+
}
9520+
)
9521+
}
9522+
9523+
#[tokio::test]
9524+
async fn test_filter_extract_by_year() {
9525+
init_testing_logger();
9526+
9527+
let logical_plan = convert_select_to_query_plan(
9528+
r#"
9529+
SELECT
9530+
COUNT(*) AS "count",
9531+
EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date") AS "yr:completedAt:ok"
9532+
FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce"
9533+
WHERE EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date") = 2019
9534+
GROUP BY 2
9535+
;"#
9536+
.to_string(),
9537+
DatabaseProtocol::PostgreSQL,
9538+
)
9539+
.await
9540+
.as_logical_plan();
9541+
9542+
assert_eq!(
9543+
logical_plan.find_cube_scan().request,
9544+
V1LoadRequestQuery {
9545+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
9546+
dimensions: Some(vec![]),
9547+
segments: Some(vec![]),
9548+
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
9549+
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
9550+
granularity: Some("year".to_string()),
9551+
date_range: Some(json!(vec![
9552+
"2019-01-01".to_string(),
9553+
"2019-12-31".to_string(),
9554+
])),
9555+
},]),
9556+
order: Some(vec![]),
9557+
..Default::default()
9558+
}
9559+
)
9560+
}
9561+
9562+
#[tokio::test]
9563+
async fn test_tableau_filter_extract_by_year() {
94869564
init_testing_logger();
94879565

94889566
let logical_plan = convert_select_to_query_plan(

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,8 +1668,38 @@ impl RewriteRules for FilterRules {
16681668
"?filter_aliases",
16691669
),
16701670
),
1671+
// DATE_PART('year', "KibanaSampleDataEcommerce"."order_date") = 2019
16711672
transforming_rewrite(
16721673
"extract-year-equals",
1674+
filter_replacer(
1675+
binary_expr(
1676+
self.fun_expr(
1677+
"DatePart",
1678+
// LOL, DF plans date_part granularity in lowercase, while Extract is uppercase
1679+
vec![literal_string("year"), column_expr("?column")],
1680+
),
1681+
"=",
1682+
literal_expr("?year"),
1683+
),
1684+
"?alias_to_cube",
1685+
"?members",
1686+
"?filter_aliases",
1687+
),
1688+
filter_member("?member", "FilterMemberOp:inDateRange", "?values"),
1689+
self.transform_filter_extract_year_equals(
1690+
"?year",
1691+
"?column",
1692+
"?alias_to_cube",
1693+
"?members",
1694+
"?member",
1695+
"?values",
1696+
"?filter_aliases",
1697+
),
1698+
),
1699+
// Same as the rule above, but wrapped with TRUNC
1700+
// TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) = 2019
1701+
transforming_rewrite(
1702+
"extract-trunc-year-equals",
16731703
filter_replacer(
16741704
binary_expr(
16751705
self.fun_expr(
@@ -3579,43 +3609,54 @@ impl FilterRules {
35793609
.collect();
35803610
for year in years {
35813611
for aliases in aliases_es.iter() {
3582-
if let ScalarValue::Int64(Some(year)) = year {
3583-
if !(1000..=9999).contains(&year) {
3584-
continue;
3585-
}
3586-
3587-
if let Some((member_name, cube)) = Self::filter_member_name(
3588-
egraph,
3589-
subst,
3590-
&meta_context,
3591-
alias_to_cube_var,
3592-
column_var,
3593-
members_var,
3594-
&aliases,
3595-
) {
3596-
if !cube.contains_member(&member_name) {
3597-
continue;
3612+
let year = match year {
3613+
ScalarValue::Int64(Some(year)) => year,
3614+
ScalarValue::Int32(Some(year)) => year as i64,
3615+
ScalarValue::Utf8(Some(year_str)) if year_str.len() == 4 => {
3616+
if let Ok(year) = year_str.parse::<i64>() {
3617+
return year;
35983618
}
35993619

3600-
subst.insert(
3601-
member_var,
3602-
egraph.add(LogicalPlanLanguage::FilterMemberMember(
3603-
FilterMemberMember(member_name.to_string()),
3604-
)),
3605-
);
3620+
continue;
3621+
} ,
3622+
_ => continue,
3623+
};
36063624

3607-
subst.insert(
3608-
values_var,
3609-
egraph.add(LogicalPlanLanguage::FilterMemberValues(
3610-
FilterMemberValues(vec![
3611-
format!("{}-01-01", year),
3612-
format!("{}-12-31", year),
3613-
]),
3614-
)),
3615-
);
3625+
if !(1000..=9999).contains(&year) {
3626+
continue;
3627+
}
36163628

3617-
return true;
3629+
if let Some((member_name, cube)) = Self::filter_member_name(
3630+
egraph,
3631+
subst,
3632+
&meta_context,
3633+
alias_to_cube_var,
3634+
column_var,
3635+
members_var,
3636+
&aliases,
3637+
) {
3638+
if !cube.contains_member(&member_name) {
3639+
continue;
36183640
}
3641+
3642+
subst.insert(
3643+
member_var,
3644+
egraph.add(LogicalPlanLanguage::FilterMemberMember(
3645+
FilterMemberMember(member_name.to_string()),
3646+
)),
3647+
);
3648+
3649+
subst.insert(
3650+
values_var,
3651+
egraph.add(LogicalPlanLanguage::FilterMemberValues(
3652+
FilterMemberValues(vec![
3653+
format!("{}-01-01", year),
3654+
format!("{}-12-31", year),
3655+
]),
3656+
)),
3657+
);
3658+
3659+
return true;
36193660
}
36203661
}
36213662
}

0 commit comments

Comments
 (0)