Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 43 additions & 26 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
prelude::JoinType,
};
use egg::{Id, Subst};
use itertools::Itertools;

impl WrapperRules {
pub fn join_rules(&self, rules: &mut Vec<CubeRewrite>) {
Expand Down Expand Up @@ -831,50 +832,66 @@
// * Both inputs depend on a single data source
// * SQL generator for that data source have `expressions/subquery` template
// It could be checked later, in WrappedSelect as well

let left_columns = egraph[subst[left_expr_var]].data.referenced_expr.as_ref();
let Some(left_columns) = left_columns else {
return false;
// TODO For views: check that each member is coming from same data source (or even cube?)

let prepare_columns = |var| {
let columns = egraph[subst[var]].data.referenced_expr.as_ref();
let Some(columns) = columns else {
return Err("Missing referenced_expr");

Check warning on line 840 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs#L840

Added line #L840 was not covered by tests
};
let columns = columns
.iter()
.map(|column| {
let column = match column {
Expr::Column(column) => column.clone(),
_ => return Err("Unexpected expression in referenced_expr"),

Check warning on line 847 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs#L847

Added line #L847 was not covered by tests
};
Ok(column)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(columns)
};
if left_columns.len() != 1 {
return false;

fn prepare_relation(columns: &[Column]) -> Result<&str, &'static str> {
let relation = columns
.iter()
.map(|column| &column.relation)
.all_equal_value();
let Ok(Some(relation)) = relation else {
// Outer Err means there's either no values at all, or more than one different value
// Inner Err means that all referenced_expr are not columns
// Inner None means that all columns are without relation, don't support that ATM
return Err("Relation mismatch");

Check warning on line 864 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/join.rs#L864

Added line #L864 was not covered by tests
};
Ok(relation)
}
let left_column = &left_columns[0];

let right_columns = egraph[subst[right_expr_var]].data.referenced_expr.as_ref();
let Some(right_columns) = right_columns else {
let Ok(left_columns) = prepare_columns(left_expr_var) else {
return false;
};
if right_columns.len() != 1 {
let Ok(left_relation) = prepare_relation(&left_columns) else {
return false;
}
let right_column = &right_columns[0];

let left_column = match left_column {
Expr::Column(column) => column,
_ => return false,
};
let right_column = match right_column {
Expr::Column(column) => column,
_ => return false,
};

// Simple check that column expressions reference different join sides
let Some(left_relation) = left_column.relation.as_ref() else {
let Ok(right_columns) = prepare_columns(right_expr_var) else {
return false;
};
let Some(right_relation) = right_column.relation.as_ref() else {
let Ok(right_relation) = prepare_relation(&right_columns) else {
return false;
};

// Simple check that column expressions reference different join sides
if left_relation == right_relation {
return false;
}

let left_column = left_column.clone();

// Don't check right, as it is already grouped

if !Self::are_join_members_supported(egraph, subst[left_members_var], [&left_column]) {
if !Self::are_join_members_supported(
egraph,
subst[left_members_var],
left_columns.iter(),
) {
return false;
}

Expand Down
77 changes: 77 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,80 @@ GROUP BY
.sql
.contains(r#"\"expr\":\"${KibanaSampleDataEcommerce.sumPrice}\""#));
}

#[tokio::test]
async fn test_join_on_multiple_columns() {
init_testing_logger();

let query_plan = convert_select_to_query_plan(
// language=PostgreSQL
r#"
SELECT
CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims"
FROM MultiTypeCube
INNER JOIN (
SELECT
CAST(dim_str0 AS TEXT) || ' - ' || CAST(dim_str1 AS TEXT) AS "concat_dims",
AVG(avgPrice) AS "avg_price"
FROM MultiTypeCube
GROUP BY
1
ORDER BY
2 DESC NULLS LAST,
1 ASC NULLS FIRST
LIMIT 10
) "grouped"
ON
CAST(MultiTypeCube.dim_str0 AS TEXT) || ' - ' || CAST(MultiTypeCube.dim_str1 AS TEXT)
=
"grouped"."concat_dims"
GROUP BY
1
;
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

let request = query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.request;

assert_eq!(request.ungrouped, None);

assert_eq!(request.subquery_joins.as_ref().unwrap().len(), 1);

let subquery = &request.subquery_joins.unwrap()[0];

assert!(!subquery.sql.contains("ungrouped"));
assert_eq!(subquery.join_type, "INNER");
assert!(subquery
.on
.contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#));
assert!(subquery
.on
.contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#));
assert!(subquery.on.contains(r#" = \"grouped\".\"concat_dims\""#));

// Dimension from ungrouped side
assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains(r#"CAST(${MultiTypeCube.dim_str0} AS STRING)"#));
assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains(r#"CAST(${MultiTypeCube.dim_str1} AS STRING)"#));
}
Loading