Skip to content

Commit 6dbf52b

Browse files
committed
feat(cubesql): Support multiple columns on each side in ungrouped-grouped join condition
1 parent 9ec4240 commit 6dbf52b

File tree

2 files changed

+120
-26
lines changed

2 files changed

+120
-26
lines changed

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use datafusion::{
2222
prelude::JoinType,
2323
};
2424
use egg::{Id, Subst};
25+
use itertools::Itertools;
2526

2627
impl WrapperRules {
2728
pub fn join_rules(&self, rules: &mut Vec<CubeRewrite>) {
@@ -831,50 +832,66 @@ impl WrapperRules {
831832
// * Both inputs depend on a single data source
832833
// * SQL generator for that data source have `expressions/subquery` template
833834
// It could be checked later, in WrappedSelect as well
834-
835-
let left_columns = egraph[subst[left_expr_var]].data.referenced_expr.as_ref();
836-
let Some(left_columns) = left_columns else {
837-
return false;
835+
// TODO For views: check that each member is coming from same data source (or even cube?)
836+
837+
let prepare_columns = |var| {
838+
let columns = egraph[subst[var]].data.referenced_expr.as_ref();
839+
let Some(columns) = columns else {
840+
return Err("Missing referenced_expr");
841+
};
842+
let columns = columns
843+
.iter()
844+
.map(|column| {
845+
let column = match column {
846+
Expr::Column(column) => column.clone(),
847+
_ => return Err("Unexpected expression in referenced_expr"),
848+
};
849+
Ok(column)
850+
})
851+
.collect::<Result<Vec<_>, _>>()?;
852+
Ok(columns)
838853
};
839-
if left_columns.len() != 1 {
840-
return false;
854+
855+
fn prepare_relation(columns: &[Column]) -> Result<&str, &'static str> {
856+
let relation = columns
857+
.iter()
858+
.map(|column| &column.relation)
859+
.all_equal_value();
860+
let Ok(Some(relation)) = relation else {
861+
// Outer Err means there's either no values at all, or more than one different value
862+
// Inner Err means that all referenced_expr are not columns
863+
// Inner None means that all columns are without relation, don't support that ATM
864+
return Err("Relation mismatch");
865+
};
866+
Ok(relation)
841867
}
842-
let left_column = &left_columns[0];
843868

844-
let right_columns = egraph[subst[right_expr_var]].data.referenced_expr.as_ref();
845-
let Some(right_columns) = right_columns else {
869+
let Ok(left_columns) = prepare_columns(left_expr_var) else {
846870
return false;
847871
};
848-
if right_columns.len() != 1 {
872+
let Ok(left_relation) = prepare_relation(&left_columns) else {
849873
return false;
850-
}
851-
let right_column = &right_columns[0];
852-
853-
let left_column = match left_column {
854-
Expr::Column(column) => column,
855-
_ => return false,
856-
};
857-
let right_column = match right_column {
858-
Expr::Column(column) => column,
859-
_ => return false,
860874
};
861875

862-
// Simple check that column expressions reference different join sides
863-
let Some(left_relation) = left_column.relation.as_ref() else {
876+
let Ok(right_columns) = prepare_columns(right_expr_var) else {
864877
return false;
865878
};
866-
let Some(right_relation) = right_column.relation.as_ref() else {
879+
let Ok(right_relation) = prepare_relation(&right_columns) else {
867880
return false;
868881
};
882+
883+
// Simple check that column expressions reference different join sides
869884
if left_relation == right_relation {
870885
return false;
871886
}
872887

873-
let left_column = left_column.clone();
874-
875888
// Don't check right, as it is already grouped
876889

877-
if !Self::are_join_members_supported(egraph, subst[left_members_var], [&left_column]) {
890+
if !Self::are_join_members_supported(
891+
egraph,
892+
subst[left_members_var],
893+
left_columns.iter(),
894+
) {
878895
return false;
879896
}
880897

rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,3 +766,80 @@ GROUP BY
766766
.sql
767767
.contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.sumPrice}\""#));
768768
}
769+
770+
#[tokio::test]
771+
async fn test_join_on_multiple_columns() {
772+
init_testing_logger();
773+
774+
let query_plan = convert_select_to_query_plan(
775+
// language=PostgreSQL
776+
r#"
777+
SELECT
778+
CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims"
779+
FROM MultiTypeCube
780+
INNER JOIN (
781+
SELECT
782+
CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims",
783+
AVG(avgPrice) AS "avg_price"
784+
FROM MultiTypeCube
785+
GROUP BY
786+
1
787+
ORDER BY
788+
2 DESC NULLS LAST,
789+
1 ASC NULLS FIRST
790+
LIMIT 10
791+
) "grouped"
792+
ON
793+
CAST(MultiTypeCube.dim_str0 AS TEXT) || ' - ' || CAST(MultiTypeCube.dim_str1 AS TEXT)
794+
=
795+
"grouped"."concat_dims"
796+
GROUP BY
797+
1
798+
;
799+
"#
800+
.to_string(),
801+
DatabaseProtocol::PostgreSQL,
802+
)
803+
.await;
804+
805+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
806+
println!(
807+
"Physical plan: {}",
808+
displayable(physical_plan.as_ref()).indent()
809+
);
810+
811+
let request = query_plan
812+
.as_logical_plan()
813+
.find_cube_scan_wrapped_sql()
814+
.request;
815+
816+
assert_eq!(request.ungrouped, None);
817+
818+
assert_eq!(request.subquery_joins.as_ref().unwrap().len(), 1);
819+
820+
let subquery = &request.subquery_joins.unwrap()[0];
821+
822+
assert!(!subquery.sql.contains("ungrouped"));
823+
assert_eq!(subquery.join_type, "INNER");
824+
assert!(subquery
825+
.on
826+
.contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#));
827+
assert!(subquery
828+
.on
829+
.contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#));
830+
assert!(subquery.on.contains(r#" = \"grouped\".\"concat_dims\""#));
831+
832+
// Dimension from ungrouped side
833+
assert!(query_plan
834+
.as_logical_plan()
835+
.find_cube_scan_wrapped_sql()
836+
.wrapped_sql
837+
.sql
838+
.contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#));
839+
assert!(query_plan
840+
.as_logical_plan()
841+
.find_cube_scan_wrapped_sql()
842+
.wrapped_sql
843+
.sql
844+
.contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#));
845+
}

0 commit comments

Comments
 (0)