Skip to content

Commit 3297d5d

Browse files
committed
[WIP] fix(cubesql): Make cube join check stricter
Now it should catch plans like Join(CubeScan, Projection(CubeScan))
1 parent fd4ccff commit 3297d5d

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

rust/cubesql/cubesql/src/compile/rewrite/converter.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,10 +1350,10 @@ impl LanguageToLogicalPlanConverter {
13501350
LogicalPlanLanguage::Join(params) => {
13511351
let left_on = match_data_node!(node_by_id, params[2], JoinLeftOn);
13521352
let right_on = match_data_node!(node_by_id, params[3], JoinRightOn);
1353-
let left = self.to_logical_plan(params[0]);
1354-
let right = self.to_logical_plan(params[1]);
1353+
let left = self.to_logical_plan(params[0])?;
1354+
let right = self.to_logical_plan(params[1])?;
13551355

1356-
if self.is_cube_scan_node(params[0]) && self.is_cube_scan_node(params[1]) {
1356+
if Self::have_cube_scan_inside(&left) && Self::have_cube_scan_inside(&right) {
13571357
if left_on.iter().any(|c| c.name == "__cubeJoinField")
13581358
|| right_on.iter().any(|c| c.name == "__cubeJoinField")
13591359
{
@@ -1370,8 +1370,8 @@ impl LanguageToLogicalPlanConverter {
13701370
}
13711371
}
13721372

1373-
let left = Arc::new(left?);
1374-
let right = Arc::new(right?);
1373+
let left = Arc::new(left);
1374+
let right = Arc::new(right);
13751375

13761376
let join_type = match_data_node!(node_by_id, params[4], JoinJoinType);
13771377
let join_constraint = match_data_node!(node_by_id, params[5], JoinJoinConstraint);
@@ -1394,7 +1394,10 @@ impl LanguageToLogicalPlanConverter {
13941394
})
13951395
}
13961396
LogicalPlanLanguage::CrossJoin(params) => {
1397-
if self.is_cube_scan_node(params[0]) && self.is_cube_scan_node(params[1]) {
1397+
let left = self.to_logical_plan(params[0])?;
1398+
let right = self.to_logical_plan(params[1])?;
1399+
1400+
if Self::have_cube_scan_inside(&left) && Self::have_cube_scan_inside(&right) {
13981401
return Err(CubeError::internal(
13991402
"Can not join Cubes. This is most likely due to one of the following reasons:\n\
14001403
• one of the cubes contains a group by\n\
@@ -1403,8 +1406,8 @@ impl LanguageToLogicalPlanConverter {
14031406
));
14041407
}
14051408

1406-
let left = Arc::new(self.to_logical_plan(params[0])?);
1407-
let right = Arc::new(self.to_logical_plan(params[1])?);
1409+
let left = Arc::new(left);
1410+
let right = Arc::new(right);
14081411
let schema = Arc::new(left.schema().join(right.schema())?);
14091412

14101413
LogicalPlan::CrossJoin(CrossJoin {
@@ -2287,16 +2290,18 @@ impl LanguageToLogicalPlanConverter {
22872290
})
22882291
}
22892292

2290-
fn is_cube_scan_node(&self, node_id: Id) -> bool {
2291-
let node_by_id = &self.best_expr;
2292-
match node_by_id.index(node_id) {
2293-
LogicalPlanLanguage::CubeScan(_) | LogicalPlanLanguage::CubeScanWrapper(_) => {
2294-
return true
2295-
}
2296-
_ => (),
2293+
fn have_cube_scan_inside(node: &LogicalPlan) -> bool {
2294+
match node {
2295+
LogicalPlan::Projection(Projection { input, .. })
2296+
| LogicalPlan::Aggregate(Aggregate { input, .. })
2297+
| LogicalPlan::Filter(Filter { input, .. })
2298+
| LogicalPlan::Sort(Sort { input, .. })
2299+
| LogicalPlan::Limit(Limit { input, .. }) => Self::have_cube_scan_inside(input),
2300+
LogicalPlan::Extension(Extension { node }) => {
2301+
node.as_any().is::<CubeScanNode>() || node.as_any().is::<CubeScanWrapperNode>()
2302+
}
2303+
_ => false,
22972304
}
2298-
2299-
return false;
23002305
}
23012306
}
23022307

rust/cubesql/cubesql/src/compile/test/test_cube_join.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ async fn test_join_cubes_on_wrong_field_error() {
497497
let query = convert_sql_to_cube_query(
498498
&r#"
499499
SELECT *
500-
FROM KibanaSampleDataEcommerce
501-
LEFT JOIN Logs ON (KibanaSampleDataEcommerce.has_subscription = Logs.read)
500+
FROM (SELECT customer_gender, has_subscription FROM KibanaSampleDataEcommerce) kibana
501+
LEFT JOIN (SELECT read, content FROM Logs) logs ON (kibana.has_subscription = logs.read)
502502
"#
503503
.to_string(),
504504
meta.clone(),
@@ -567,6 +567,7 @@ async fn test_join_cubes_with_aggr_error() {
567567
)
568568
}
569569

570+
// TODO it seems this query should not execute: it has join of grouped CubeScan with ungrouped CubeScan by __cubeJoinField
570571
#[tokio::test]
571572
async fn test_join_cubes_with_postprocessing() {
572573
if !Rewriter::sql_push_down_enabled() {
@@ -621,6 +622,7 @@ async fn test_join_cubes_with_postprocessing() {
621622
)
622623
}
623624

625+
// TODO it seems this query should not execute: it has join of grouped CubeScan with ungrouped CubeScan, and we explicitly try to forbid that
624626
#[tokio::test]
625627
async fn test_join_cubes_with_postprocessing_and_no_cubejoinfield() {
626628
if !Rewriter::sql_push_down_enabled() {

0 commit comments

Comments
 (0)