Skip to content

Commit 28c1e3b

Browse files
authored
feat: Support complex join conditions for grouped joins (#9157)
Add support for complex join conditions for grouped joins DataFusion plans non-trivial joins (ones that are not `l.column = r.column`) as `Filter(CrossJoin(...))` To support ungrouped-grouped joins with queries like this SQL API needs to rewrite logical plan like that to `WrappedSelect` with join inside. To do that it need to distinguish between plan coming from regular `JOIN` and actual `CROSS JOIN` with `WHERE` on top. This is done with new `JoinCheckStage`: it starts on `Filter(CrossJoin(wrapper, wrapper))`, traverses all `AND`s in filter condition, checks that "leaves" in condition are comparing two join sides, and pulls up that fact. After that regular join rewrite can start on checked condition. Supporting changes: * Allow grouped join sides to have different in_projection flag * Allow non-push_to_cube WrappedSelect in grouped subquery position in join * Make zero members wrapper more expensive than filter member * Replace alias to cube during wrapper pull up * Wrap is_null expressions in parens, to avoid operator precedence issues Expression like `(foo IS NOT NULL = bar IS NOT NULL)`` would try to compare `foo IS NOT NULL` with `bar`, not with `bar IS NOT NULL`
1 parent 4e9ed0e commit 28c1e3b

File tree

13 files changed

+1016
-40
lines changed

13 files changed

+1016
-40
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3327,7 +3327,7 @@ export class BaseQuery {
33273327
column_aliased: '{{expr}} {{quoted_alias}}',
33283328
query_aliased: '{{ query }} AS {{ quoted_alias }}',
33293329
case: 'CASE{% if expr %} {{ expr }}{% endif %}{% for when, then in when_then %} WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END',
3330-
is_null: '{{ expr }} IS {% if negate %}NOT {% endif %}NULL',
3330+
is_null: '({{ expr }} IS {% if negate %}NOT {% endif %}NULL)',
33313331
binary: '({{ left }} {{ op }} {{ right }})',
33323332
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %} NULLS {% if nulls_first %}FIRST{% else %}LAST{% endif %}',
33333333
order_by: '{% if index %} {{ index }} {% else %} {{ expr }} {% endif %} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',

packages/cubejs-testing/test/__snapshots__/smoke-cubesql.test.ts.snap

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ Array [
7070
]
7171
`;
7272

73+
exports[`SQL API Postgres (Data) join with grouped query on coalesce: join grouped on coalesce 1`] = `
74+
Array [
75+
Object {
76+
"count": "2",
77+
"status": "processed",
78+
},
79+
Object {
80+
"count": "1",
81+
"status": "shipped",
82+
},
83+
]
84+
`;
85+
7386
exports[`SQL API Postgres (Data) join with grouped query: join grouped 1`] = `
7487
Array [
7588
Object {

packages/cubejs-testing/test/smoke-cubesql.test.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,36 @@ filter_subq AS (
535535
expect(res.rows).toMatchSnapshot('join grouped with filter');
536536
});
537537

538+
test('join with grouped query on coalesce', async () => {
539+
const query = `
540+
SELECT
541+
"Orders".status AS status,
542+
COUNT(*) AS count
543+
FROM
544+
"Orders"
545+
INNER JOIN
546+
(
547+
SELECT
548+
status,
549+
SUM(totalAmount)
550+
FROM
551+
"Orders"
552+
GROUP BY 1
553+
ORDER BY 2 DESC
554+
LIMIT 2
555+
) top_orders
556+
ON
557+
(COALESCE("Orders".status, '') = COALESCE(top_orders.status, '')) AND
558+
(("Orders".status IS NOT NULL) = (top_orders.status IS NOT NULL))
559+
GROUP BY 1
560+
ORDER BY 1
561+
`;
562+
563+
const res = await connection.query(query);
564+
// Expect only top statuses 2 by total amount: processed and shipped
565+
expect(res.rows).toMatchSnapshot('join grouped on coalesce');
566+
});
567+
538568
test('where segment is false', async () => {
539569
const query =
540570
'SELECT value AS val, * FROM "SegmentTest" WHERE segment_eq_1 IS FALSE ORDER BY value;';

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,14 @@ impl CubeScanWrapperNode {
994994
"Unsupported ungrouped CubeScan as join subquery: {join_cube_scan:?}"
995995
)));
996996
}
997+
} else if let Some(wrapped_select) =
998+
node.as_any().downcast_ref::<WrappedSelectNode>()
999+
{
1000+
if wrapped_select.push_to_cube {
1001+
return Err(CubeError::internal(format!(
1002+
"Unsupported push_to_cube WrappedSelect as join subquery: {wrapped_select:?}"
1003+
)));
1004+
}
9971005
} else {
9981006
// TODO support more grouped cases here
9991007
return Err(CubeError::internal(format!(

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12188,7 +12188,7 @@ ORDER BY "source"."str0" ASC
1218812188
}
1218912189
init_testing_logger();
1219012190

12191-
let logical_plan = convert_select_to_query_plan(
12191+
let query_plan = convert_select_to_query_plan(
1219212192
r#"
1219312193
WITH "qt_0" AS (
1219412194
SELECT "ta_1"."customer_gender" "ca_1"
@@ -12205,13 +12205,20 @@ ORDER BY "source"."str0" ASC
1220512205
.to_string(),
1220612206
DatabaseProtocol::PostgreSQL,
1220712207
)
12208-
.await
12209-
.as_logical_plan();
12208+
.await;
12209+
12210+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
12211+
println!(
12212+
"Physical plan: {}",
12213+
displayable(physical_plan.as_ref()).indent()
12214+
);
12215+
12216+
let logical_plan = query_plan.as_logical_plan();
1221012217

1221112218
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
1221212219

12213-
// check wrapping for `NOT(.. IS NULL OR LOWER(..) IN)`
12214-
let re = Regex::new(r"NOT \(.+ IS NULL OR .*LOWER\(.+ IN ").unwrap();
12220+
// check wrapping for `NOT((.. IS NULL) OR LOWER(..) IN)`
12221+
let re = Regex::new(r"NOT \(\(.+ IS NULL\) OR .*LOWER\(.+ IN ").unwrap();
1221512222
assert!(re.is_match(&sql));
1221612223
}
1221712224

rust/cubesql/cubesql/src/compile/rewrite/cost.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ impl BestCubePlan {
122122
LogicalPlanLanguage::ProjectionSplitPushDownReplacer(_) => 1,
123123
LogicalPlanLanguage::ProjectionSplitPullUpReplacer(_) => 1,
124124
LogicalPlanLanguage::QueryParam(_) => 1,
125+
LogicalPlanLanguage::JoinCheckStage(_) => 1,
126+
LogicalPlanLanguage::JoinCheckPushDown(_) => 1,
127+
LogicalPlanLanguage::JoinCheckPullUp(_) => 1,
125128
// Not really replacers but those should be deemed as mandatory rewrites and as soon as
126129
// there's always rewrite rule it's fine to have replacer cost.
127130
// Needs to be added as alias rewrite always more expensive than original function.
@@ -234,6 +237,7 @@ impl BestCubePlan {
234237
/// - `empty_wrappers` > `non_detected_cube_scans` - we don't want empty wrapper to hide non detected cube scan errors
235238
/// - `non_detected_cube_scans` > other nodes - minimize cube scans without members
236239
/// - `filters` > `filter_members` - optimize for push down of filters
240+
/// - `zero_members_wrapper` > `filter_members` - prefer CubeScan(filters) to WrappedSelect(CubeScan(*), filters)
237241
/// - `filter_members` > `cube_members` - optimize for `inDateRange` filter push down to time dimension
238242
/// - `member_errors` > `cube_members` - extra cube members may be required (e.g. CASE)
239243
/// - `member_errors` > `wrapper_nodes` - use SQL push down where possible if cube scan can't be detected
@@ -259,12 +263,12 @@ pub struct CubePlanCost {
259263
wrapped_select_ungrouped_scan: usize,
260264
filters: i64,
261265
structure_points: i64,
262-
filter_members: i64,
263266
// This is separate from both non_detected_cube_scans and cube_members
264267
// Because it's ok to use all members inside wrapper (so non_detected_cube_scans would be zero)
265268
// And we want to select representation with less members
266269
// But only when members are present!
267270
zero_members_wrapper: i64,
271+
filter_members: i64,
268272
cube_members: i64,
269273
errors: i64,
270274
time_dimensions_used_as_dimensions: i64,

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,19 @@ crate::plan_to_language! {
514514
QueryParam {
515515
index: usize,
516516
},
517+
JoinCheckStage {
518+
expr: Arc<Expr>,
519+
},
520+
JoinCheckPushDown {
521+
expr: Arc<Expr>,
522+
left_input: Arc<LogicalPlan>,
523+
right_input: Arc<LogicalPlan>,
524+
},
525+
JoinCheckPullUp {
526+
expr: Arc<Expr>,
527+
left_input: Arc<LogicalPlan>,
528+
right_input: Arc<LogicalPlan>,
529+
},
517530
}
518531
}
519532

@@ -2152,6 +2165,18 @@ fn distinct(input: impl Display) -> String {
21522165
format!("(Distinct {})", input)
21532166
}
21542167

2168+
fn join_check_stage(expr: impl Display) -> String {
2169+
format!("(JoinCheckStage {expr})")
2170+
}
2171+
2172+
fn join_check_push_down(expr: impl Display, left: impl Display, right: impl Display) -> String {
2173+
format!("(JoinCheckPushDown {expr} {left} {right})")
2174+
}
2175+
2176+
fn join_check_pull_up(expr: impl Display, left: impl Display, right: impl Display) -> String {
2177+
format!("(JoinCheckPullUp {expr} {left} {right})")
2178+
}
2179+
21552180
pub fn original_expr_name(egraph: &CubeEGraph, id: Id) -> Option<String> {
21562181
egraph[id]
21572182
.data

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ impl MemberRules {
14371437
}
14381438
}
14391439

1440-
fn replace_alias(
1440+
pub fn replace_alias(
14411441
alias_to_cube: &Vec<(String, String)>,
14421442
projection_alias: &Option<String>,
14431443
) -> Vec<(String, String)> {

0 commit comments

Comments
 (0)