From 6dbf52bc6489ce3c068e3c353c3a1732d0337656 Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Thu, 27 Feb 2025 14:28:41 +0200 Subject: [PATCH] feat(cubesql): Support multiple columns on each side in ungrouped-grouped join condition --- .../src/compile/rewrite/rules/wrapper/join.rs | 69 ++++++++++------- .../compile/test/test_cube_join_grouped.rs | 77 +++++++++++++++++++ 2 files changed, 120 insertions(+), 26 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs index eab7a0d1a4065..1d518814964ae 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs @@ -22,6 +22,7 @@ use datafusion::{ prelude::JoinType, }; use egg::{Id, Subst}; +use itertools::Itertools; impl WrapperRules { pub fn join_rules(&self, rules: &mut Vec) { @@ -831,50 +832,66 @@ impl WrapperRules { // * Both inputs depend on a single data source // * SQL generator for that data source have `expressions/subquery` template // It could be checked later, in WrappedSelect as well - - let left_columns = egraph[subst[left_expr_var]].data.referenced_expr.as_ref(); - let Some(left_columns) = left_columns else { - return false; + // TODO For views: check that each member is coming from same data source (or even cube?) + + let prepare_columns = |var| { + let columns = egraph[subst[var]].data.referenced_expr.as_ref(); + let Some(columns) = columns else { + return Err("Missing referenced_expr"); + }; + let columns = columns + .iter() + .map(|column| { + let column = match column { + Expr::Column(column) => column.clone(), + _ => return Err("Unexpected expression in referenced_expr"), + }; + Ok(column) + }) + .collect::, _>>()?; + Ok(columns) }; - if left_columns.len() != 1 { - return false; + + fn prepare_relation(columns: &[Column]) -> Result<&str, &'static str> { + let relation = columns + .iter() + .map(|column| &column.relation) + .all_equal_value(); + let Ok(Some(relation)) = relation else { + // Outer Err means there's either no values at all, or more than one different value + // Inner Err means that all referenced_expr are not columns + // Inner None means that all columns are without relation, don't support that ATM + return Err("Relation mismatch"); + }; + Ok(relation) } - let left_column = &left_columns[0]; - let right_columns = egraph[subst[right_expr_var]].data.referenced_expr.as_ref(); - let Some(right_columns) = right_columns else { + let Ok(left_columns) = prepare_columns(left_expr_var) else { return false; }; - if right_columns.len() != 1 { + let Ok(left_relation) = prepare_relation(&left_columns) else { return false; - } - let right_column = &right_columns[0]; - - let left_column = match left_column { - Expr::Column(column) => column, - _ => return false, - }; - let right_column = match right_column { - Expr::Column(column) => column, - _ => return false, }; - // Simple check that column expressions reference different join sides - let Some(left_relation) = left_column.relation.as_ref() else { + let Ok(right_columns) = prepare_columns(right_expr_var) else { return false; }; - let Some(right_relation) = right_column.relation.as_ref() else { + let Ok(right_relation) = prepare_relation(&right_columns) else { return false; }; + + // Simple check that column expressions reference different join sides if left_relation == right_relation { return false; } - let left_column = left_column.clone(); - // Don't check right, as it is already grouped - if !Self::are_join_members_supported(egraph, subst[left_members_var], [&left_column]) { + if !Self::are_join_members_supported( + egraph, + subst[left_members_var], + left_columns.iter(), + ) { return false; } diff --git a/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs b/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs index 8595774977059..4f04f3f9483bb 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs @@ -766,3 +766,80 @@ GROUP BY .sql .contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.sumPrice}\""#)); } + +#[tokio::test] +async fn test_join_on_multiple_columns() { + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + // language=PostgreSQL + r#" +SELECT + CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims" +FROM MultiTypeCube +INNER JOIN ( + SELECT + CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims", + AVG(avgPrice) AS "avg_price" + FROM MultiTypeCube + GROUP BY + 1 + ORDER BY + 2 DESC NULLS LAST, + 1 ASC NULLS FIRST + LIMIT 10 +) "grouped" +ON + CAST(MultiTypeCube.dim_str0 AS TEXT) || ' - ' || CAST(MultiTypeCube.dim_str1 AS TEXT) + = + "grouped"."concat_dims" +GROUP BY + 1 +; + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let request = query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .request; + + assert_eq!(request.ungrouped, None); + + assert_eq!(request.subquery_joins.as_ref().unwrap().len(), 1); + + let subquery = &request.subquery_joins.unwrap()[0]; + + assert!(!subquery.sql.contains("ungrouped")); + assert_eq!(subquery.join_type, "INNER"); + assert!(subquery + .on + .contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#)); + assert!(subquery + .on + .contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#)); + assert!(subquery.on.contains(r#" = \"grouped\".\"concat_dims\""#)); + + // Dimension from ungrouped side + assert!(query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .wrapped_sql + .sql + .contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#)); + assert!(query_plan + .as_logical_plan() + .find_cube_scan_wrapped_sql() + .wrapped_sql + .sql + .contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#)); +}