diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 87abdf0b772f9..fd86bf2f21b47 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -17815,4 +17815,72 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), } ) } + + #[tokio::test] + async fn test_tableau_trunc_extract_year_and_month() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let logical_plan = convert_select_to_query_plan( + r#" + SELECT SUM("KibanaSampleDataEcommerce"."sumPrice") AS "sum:sumPrice:ok" + FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce" + WHERE ( + "KibanaSampleDataEcommerce"."id" != 0 + AND CAST(TRUNC(EXTRACT(MONTH FROM "KibanaSampleDataEcommerce"."order_date")) AS INTEGER) = 2 + AND CAST(TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) AS INTEGER) = 2024 + AND "KibanaSampleDataEcommerce"."customer_gender" IS NOT NULL + ) + HAVING COUNT(1) > 0 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await + .as_logical_plan(); + + assert_eq!( + logical_plan.find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.sumPrice".to_string(),]), + dimensions: Some(vec![]), + segments: Some(vec![]), + time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { + dimension: "KibanaSampleDataEcommerce.order_date".to_string(), + granularity: None, + date_range: Some(json!(vec![ + "2024-02-01".to_string(), + "2024-02-29".to_string(), + ])), + }]), + order: Some(vec![]), + filters: Some(vec![ + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.id".to_string()), + operator: Some("notEquals".to_string()), + values: Some(vec!["0".to_string()]), + or: None, + and: None, + }, + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.customer_gender".to_string()), + operator: Some("set".to_string()), + values: None, + or: None, + and: None, + }, + V1LoadRequestQueryFilterItem { + member: Some("KibanaSampleDataEcommerce.count".to_string()), + operator: Some("gt".to_string()), + values: Some(vec!["0".to_string()]), + or: None, + and: None, + }, + ]), + ..Default::default() + } + ) + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs index 99cf86cfb7cfc..2174ad74da60c 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs @@ -37,7 +37,7 @@ use chrono::{ Numeric::{Day, Hour, Minute, Month, Second, Year}, Pad::Zero, }, - DateTime, Datelike, Days, Duration, Months, NaiveDateTime, Timelike, Weekday, + DateTime, Datelike, Days, Duration, Months, NaiveDate, NaiveDateTime, Timelike, Weekday, }; use cubeclient::models::V1CubeMeta; use datafusion::{ @@ -1685,8 +1685,9 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), filter_member("?member", "FilterMemberOp:inDateRange", "?values"), - self.transform_filter_extract_year_equals( + self.transform_filter_extract_year_month_equals( "?year", + None, "?column", "?alias_to_cube", "?members", @@ -1715,8 +1716,9 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), filter_member("?member", "FilterMemberOp:inDateRange", "?values"), - self.transform_filter_extract_year_equals( + self.transform_filter_extract_year_month_equals( "?year", + None, "?column", "?alias_to_cube", "?members", @@ -1725,6 +1727,125 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), ), + // TRUNC(EXTRACT(MONTH FROM "KibanaSampleDataEcommerce"."order_date")) = 3 + // AND TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) = 2019 + transforming_rewrite( + "extract-trunc-year-and-month-equals", + filter_replacer( + binary_expr( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("month"), column_expr("?column")], + )], + ), + "=", + literal_expr("?month"), + ), + "AND", + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("year"), column_expr("?column")], + )], + ), + "=", + literal_expr("?year"), + ), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + self.transform_filter_extract_year_month_equals( + "?year", + Some("?month"), + "?column", + "?alias_to_cube", + "?members", + "?member", + "?values", + "?filter_aliases", + ), + ), + // When the filter set above is paired with other filters, it needs to be + // regrouped for the above rewrite rule to match + rewrite( + "extract-trunc-year-and-month-equals-regroup-binary", + filter_replacer( + binary_expr( + binary_expr( + "?expr", + "AND", + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("month"), column_expr("?column")], + )], + ), + "=", + literal_expr("?month"), + ), + ), + "AND", + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("year"), column_expr("?column")], + )], + ), + "=", + literal_expr("?year"), + ), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_replacer( + binary_expr( + "?expr", + "AND", + binary_expr( + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("month"), column_expr("?column")], + )], + ), + "=", + literal_expr("?month"), + ), + "AND", + binary_expr( + self.fun_expr( + "Trunc", + vec![self.fun_expr( + "DatePart", + vec![literal_string("year"), column_expr("?column")], + )], + ), + "=", + literal_expr("?year"), + ), + ), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + ), transforming_rewrite( "filter-date-trunc-sub-leeq", filter_replacer( @@ -3576,9 +3697,10 @@ impl FilterRules { } } - fn transform_filter_extract_year_equals( + fn transform_filter_extract_year_month_equals( &self, year_var: &'static str, + month_var: Option<&'static str>, column_var: &'static str, alias_to_cube_var: &'static str, members_var: &'static str, @@ -3587,6 +3709,7 @@ impl FilterRules { filter_aliases_var: &'static str, ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { let year_var = var!(year_var); + let month_var = month_var.map(|var| var!(var)); let column_var = var!(column_var); let alias_to_cube_var = var!(alias_to_cube_var); let members_var = var!(members_var); @@ -3595,70 +3718,117 @@ impl FilterRules { let filter_aliases_var = var!(filter_aliases_var); let meta_context = self.meta_context.clone(); move |egraph, subst| { - let years: Vec = var_iter!(egraph[subst[year_var]], LiteralExprValue) - .cloned() - .collect(); - if years.is_empty() { - return false; - } - let aliases_es: Vec> = - var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) - .cloned() - .collect(); - for year in years { - for aliases in aliases_es.iter() { + let Some(year) = + var_iter!(egraph[subst[year_var]], LiteralExprValue).find_map(|year| { let year = match year { - ScalarValue::Int64(Some(year)) => year, - ScalarValue::Int32(Some(year)) => year as i64, - ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(&year) => { + ScalarValue::Int64(Some(year)) => *year, + ScalarValue::Int32(Some(year)) => *year as i64, + ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(year) => { year.round() as i64 } ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => { if let Ok(year) = year_str.parse::() { year } else { - continue; + return None; } } - _ => continue, + _ => return None, }; - if !(1000..=9999).contains(&year) { - continue; + return None; } + Some(year as i32) + }) + else { + return false; + }; - if let Some((member_name, cube)) = Self::filter_member_name( - egraph, - subst, - &meta_context, - alias_to_cube_var, - column_var, - members_var, - &aliases, - ) { - if !cube.contains_member(&member_name) { - continue; + let month = if let Some(month_var) = month_var { + let month = + var_iter!(egraph[subst[month_var]], LiteralExprValue).find_map(|month| { + let month = match month { + ScalarValue::Int64(Some(month)) => *month, + ScalarValue::Int32(Some(month)) => *month as i64, + ScalarValue::Float64(Some(month)) if (1.0..=12.0).contains(month) => { + month.round() as i64 + } + ScalarValue::Utf8(Some(ref month_str)) + if (1..=2).contains(&month_str.len()) => + { + if let Ok(month) = month_str.parse::() { + month + } else { + return None; + } + } + _ => return None, + }; + if !(1..=12).contains(&month) { + return None; } + Some(month as u32) + }); + if month.is_none() { + return false; + } + month + } else { + None + }; - subst.insert( - member_var, - egraph.add(LogicalPlanLanguage::FilterMemberMember( - FilterMemberMember(member_name.to_string()), - )), - ); - - subst.insert( - values_var, - egraph.add(LogicalPlanLanguage::FilterMemberValues( - FilterMemberValues(vec![ - format!("{}-01-01", year), - format!("{}-12-31", year), - ]), - )), - ); + let last_day = { + let month = month.unwrap_or(12); + let next_month = if month == 12 { 1 } else { month + 1 }; + let next_month_year = if month == 12 { year + 1 } else { year }; + let Some(next_month_first_date) = + NaiveDate::from_ymd_opt(next_month_year, next_month, 1) + else { + return false; + }; + let Some(last_day_date) = next_month_first_date.checked_sub_days(Days::new(1)) + else { + return false; + }; + last_day_date.day() + }; - return true; + let aliases_es: Vec> = + var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) + .cloned() + .collect(); + for aliases in aliases_es.iter() { + if let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) { + if !cube.contains_member(&member_name) { + continue; } + + subst.insert( + member_var, + egraph.add(LogicalPlanLanguage::FilterMemberMember(FilterMemberMember( + member_name.to_string(), + ))), + ); + + let date_range_start = format!("{}-{:0>2}-01", year, month.unwrap_or(1)); + let date_range_end = + format!("{}-{:0>2}-{}", year, month.unwrap_or(12), last_day); + subst.insert( + values_var, + egraph.add(LogicalPlanLanguage::FilterMemberValues(FilterMemberValues( + vec![date_range_start, date_range_end], + ))), + ); + + return true; } }