diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs index 444697b004283..acd20572df886 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs @@ -349,7 +349,14 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), change_user_member("?user"), - self.transform_change_user_eq("?column", "?literal", "?user"), + self.transform_change_user_eq( + "?column", + "?literal", + "?alias_to_cube", + "?members", + "?filter_aliases", + "?user", + ), ), transforming_rewrite( "change-user-equal-filter", @@ -360,7 +367,35 @@ impl RewriteRules for FilterRules { "?filter_aliases", ), change_user_member("?user"), - self.transform_change_user_eq("?column", "?literal", "?user"), + self.transform_change_user_eq( + "?column", + "?literal", + "?alias_to_cube", + "?members", + "?filter_aliases", + "?user", + ), + ), + transforming_rewrite( + "user-is-not-null-filter", + filter_replacer( + is_not_null_expr(column_expr("?column")), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + filter_replacer( + literal_bool(true), + "?alias_to_cube", + "?members", + "?filter_aliases", + ), + self.transform_user_is_not_null( + "?column", + "?alias_to_cube", + "?members", + "?filter_aliases", + ), ), transforming_rewrite( "join-field-filter-eq", @@ -3308,29 +3343,56 @@ impl FilterRules { &self, column_var: &'static str, literal_var: &'static str, + alias_to_cube_var: &'static str, + members_var: &'static str, + filter_aliases_var: &'static str, change_user_member_var: &'static str, ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { - let column_var = column_var.parse().unwrap(); - let literal_var = literal_var.parse().unwrap(); - let change_user_member_var = change_user_member_var.parse().unwrap(); - + let column_var = var!(column_var); + let literal_var = var!(literal_var); + let alias_to_cube_var = var!(alias_to_cube_var); + let members_var = var!(members_var); + let filter_aliases_var = var!(filter_aliases_var); + let change_user_member_var = var!(change_user_member_var); + let meta_context = self.meta_context.clone(); move |egraph, subst| { - for literal in var_iter!(egraph[subst[literal_var]], LiteralExprValue) { - if let ScalarValue::Utf8(Some(change_user)) = literal { - let specified_user = change_user.clone(); + let literals = var_iter!(egraph[subst[literal_var]], LiteralExprValue) + .cloned() + .collect::>(); + for literal in literals { + let ScalarValue::Utf8(Some(user_name)) = literal else { + continue; + }; - for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn).cloned() { - if column.name.eq_ignore_ascii_case("__user") { - subst.insert( - change_user_member_var, - egraph.add(LogicalPlanLanguage::ChangeUserMemberValue( - ChangeUserMemberValue(specified_user), - )), - ); + let aliases_es = + var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) + .cloned() + .collect::>(); + for aliases in aliases_es { + let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) else { + continue; + }; - return true; - } + let user_member_name = format!("{}.__user", cube.name); + if !member_name.eq_ignore_ascii_case(&user_member_name) { + continue; } + + subst.insert( + change_user_member_var, + egraph.add(LogicalPlanLanguage::ChangeUserMemberValue( + ChangeUserMemberValue(user_name.clone()), + )), + ); + return true; } } @@ -3338,6 +3400,46 @@ impl FilterRules { } } + fn transform_user_is_not_null( + &self, + column_var: &'static str, + alias_to_cube_var: &'static str, + members_var: &'static str, + filter_aliases_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let column_var = var!(column_var); + let alias_to_cube_var = var!(alias_to_cube_var); + let members_var = var!(members_var); + let filter_aliases_var = var!(filter_aliases_var); + let meta_context = self.meta_context.clone(); + move |egraph, subst| { + let aliases_es = var_iter!(egraph[subst[filter_aliases_var]], FilterReplacerAliases) + .cloned() + .collect::>(); + for aliases in aliases_es { + let Some((member_name, cube)) = Self::filter_member_name( + egraph, + subst, + &meta_context, + alias_to_cube_var, + column_var, + members_var, + &aliases, + ) else { + continue; + }; + + let user_member_name = format!("{}.__user", cube.name); + if !member_name.eq_ignore_ascii_case(&user_member_name) { + continue; + } + + return true; + } + false + } + } + // Transform ?expr IN (?literal) to ?expr = ?literal fn transform_filter_in_to_equal( &self, diff --git a/rust/cubesql/cubesql/src/compile/test/test_user_change.rs b/rust/cubesql/cubesql/src/compile/test/test_user_change.rs index f59c8b2ff4cb1..9c475ec4961f9 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_user_change.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_user_change.rs @@ -110,6 +110,36 @@ async fn test_change_user_via_in_filter_thoughtspot() { assert_eq!(cube_scan.request, expected_request); } +#[tokio::test] +async fn test_change_user_via_filter_powerbi() { + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + "SELECT COUNT(*) as cnt FROM KibanaSampleDataEcommerce WHERE NOT __user IS NULL AND __user = 'gopher'".to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let cube_scan = query_plan.as_logical_plan().find_cube_scan(); + + assert_eq!(cube_scan.options.change_user, Some("gopher".to_string())); + + assert_eq!( + cube_scan.request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]), + segments: Some(vec![]), + dimensions: Some(vec![]), + time_dimensions: None, + order: Some(vec![]), + limit: None, + offset: None, + filters: None, + ungrouped: None, + } + ) +} + #[tokio::test] async fn test_change_user_via_filter_and() { let query_plan = convert_select_to_query_plan( @@ -192,6 +222,47 @@ async fn test_user_with_join() { assert_eq!(cube_scan.options.change_user, Some("foo".to_string())) } +#[tokio::test] +async fn test_change_user_via_filter_with_alias() { + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT "k"."cnt" AS "cnt" + FROM ( + SELECT + COUNT(*) AS "cnt", + "__user" AS "user" + FROM "KibanaSampleDataEcommerce" + GROUP BY 2 + ) AS "k" + WHERE "k"."user" = 'gopher' + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let cube_scan = query_plan.as_logical_plan().find_cube_scan(); + + assert_eq!(cube_scan.options.change_user, Some("gopher".to_string())); + + assert_eq!( + cube_scan.request, + V1LoadRequestQuery { + measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string(),]), + segments: Some(vec![]), + dimensions: Some(vec![]), + time_dimensions: None, + order: Some(vec![]), + limit: None, + offset: None, + filters: None, + ungrouped: None, + } + ) +} + /// This should test that query with CubeScanWrapper uses proper change_user for both SQL generation and execution calls #[tokio::test] async fn test_user_change_sql_generation() {