Skip to content

Commit 86a58a5

Browse files
authored
feat(cubesql): Rewrites for pushdown of subqueries with empty source (#8188)
1 parent e366406 commit 86a58a5

File tree

8 files changed

+402
-96
lines changed

8 files changed

+402
-96
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,10 +2918,10 @@ export class BaseQuery {
29182918
},
29192919
statements: {
29202920
select: 'SELECT {% if distinct %}DISTINCT {% endif %}' +
2921-
'{{ select_concat | map(attribute=\'aliased\') | join(\', \') }} \n' +
2921+
'{{ select_concat | map(attribute=\'aliased\') | join(\', \') }} {% if from %}\n' +
29222922
'FROM (\n' +
29232923
'{{ from | indent(2, true) }}\n' +
2924-
') AS {{ from_alias }}' +
2924+
') AS {{ from_alias }}{% endif %}' +
29252925
'{% if filter %}\nWHERE {{ filter }}{% endif %}' +
29262926
'{% if group_by %}\nGROUP BY {{ group_by | map(attribute=\'index\') | join(\', \') }}{% endif %}' +
29272927
'{% if order_by %}\nORDER BY {{ order_by | map(attribute=\'expr\') | join(\', \') }}{% endif %}' +

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ impl SqlQuery {
5050
index
5151
}
5252

53+
pub fn extend_values(&mut self, values: &Vec<Option<String>>) {
54+
self.values.extend(values.iter().cloned());
55+
}
56+
5357
pub fn replace_sql(&mut self, sql: String) {
5458
self.sql = sql;
5559
}
@@ -243,6 +247,7 @@ impl CubeScanWrapperNode {
243247
self.clone().set_max_limit_for_node(wrapped_plan),
244248
true,
245249
Vec::new(),
250+
None,
246251
)
247252
.await
248253
.and_then(|SqlGenerationResult { data_source, mut sql, request, column_remapping, .. }| -> result::Result<_, CubeError> {
@@ -324,7 +329,8 @@ impl CubeScanWrapperNode {
324329
load_request_meta: Arc<LoadRequestMeta>,
325330
node: Arc<LogicalPlan>,
326331
can_rename_columns: bool,
327-
mut values: Vec<Option<String>>,
332+
values: Vec<Option<String>>,
333+
parent_data_source: Option<String>,
328334
) -> Pin<Box<dyn Future<Output = result::Result<SqlGenerationResult, CubeError>> + Send>> {
329335
Box::pin(async move {
330336
match node.as_ref() {
@@ -435,7 +441,7 @@ impl CubeScanWrapperNode {
435441
Some(Arc::new(cube_scan_node.clone()))
436442
} else {
437443
return Err(CubeError::internal(format!(
438-
"Expected CubeScan node but found: {:?}",
444+
"Expected ubeScan node but found: {:?}",
439445
plan
440446
)));
441447
}
@@ -448,37 +454,12 @@ impl CubeScanWrapperNode {
448454
} else {
449455
None
450456
};
451-
let mut subqueries_sql = HashMap::new();
452-
for subquery in subqueries.iter() {
453-
let SqlGenerationResult {
454-
data_source: _,
455-
from_alias: _,
456-
column_remapping: _,
457-
sql,
458-
request: _,
459-
} = Self::generate_sql_for_node(
460-
plan.clone(),
461-
transport.clone(),
462-
load_request_meta.clone(),
463-
subquery.clone(),
464-
true,
465-
values,
466-
)
467-
.await?;
468-
469-
let (sql_string, new_values) = sql.unpack();
470-
values = new_values;
471-
472-
let field = subquery.schema().field(0);
473-
subqueries_sql.insert(field.qualified_name(), sql_string);
474-
}
475457

476-
let subqueries_sql = Arc::new(subqueries_sql);
477458
let SqlGenerationResult {
478459
data_source,
479460
from_alias,
480461
column_remapping,
481-
sql,
462+
mut sql,
482463
request,
483464
} = if let Some(ungrouped_scan_node) = ungrouped_scan_node.clone() {
484465
let data_sources = ungrouped_scan_node
@@ -499,7 +480,7 @@ impl CubeScanWrapperNode {
499480
ungrouped_scan_node
500481
)));
501482
}
502-
let sql = SqlQuery::new("".to_string(), values);
483+
let sql = SqlQuery::new("".to_string(), values.clone());
503484
SqlGenerationResult {
504485
data_source: Some(data_sources[0].clone()),
505486
from_alias: ungrouped_scan_node
@@ -519,10 +500,37 @@ impl CubeScanWrapperNode {
519500
load_request_meta.clone(),
520501
from.clone(),
521502
true,
522-
values,
503+
values.clone(),
504+
parent_data_source.clone(),
523505
)
524506
.await?
525507
};
508+
509+
let mut subqueries_sql = HashMap::new();
510+
for subquery in subqueries.iter() {
511+
let SqlGenerationResult {
512+
data_source: _,
513+
from_alias: _,
514+
column_remapping: _,
515+
sql: subquery_sql,
516+
request: _,
517+
} = Self::generate_sql_for_node(
518+
plan.clone(),
519+
transport.clone(),
520+
load_request_meta.clone(),
521+
subquery.clone(),
522+
true,
523+
sql.values.clone(),
524+
data_source.clone(),
525+
)
526+
.await?;
527+
528+
let (sql_string, new_values) = subquery_sql.unpack();
529+
sql.extend_values(&new_values);
530+
let field = subquery.schema().field(0);
531+
subqueries_sql.insert(field.qualified_name(), sql_string);
532+
}
533+
let subqueries_sql = Arc::new(subqueries_sql);
526534
let mut next_remapping = HashMap::new();
527535
let alias = alias.or(from_alias.clone());
528536
if let Some(data_source) = data_source {
@@ -825,6 +833,13 @@ impl CubeScanWrapperNode {
825833
)));
826834
}
827835
}
836+
LogicalPlan::EmptyRelation(_) => Ok(SqlGenerationResult {
837+
data_source: parent_data_source,
838+
from_alias: None,
839+
sql: SqlQuery::new("".to_string(), values.clone()),
840+
column_remapping: None,
841+
request: V1LoadRequestQuery::new(),
842+
}),
828843
// LogicalPlan::Distinct(_) => {}
829844
x => {
830845
return Err(CubeError::internal(format!(

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 143 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use self::{
4545
},
4646
parser::parse_sql_to_statement,
4747
qtrace::Qtrace,
48-
rewrite::converter::LogicalPlanToLanguageConverter,
48+
rewrite::converter::{LogicalPlanToLanguageContext, LogicalPlanToLanguageConverter},
4949
};
5050
use crate::{
5151
sql::{
@@ -1313,7 +1313,11 @@ WHERE `TABLE_SCHEMA` = '{}'",
13131313
let mut converter = LogicalPlanToLanguageConverter::new(cube_ctx.clone());
13141314
let mut query_params = Some(HashMap::new());
13151315
let root = converter
1316-
.add_logical_plan_replace_params(&optimized_plan, &mut query_params)
1316+
.add_logical_plan_replace_params(
1317+
&optimized_plan,
1318+
&mut query_params,
1319+
&mut LogicalPlanToLanguageContext::default(),
1320+
)
13171321
.map_err(|e| CompilationError::internal(e.to_string()))?;
13181322

13191323
let mut finalized_graph = self
@@ -20115,14 +20119,149 @@ ORDER BY "source"."str0" ASC
2011520119
}
2011620120

2011720121
#[tokio::test]
20118-
async fn test_simple_subquery_wrapper_projection() {
20122+
async fn test_simple_subquery_wrapper_projection_empty_source() {
20123+
if !Rewriter::sql_push_down_enabled() {
20124+
return;
20125+
}
20126+
init_logger();
20127+
20128+
let query_plan = convert_select_to_query_plan(
20129+
"SELECT (SELECT 'male' where 1 group by 'male' having 1 order by 'male' limit 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
20130+
.to_string(),
20131+
DatabaseProtocol::PostgreSQL,
20132+
)
20133+
.await;
20134+
20135+
let logical_plan = query_plan.as_logical_plan();
20136+
let sql = logical_plan
20137+
.find_cube_scan_wrapper()
20138+
.wrapped_sql
20139+
.unwrap()
20140+
.sql;
20141+
assert!(sql.contains("(SELECT"));
20142+
assert!(sql.contains("utf8__male__"));
20143+
20144+
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
20145+
//println!("phys plan {:?}", physical_plan);
20146+
}
20147+
20148+
#[tokio::test]
20149+
async fn test_simple_subquery_wrapper_filter_empty_source() {
2011920150
if !Rewriter::sql_push_down_enabled() {
2012020151
return;
2012120152
}
2012220153
init_logger();
2012320154

2012420155
let query_plan = convert_select_to_query_plan(
20125-
"SELECT (SELECT customer_gender FROM KibanaSampleDataEcommerce WHERE customer_gender = 'male' LIMIT 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
20156+
"SELECT avgPrice FROM KibanaSampleDataEcommerce a where customer_gender = (SELECT 'male' )"
20157+
.to_string(),
20158+
DatabaseProtocol::PostgreSQL,
20159+
)
20160+
.await;
20161+
20162+
let logical_plan = query_plan.as_logical_plan();
20163+
let sql = logical_plan
20164+
.find_cube_scan_wrapper()
20165+
.wrapped_sql
20166+
.unwrap()
20167+
.sql;
20168+
assert!(sql.contains("(SELECT"));
20169+
assert!(sql.contains("utf8__male__"));
20170+
20171+
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
20172+
//println!("phys plan {:?}", physical_plan);
20173+
}
20174+
20175+
#[tokio::test]
20176+
async fn test_simple_subquery_wrapper_projection_aggregate_empty_source() {
20177+
if !Rewriter::sql_push_down_enabled() {
20178+
return;
20179+
}
20180+
init_logger();
20181+
20182+
let query_plan = convert_select_to_query_plan(
20183+
"SELECT (SELECT 'male'), avg(avgPrice) FROM KibanaSampleDataEcommerce a GROUP BY 1"
20184+
.to_string(),
20185+
DatabaseProtocol::PostgreSQL,
20186+
)
20187+
.await;
20188+
20189+
let logical_plan = query_plan.as_logical_plan();
20190+
let sql = logical_plan
20191+
.find_cube_scan_wrapper()
20192+
.wrapped_sql
20193+
.unwrap()
20194+
.sql;
20195+
assert!(sql.contains("(SELECT"));
20196+
assert!(sql.contains("utf8__male__"));
20197+
20198+
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
20199+
}
20200+
20201+
#[tokio::test]
20202+
async fn test_simple_subquery_wrapper_filter_in_empty_source() {
20203+
if !Rewriter::sql_push_down_enabled() {
20204+
return;
20205+
}
20206+
init_logger();
20207+
20208+
let query_plan = convert_select_to_query_plan(
20209+
"SELECT customer_gender, avgPrice FROM KibanaSampleDataEcommerce a where customer_gender in (select 'male')"
20210+
.to_string(),
20211+
DatabaseProtocol::PostgreSQL,
20212+
)
20213+
.await;
20214+
20215+
let logical_plan = query_plan.as_logical_plan();
20216+
let sql = logical_plan
20217+
.find_cube_scan_wrapper()
20218+
.wrapped_sql
20219+
.unwrap()
20220+
.sql;
20221+
assert!(sql.contains("IN (SELECT"));
20222+
assert!(sql.contains("utf8__male__"));
20223+
20224+
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
20225+
}
20226+
20227+
#[tokio::test]
20228+
async fn test_simple_subquery_wrapper_filter_and_projection_empty_source() {
20229+
if !Rewriter::sql_push_down_enabled() {
20230+
return;
20231+
}
20232+
init_logger();
20233+
20234+
let query_plan = convert_select_to_query_plan(
20235+
"SELECT (select 'male'), avgPrice FROM KibanaSampleDataEcommerce a where customer_gender in (select 'female')"
20236+
.to_string(),
20237+
DatabaseProtocol::PostgreSQL,
20238+
)
20239+
.await;
20240+
20241+
let logical_plan = query_plan.as_logical_plan();
20242+
20243+
let sql = logical_plan
20244+
.find_cube_scan_wrapper()
20245+
.wrapped_sql
20246+
.unwrap()
20247+
.sql;
20248+
assert!(sql.contains("IN (SELECT"));
20249+
assert!(sql.contains("(SELECT"));
20250+
assert!(sql.contains("utf8__male__"));
20251+
assert!(sql.contains("utf8__female__"));
20252+
20253+
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
20254+
}
20255+
20256+
#[tokio::test]
20257+
async fn test_simple_subquery_wrapper_projection_1() {
20258+
if !Rewriter::sql_push_down_enabled() {
20259+
return;
20260+
}
20261+
init_logger();
20262+
20263+
let query_plan = convert_select_to_query_plan(
20264+
"SELECT (SELECT customer_gender FROM KibanaSampleDataEcommerce LIMIT 1) as gender, avgPrice FROM KibanaSampleDataEcommerce a"
2012620265
.to_string(),
2012720266
DatabaseProtocol::PostgreSQL,
2012820267
)
@@ -20166,12 +20305,6 @@ ORDER BY "source"."str0" ASC
2016620305
.unwrap()
2016720306
.sql
2016820307
.contains("(SELECT"));
20169-
assert!(logical_plan
20170-
.find_cube_scan_wrapper()
20171-
.wrapped_sql
20172-
.unwrap()
20173-
.sql
20174-
.contains("LIMIT 1"));
2017520308

2017620309
let _physical_plan = query_plan.as_physical_plan().await.unwrap();
2017720310
}

0 commit comments

Comments
 (0)