Skip to content

Commit 80e4a60

Browse files
authored
feat(cubesql): CASE WHEN SQL push down (#7029)
* feat(cubesql): CASE WHEN SQL push down * Enable SQL push down for legacy tests
1 parent 0e8a76a commit 80e4a60

File tree

7 files changed

+312
-21
lines changed

7 files changed

+312
-21
lines changed

.github/workflows/rust-cubesql.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ jobs:
118118
env:
119119
CUBESQL_TESTING_CUBE_TOKEN: ${{ secrets.CUBESQL_TESTING_CUBE_TOKEN }}
120120
CUBESQL_TESTING_CUBE_URL: ${{ secrets.CUBESQL_TESTING_CUBE_URL }}
121+
CUBESQL_SQL_PUSH_DOWN: true
121122
run: cd rust/cubesql && cargo test
122123

123124
native_linux:

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,7 +2402,11 @@ class BaseQuery {
24022402
select: 'SELECT {{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' +
24032403
'FROM (\n {{ from }}\n) AS {{ from_alias }} \n' +
24042404
'{% if group_by %} GROUP BY {{ group_by | map(attribute=\'index\') | join(\', \') }}{% endif %}',
2405+
},
2406+
expressions: {
24052407
column_aliased: '{{expr}} {{quoted_alias}}',
2408+
case: 'CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END',
2409+
binary: '{{ left }} {{ op }} {{ right }}'
24062410
},
24072411
quotes: {
24082412
identifiers: '"',

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ impl CubeScanWrapperNode {
389389
plan.clone(),
390390
sql_query,
391391
sql_generator.clone(),
392-
(*expr).clone(),
392+
*expr,
393393
)
394394
.await?;
395395
Ok((expr, sql_query))
@@ -423,7 +423,32 @@ impl CubeScanWrapperNode {
423423
sql_query,
424424
)),
425425
// Expr::ScalarVariable(_, _) => {}
426-
// Expr::BinaryExpr { .. } => {}
426+
Expr::BinaryExpr { left, op, right } => {
427+
let (left, sql_query) = Self::generate_sql_for_expr(
428+
plan.clone(),
429+
sql_query,
430+
sql_generator.clone(),
431+
*left,
432+
)
433+
.await?;
434+
let (right, sql_query) = Self::generate_sql_for_expr(
435+
plan.clone(),
436+
sql_query,
437+
sql_generator.clone(),
438+
*right,
439+
)
440+
.await?;
441+
let resulting_sql = sql_generator
442+
.get_sql_templates()
443+
.binary_expr(left, op.to_string(), right)
444+
.map_err(|e| {
445+
DataFusionError::Internal(format!(
446+
"Can't generate SQL for binary expr: {}",
447+
e
448+
))
449+
})?;
450+
Ok((resulting_sql, sql_query))
451+
}
427452
// Expr::AnyExpr { .. } => {}
428453
// Expr::Like(_) => {}-=
429454
// Expr::ILike(_) => {}
@@ -434,7 +459,64 @@ impl CubeScanWrapperNode {
434459
// Expr::Negative(_) => {}
435460
// Expr::GetIndexedField { .. } => {}
436461
// Expr::Between { .. } => {}
437-
// Expr::Case { .. } => {}
462+
Expr::Case {
463+
expr,
464+
when_then_expr,
465+
else_expr,
466+
} => {
467+
let expr = if let Some(expr) = expr {
468+
let (expr, sql_query_next) = Self::generate_sql_for_expr(
469+
plan.clone(),
470+
sql_query,
471+
sql_generator.clone(),
472+
*expr,
473+
)
474+
.await?;
475+
sql_query = sql_query_next;
476+
Some(expr)
477+
} else {
478+
None
479+
};
480+
let mut when_then_expr_sql = Vec::new();
481+
for (when, then) in when_then_expr {
482+
let (when, sql_query_next) = Self::generate_sql_for_expr(
483+
plan.clone(),
484+
sql_query,
485+
sql_generator.clone(),
486+
*when,
487+
)
488+
.await?;
489+
let (then, sql_query_next) = Self::generate_sql_for_expr(
490+
plan.clone(),
491+
sql_query_next,
492+
sql_generator.clone(),
493+
*then,
494+
)
495+
.await?;
496+
sql_query = sql_query_next;
497+
when_then_expr_sql.push((when, then));
498+
}
499+
let else_expr = if let Some(else_expr) = else_expr {
500+
let (else_expr, sql_query_next) = Self::generate_sql_for_expr(
501+
plan.clone(),
502+
sql_query,
503+
sql_generator.clone(),
504+
*else_expr,
505+
)
506+
.await?;
507+
sql_query = sql_query_next;
508+
Some(else_expr)
509+
} else {
510+
None
511+
};
512+
let resulting_sql = sql_generator
513+
.get_sql_templates()
514+
.case(expr, when_then_expr_sql, else_expr)
515+
.map_err(|e| {
516+
DataFusionError::Internal(format!("Can't generate SQL for case: {}", e))
517+
})?;
518+
Ok((resulting_sql, sql_query))
519+
}
438520
// Expr::Cast { .. } => {}
439521
// Expr::TryCast { .. } => {}
440522
// Expr::Sort { .. } => {}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,8 @@ mod tests {
16951695

16961696
fn find_cube_scan(&self) -> CubeScanNode;
16971697

1698+
fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode;
1699+
16981700
fn find_cube_scans(&self) -> Vec<CubeScanNode>;
16991701

17001702
fn find_filter(&self) -> Option<Filter>;
@@ -1736,6 +1738,20 @@ mod tests {
17361738
cube_scans[0].clone()
17371739
}
17381740

1741+
fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode {
1742+
match self {
1743+
LogicalPlan::Extension(Extension { node }) => {
1744+
if let Some(wrapper_node) = node.as_any().downcast_ref::<CubeScanWrapperNode>()
1745+
{
1746+
wrapper_node.clone()
1747+
} else {
1748+
panic!("Root plan node is not cube_scan_wrapper!");
1749+
}
1750+
}
1751+
_ => panic!("Root plan node is not extension!"),
1752+
}
1753+
}
1754+
17391755
fn find_cube_scans(&self) -> Vec<CubeScanNode> {
17401756
find_cube_scans_deep_search(Arc::new(self.clone()), true)
17411757
}
@@ -17929,20 +17945,39 @@ ORDER BY \"COUNT(count)\" DESC"
1792917945
)
1793017946
.await;
1793117947

17932-
// let logical_plan = query_plan.as_logical_plan();
17933-
// assert_eq!(
17934-
// logical_plan.find_cube_scan().request,
17935-
// V1LoadRequestQuery {
17936-
// measures: Some(vec!["KibanaSampleDataEcommerce.avgPrice".to_string(),]),
17937-
// segments: Some(vec![]),
17938-
// dimensions: Some(vec![]),
17939-
// time_dimensions: None,
17940-
// order: None,
17941-
// limit: None,
17942-
// offset: None,
17943-
// filters: None
17944-
// }
17945-
// );
17948+
let logical_plan = query_plan.as_logical_plan();
17949+
assert!(logical_plan
17950+
.find_cube_scan_wrapper()
17951+
.wrapped_sql
17952+
.unwrap()
17953+
.sql
17954+
.contains("COALESCE"));
17955+
17956+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
17957+
println!(
17958+
"Physical plan: {}",
17959+
displayable(physical_plan.as_ref()).indent()
17960+
);
17961+
}
17962+
17963+
#[tokio::test]
17964+
async fn test_case_wrapper() {
17965+
init_logger();
17966+
17967+
let query_plan = convert_select_to_query_plan(
17968+
"SELECT CASE WHEN customer_gender = 'female' THEN 'f' ELSE 'm' END, MIN(avgPrice) mp FROM (SELECT avgPrice, customer_gender FROM KibanaSampleDataEcommerce LIMIT 1) a GROUP BY 1"
17969+
.to_string(),
17970+
DatabaseProtocol::PostgreSQL,
17971+
)
17972+
.await;
17973+
17974+
let logical_plan = query_plan.as_logical_plan();
17975+
assert!(logical_plan
17976+
.find_cube_scan_wrapper()
17977+
.wrapped_sql
17978+
.unwrap()
17979+
.sql
17980+
.contains("CASE WHEN"));
1794617981

1794717982
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1794817983
println!(

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper.rs

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use crate::{
44
rewrite::{
55
agg_fun_expr, aggregate, alias_expr,
66
analysis::LogicalPlanAnalysis,
7-
column_expr, cube_scan, cube_scan_wrapper, fun_expr_var_arg, literal_expr, projection,
8-
rewrite,
7+
binary_expr, case_expr_var_arg, column_expr, cube_scan, cube_scan_wrapper,
8+
fun_expr_var_arg, literal_expr, projection, rewrite,
99
rewriter::RewriteRules,
1010
rules::{replacer_pull_up_node, replacer_push_down_node},
1111
scalar_fun_expr_args, scalar_fun_expr_args_empty_tail, transforming_rewrite,
@@ -339,6 +339,52 @@ impl RewriteRules for WrapperRules {
339339
alias_expr(wrapper_pullup_replacer("?expr", "?alias_to_cube"), "?alias"),
340340
wrapper_pullup_replacer(alias_expr("?expr", "?alias"), "?alias_to_cube"),
341341
),
342+
// Case
343+
rewrite(
344+
"wrapper-push-down-case",
345+
wrapper_pushdown_replacer(
346+
case_expr_var_arg("?when", "?then", "?else"),
347+
"?alias_to_cube",
348+
),
349+
case_expr_var_arg(
350+
wrapper_pushdown_replacer("?when", "?alias_to_cube"),
351+
wrapper_pushdown_replacer("?then", "?alias_to_cube"),
352+
wrapper_pushdown_replacer("?else", "?alias_to_cube"),
353+
),
354+
),
355+
transforming_rewrite(
356+
"wrapper-pull-up-case",
357+
case_expr_var_arg(
358+
wrapper_pullup_replacer("?when", "?alias_to_cube"),
359+
wrapper_pullup_replacer("?then", "?alias_to_cube"),
360+
wrapper_pullup_replacer("?else", "?alias_to_cube"),
361+
),
362+
wrapper_pullup_replacer(
363+
case_expr_var_arg("?when", "?then", "?else"),
364+
"?alias_to_cube",
365+
),
366+
self.transform_case_expr("?alias_to_cube"),
367+
),
368+
// Binary Expr
369+
rewrite(
370+
"wrapper-push-down-binary-expr",
371+
wrapper_pushdown_replacer(binary_expr("?left", "?op", "?right"), "?alias_to_cube"),
372+
binary_expr(
373+
wrapper_pushdown_replacer("?left", "?alias_to_cube"),
374+
"?op",
375+
wrapper_pushdown_replacer("?right", "?alias_to_cube"),
376+
),
377+
),
378+
transforming_rewrite(
379+
"wrapper-pull-up-binary-expr",
380+
binary_expr(
381+
wrapper_pullup_replacer("?left", "?alias_to_cube"),
382+
"?op",
383+
wrapper_pullup_replacer("?right", "?alias_to_cube"),
384+
),
385+
wrapper_pullup_replacer(binary_expr("?left", "?op", "?right"), "?alias_to_cube"),
386+
self.transform_binary_expr("?op", "?alias_to_cube"),
387+
),
342388
// Column
343389
rewrite(
344390
"wrapper-push-down-column",
@@ -353,6 +399,20 @@ impl RewriteRules for WrapperRules {
353399
),
354400
];
355401

402+
Self::expr_list_pushdown_pullup_rules(&mut rules, "wrapper-case-expr", "CaseExprExpr");
403+
404+
Self::expr_list_pushdown_pullup_rules(
405+
&mut rules,
406+
"wrapper-case-when-expr",
407+
"CaseExprWhenThenExpr",
408+
);
409+
410+
Self::expr_list_pushdown_pullup_rules(
411+
&mut rules,
412+
"wrapper-case-else-expr",
413+
"CaseExprElseExpr",
414+
);
415+
356416
Self::list_pushdown_pullup_rules(
357417
&mut rules,
358418
"wrapper-aggregate-aggr-expr",
@@ -562,6 +622,63 @@ impl WrapperRules {
562622
}
563623
}
564624

625+
fn transform_case_expr(
626+
&self,
627+
alias_to_cube_var: &'static str,
628+
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
629+
let alias_to_cube_var = var!(alias_to_cube_var);
630+
let meta = self.cube_context.meta.clone();
631+
move |egraph, subst| {
632+
for alias_to_cube in var_iter!(
633+
egraph[subst[alias_to_cube_var]],
634+
WrapperPullupReplacerAliasToCube
635+
)
636+
.cloned()
637+
{
638+
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
639+
if sql_generator
640+
.get_sql_templates()
641+
.templates
642+
.contains_key("expressions/case")
643+
{
644+
return true;
645+
}
646+
}
647+
}
648+
false
649+
}
650+
}
651+
652+
fn transform_binary_expr(
653+
&self,
654+
_operator_var: &'static str,
655+
alias_to_cube_var: &'static str,
656+
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
657+
let alias_to_cube_var = var!(alias_to_cube_var);
658+
// let operator_var = var!(operator_var);
659+
let meta = self.cube_context.meta.clone();
660+
move |egraph, subst| {
661+
for alias_to_cube in var_iter!(
662+
egraph[subst[alias_to_cube_var]],
663+
WrapperPullupReplacerAliasToCube
664+
)
665+
.cloned()
666+
{
667+
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
668+
if sql_generator
669+
.get_sql_templates()
670+
.templates
671+
.contains_key("expressions/binary")
672+
{
673+
// TODO check supported operators
674+
return true;
675+
}
676+
}
677+
}
678+
false
679+
}
680+
}
681+
565682
fn list_pushdown_pullup_rules(
566683
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
567684
rule_name: &str,
@@ -588,4 +705,30 @@ impl WrapperRules {
588705
wrapper_pullup_replacer(substitute_list_node, "?alias_to_cube"),
589706
)]);
590707
}
708+
709+
fn expr_list_pushdown_pullup_rules(
710+
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
711+
rule_name: &str,
712+
list_node: &str,
713+
) {
714+
rules.extend(replacer_push_down_node(
715+
rule_name,
716+
list_node,
717+
|node| wrapper_pushdown_replacer(node, "?alias_to_cube"),
718+
false,
719+
));
720+
721+
rules.extend(replacer_pull_up_node(
722+
rule_name,
723+
list_node,
724+
list_node,
725+
|node| wrapper_pullup_replacer(node, "?alias_to_cube"),
726+
));
727+
728+
rules.extend(vec![rewrite(
729+
rule_name,
730+
wrapper_pushdown_replacer(list_node, "?alias_to_cube"),
731+
wrapper_pullup_replacer(list_node, "?alias_to_cube"),
732+
)]);
733+
}
591734
}

0 commit comments

Comments
 (0)