Skip to content

Commit c64994a

Browse files
committed
feat(cubesql): Support [NOT] IN SQL push down
1 parent 9555af6 commit c64994a

File tree

7 files changed

+199
-2
lines changed

7 files changed

+199
-2
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2479,7 +2479,8 @@ class BaseQuery {
24792479
binary: '({{ left }} {{ op }} {{ right }})',
24802480
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
24812481
cast: 'CAST({{ expr }} AS {{ data_type }})',
2482-
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})'
2482+
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})',
2483+
in_list: '{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})',
24832484
},
24842485
quotes: {
24852486
identifiers: '"',

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,47 @@ impl CubeScanWrapperNode {
14291429
Ok((resulting_sql, sql_query))
14301430
}
14311431
// Expr::AggregateUDF { .. } => {}
1432-
// Expr::InList { .. } => {}
1432+
Expr::InList {
1433+
expr,
1434+
list,
1435+
negated,
1436+
} => {
1437+
let mut sql_query = sql_query;
1438+
let (sql_expr, query) = Self::generate_sql_for_expr(
1439+
plan.clone(),
1440+
sql_query,
1441+
sql_generator.clone(),
1442+
*expr,
1443+
ungrouped_scan_node.clone(),
1444+
)
1445+
.await?;
1446+
sql_query = query;
1447+
let mut sql_in_exprs = Vec::new();
1448+
for expr in list {
1449+
let (sql, query) = Self::generate_sql_for_expr(
1450+
plan.clone(),
1451+
sql_query,
1452+
sql_generator.clone(),
1453+
expr,
1454+
ungrouped_scan_node.clone(),
1455+
)
1456+
.await?;
1457+
sql_query = query;
1458+
sql_in_exprs.push(sql);
1459+
}
1460+
Ok((
1461+
sql_generator
1462+
.get_sql_templates()
1463+
.in_list_expr(sql_expr, sql_in_exprs, negated)
1464+
.map_err(|e| {
1465+
DataFusionError::Internal(format!(
1466+
"Can't generate SQL for in list expr: {}",
1467+
e
1468+
))
1469+
})?,
1470+
sql_query,
1471+
))
1472+
}
14331473
// Expr::Wildcard => {}
14341474
// Expr::QualifiedWildcard { .. } => {}
14351475
x => {

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19591,4 +19591,42 @@ ORDER BY \"COUNT(count)\" DESC"
1959119591

1959219592
Ok(())
1959319593
}
19594+
19595+
#[tokio::test]
19596+
async fn test_inlist_expr() {
19597+
if !Rewriter::sql_push_down_enabled() {
19598+
return;
19599+
}
19600+
init_logger();
19601+
19602+
let query_plan = convert_select_to_query_plan(
19603+
"
19604+
SELECT
19605+
CASE
19606+
WHEN (customer_gender NOT IN ('1', '2', '3')) THEN customer_gender
19607+
ELSE '0'
19608+
END AS customer_gender
19609+
FROM KibanaSampleDataEcommerce AS k
19610+
GROUP BY 1
19611+
ORDER BY 1 DESC
19612+
"
19613+
.to_string(),
19614+
DatabaseProtocol::PostgreSQL,
19615+
)
19616+
.await;
19617+
19618+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
19619+
println!(
19620+
"Physical plan: {}",
19621+
displayable(physical_plan.as_ref()).indent()
19622+
);
19623+
19624+
let logical_plan = query_plan.as_logical_plan();
19625+
assert!(logical_plan
19626+
.find_cube_scan_wrapper()
19627+
.wrapped_sql
19628+
.unwrap()
19629+
.sql
19630+
.contains("NOT IN ("));
19631+
}
1959419632
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
use crate::{
2+
compile::rewrite::{
3+
analysis::LogicalPlanAnalysis, inlist_expr, rewrite, rules::wrapper::WrapperRules,
4+
transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer,
5+
LogicalPlanLanguage, WrapperPullupReplacerAliasToCube,
6+
},
7+
var, var_iter,
8+
};
9+
use egg::{EGraph, Rewrite, Subst};
10+
11+
impl WrapperRules {
12+
pub fn in_list_expr_rules(
13+
&self,
14+
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
15+
) {
16+
rules.extend(vec![
17+
rewrite(
18+
"wrapper-in-list-push-down",
19+
wrapper_pushdown_replacer(
20+
inlist_expr("?expr", "?list", "?negated"),
21+
"?alias_to_cube",
22+
"?ungrouped",
23+
"?cube_members",
24+
),
25+
inlist_expr(
26+
wrapper_pushdown_replacer(
27+
"?expr",
28+
"?alias_to_cube",
29+
"?ungrouped",
30+
"?cube_members",
31+
),
32+
wrapper_pushdown_replacer(
33+
"?list",
34+
"?alias_to_cube",
35+
"?ungrouped",
36+
"?cube_members",
37+
),
38+
"?negated",
39+
),
40+
),
41+
transforming_rewrite(
42+
"wrapper-in-list-pull-up",
43+
inlist_expr(
44+
wrapper_pullup_replacer(
45+
"?expr",
46+
"?alias_to_cube",
47+
"?ungrouped",
48+
"?cube_members",
49+
),
50+
wrapper_pullup_replacer(
51+
"?list",
52+
"?alias_to_cube",
53+
"?ungrouped",
54+
"?cube_members",
55+
),
56+
"?negated",
57+
),
58+
wrapper_pullup_replacer(
59+
inlist_expr("?expr", "?list", "?negated"),
60+
"?alias_to_cube",
61+
"?ungrouped",
62+
"?cube_members",
63+
),
64+
self.transform_in_list_expr("?alias_to_cube"),
65+
),
66+
]);
67+
68+
Self::expr_list_pushdown_pullup_rules(rules, "wrapper-in-list-exprs", "InListExprList");
69+
}
70+
71+
fn transform_in_list_expr(
72+
&self,
73+
alias_to_cube_var: &'static str,
74+
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
75+
let alias_to_cube_var = var!(alias_to_cube_var);
76+
let meta = self.cube_context.meta.clone();
77+
move |egraph, subst| {
78+
for alias_to_cube in var_iter!(
79+
egraph[subst[alias_to_cube_var]],
80+
WrapperPullupReplacerAliasToCube
81+
)
82+
.cloned()
83+
{
84+
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
85+
if sql_generator
86+
.get_sql_templates()
87+
.templates
88+
.contains_key("expressions/in_list")
89+
{
90+
return true;
91+
}
92+
}
93+
}
94+
false
95+
}
96+
}
97+
}

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod cast;
77
mod column;
88
mod cube_scan_wrapper;
99
mod extract;
10+
mod in_list_expr;
1011
mod is_null_expr;
1112
mod limit;
1213
mod literal;
@@ -60,6 +61,7 @@ impl RewriteRules for WrapperRules {
6061
self.cast_rules(&mut rules);
6162
self.column_rules(&mut rules);
6263
self.literal_rules(&mut rules);
64+
self.in_list_expr_rules(&mut rules);
6365

6466
rules
6567
}

rust/cubesql/cubesql/src/compile/test/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ pub fn get_test_tenant_ctx() -> Arc<MetaContext> {
234234
("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()),
235235
("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()),
236236
("expressions/window_function".to_string(), "{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})".to_string()),
237+
("expressions/in_list".to_string(), "{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})".to_string()),
237238
("quotes/identifiers".to_string(), "\"".to_string()),
238239
("quotes/escape".to_string(), "\"\"".to_string()),
239240
("params/param".to_string(), "${{ param_index + 1 }}".to_string())

rust/cubesql/cubesql/src/transport/service.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,24 @@ impl SqlTemplates {
536536
)
537537
}
538538

539+
pub fn in_list_expr(
540+
&self,
541+
expr: String,
542+
in_exprs: Vec<String>,
543+
negated: bool,
544+
) -> Result<String, CubeError> {
545+
let in_exprs_concat = in_exprs.join(", ");
546+
self.render_template(
547+
"expressions/in_list",
548+
context! {
549+
expr => expr,
550+
in_exprs_concat => in_exprs_concat,
551+
in_exprs => in_exprs,
552+
negated => negated
553+
},
554+
)
555+
}
556+
539557
pub fn param(&self, param_index: usize) -> Result<String, CubeError> {
540558
self.render_template("params/param", context! { param_index => param_index })
541559
}

0 commit comments

Comments
 (0)