diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 3f7e0e037cdf3..7aa24fbdc107e 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -9482,7 +9482,105 @@ ORDER BY "source"."str0" ASC } #[tokio::test] - async fn test_tableau_filter_by_year() { + async fn test_filter_date_part_by_year() { + init_testing_logger(); + + fn assert_expected_result(query_plan: QueryPlan) { + assert_eq!( + query_plan.as_logical_plan().find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]), + dimensions: Some(vec![]), + segments: Some(vec![]), + time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { + dimension: "KibanaSampleDataEcommerce.order_date".to_string(), + granularity: Some("year".to_string()), + date_range: Some(json!(vec![ + "2019-01-01".to_string(), + "2019-12-31".to_string(), + ])), + },]), + order: Some(vec![]), + ..Default::default() + } + ) + } + + assert_expected_result( + convert_select_to_query_plan( + r#" + SELECT + COUNT(*) AS "count", + date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") AS "yr:completedAt:ok" + FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce" + WHERE date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") = 2019 + GROUP BY 2 + ;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await, + ); + + // Same as above, but with string literal. + assert_expected_result( + convert_select_to_query_plan( + r#" + SELECT + COUNT(*) AS "count", + date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") AS "yr:completedAt:ok" + FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce" + WHERE date_part('YEAR', "KibanaSampleDataEcommerce"."order_date") = '2019' + GROUP BY 2 + ;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await, + ) + } + + #[tokio::test] + async fn test_filter_extract_by_year() { + init_testing_logger(); + + let logical_plan = convert_select_to_query_plan( + r#" + SELECT + COUNT(*) AS "count", + EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date") AS "yr:completedAt:ok" + FROM "public"."KibanaSampleDataEcommerce" "KibanaSampleDataEcommerce" + WHERE EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date") = 2019 + GROUP BY 2 + ;"# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await + .as_logical_plan(); + + assert_eq!( + logical_plan.find_cube_scan().request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]), + dimensions: Some(vec![]), + segments: Some(vec![]), + time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension { + dimension: "KibanaSampleDataEcommerce.order_date".to_string(), + granularity: Some("year".to_string()), + date_range: Some(json!(vec![ + "2019-01-01".to_string(), + "2019-12-31".to_string(), + ])), + },]), + order: Some(vec![]), + ..Default::default() + } + ) + } + + #[tokio::test] + async fn test_tableau_filter_extract_by_year() { init_testing_logger(); let logical_plan = convert_select_to_query_plan( diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs index f1525c66b62de..1932938040e61 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs @@ -1668,8 +1668,64 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), ), + // DATE_PART('year', "KibanaSampleDataEcommerce"."order_date") = 2019 transforming_rewrite( "extract-year-equals", + filter_replacer( + binary_expr( + self.fun_expr( + "DatePart", + vec![literal_string("YEAR"), column_expr("?column")], + ), + "=", + literal_expr("?year"), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + self.transform_filter_extract_year_equals( + "?year", + "?column", + "?alias_to_cube", + "?members", + "?member", + "?values", + "?filter_aliases", + ), + ), + // Same as the rule above, but it uses different case for granularity. + // TODO: Remove, whenever we will fix bug with granularity cases. CORE-1761 + transforming_rewrite( + "extract-year-equals-lower-case", + filter_replacer( + binary_expr( + self.fun_expr( + "DatePart", + vec![literal_string("year"), column_expr("?column")], + ), + "=", + literal_expr("?year"), + ), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_member("?member", "FilterMemberOp:inDateRange", "?values"), + self.transform_filter_extract_year_equals( + "?year", + "?column", + "?alias_to_cube", + "?members", + "?member", + "?values", + "?filter_aliases", + ), + ), + // TRUNC(EXTRACT(YEAR FROM "KibanaSampleDataEcommerce"."order_date")) = 2019 + transforming_rewrite( + "extract-trunc-year-equals", filter_replacer( binary_expr( self.fun_expr( @@ -3579,43 +3635,54 @@ impl FilterRules { .collect(); for year in years { for aliases in aliases_es.iter() { - if let ScalarValue::Int64(Some(year)) = year { - if !(1000..=9999).contains(&year) { - continue; - } - - if let Some((member_name, cube)) = Self::filter_member_name( - egraph, - subst, - &meta_context, - alias_to_cube_var, - column_var, - members_var, - &aliases, - ) { - if !cube.contains_member(&member_name) { + let year = match year { + ScalarValue::Int64(Some(year)) => year, + ScalarValue::Int32(Some(year)) => year as i64, + ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => { + if let Ok(year) = year_str.parse::() { + year + } else { continue; } + } + _ => continue, + }; - subst.insert( - member_var, - egraph.add(LogicalPlanLanguage::FilterMemberMember( - FilterMemberMember(member_name.to_string()), - )), - ); - - subst.insert( - values_var, - egraph.add(LogicalPlanLanguage::FilterMemberValues( - FilterMemberValues(vec![ - format!("{}-01-01", year), - format!("{}-12-31", year), - ]), - )), - ); + if !(1000..=9999).contains(&year) { + continue; + } - return true; + if let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) { + if !cube.contains_member(&member_name) { + continue; } + + subst.insert( + member_var, + egraph.add(LogicalPlanLanguage::FilterMemberMember( + FilterMemberMember(member_name.to_string()), + )), + ); + + subst.insert( + values_var, + egraph.add(LogicalPlanLanguage::FilterMemberValues( + FilterMemberValues(vec![ + format!("{}-01-01", year), + format!("{}-12-31", year), + ]), + )), + ); + + return true; } } }