Skip to content

Commit b1da6c0

Browse files
authored
feat(cubesql): SQL push down support for window functions (#7403)
1 parent b3ea6ab commit b1da6c0

File tree

18 files changed

+607
-37
lines changed

18 files changed

+607
-37
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,7 @@ class BaseQuery {
24792479
binary: '({{ left }} {{ op }} {{ right }})',
24802480
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
24812481
cast: 'CAST({{ expr }} AS {{ data_type }})',
2482+
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})'
24822483
},
24832484
quotes: {
24842485
identifiers: '"',

rust/cubesql/cubesql/src/compile/engine/df/scan.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ pub struct WrappedSelectNode {
140140
pub projection_expr: Vec<Expr>,
141141
pub group_expr: Vec<Expr>,
142142
pub aggr_expr: Vec<Expr>,
143+
pub window_expr: Vec<Expr>,
143144
pub from: Arc<LogicalPlan>,
144145
pub joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
145146
pub filter_expr: Vec<Expr>,
@@ -158,6 +159,7 @@ impl WrappedSelectNode {
158159
projection_expr: Vec<Expr>,
159160
group_expr: Vec<Expr>,
160161
aggr_expr: Vec<Expr>,
162+
window_expr: Vec<Expr>,
161163
from: Arc<LogicalPlan>,
162164
joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
163165
filter_expr: Vec<Expr>,
@@ -174,6 +176,7 @@ impl WrappedSelectNode {
174176
projection_expr,
175177
group_expr,
176178
aggr_expr,
179+
window_expr,
177180
from,
178181
joins,
179182
filter_expr,
@@ -207,6 +210,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
207210
exprs.extend(self.projection_expr.clone());
208211
exprs.extend(self.group_expr.clone());
209212
exprs.extend(self.aggr_expr.clone());
213+
exprs.extend(self.window_expr.clone());
210214
exprs.extend(self.joins.iter().map(|(_, expr, _)| expr.clone()));
211215
exprs.extend(self.filter_expr.clone());
212216
exprs.extend(self.having_expr.clone());
@@ -217,11 +221,12 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
217221
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
218222
write!(
219223
f,
220-
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
224+
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, window_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
221225
self.select_type,
222226
self.projection_expr,
223227
self.group_expr,
224228
self.aggr_expr,
229+
self.window_expr,
225230
self.from,
226231
self.joins,
227232
self.filter_expr,
@@ -261,6 +266,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
261266
let mut projection_expr = vec![];
262267
let mut group_expr = vec![];
263268
let mut aggregate_expr = vec![];
269+
let mut window_expr = vec![];
264270
let limit = None;
265271
let offset = None;
266272
let alias = None;
@@ -278,6 +284,10 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
278284
aggregate_expr.push(exprs_iter.next().unwrap().clone());
279285
}
280286

287+
for _ in self.window_expr.iter() {
288+
window_expr.push(exprs_iter.next().unwrap().clone());
289+
}
290+
281291
for _ in self.joins.iter() {
282292
joins_expr.push(exprs_iter.next().unwrap().clone());
283293
}
@@ -300,6 +310,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
300310
projection_expr,
301311
group_expr,
302312
aggregate_expr,
313+
window_expr,
303314
from,
304315
joins
305316
.into_iter()

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ impl CubeScanWrapperNode {
297297
projection_expr,
298298
group_expr,
299299
aggr_expr,
300+
window_expr,
300301
from,
301302
joins: _joins,
302303
filter_expr: _filter_expr,
@@ -431,6 +432,20 @@ impl CubeScanWrapperNode {
431432
ungrouped_scan_node.clone(),
432433
)
433434
.await?;
435+
436+
let (window, sql) = Self::generate_column_expr(
437+
plan.clone(),
438+
schema.clone(),
439+
window_expr.clone(),
440+
sql,
441+
generator.clone(),
442+
&column_remapping,
443+
&mut next_remapping,
444+
alias.clone(),
445+
can_rename_columns,
446+
ungrouped_scan_node.clone(),
447+
)
448+
.await?;
434449
// Sort node always comes on top and pushed down to select so we need to replace columns here by appropriate column definitions
435450
let order_replace_map = projection_expr
436451
.iter()
@@ -504,6 +519,12 @@ impl CubeScanWrapperNode {
504519
)
505520
}),
506521
)
522+
.chain(window.iter().map(|m| {
523+
Self::ungrouped_member_def(
524+
m,
525+
&ungrouped_scan_node.used_cubes,
526+
)
527+
}))
507528
.collect::<Result<_>>()?,
508529
);
509530
load_request.dimensions = Some(
@@ -1333,7 +1354,80 @@ impl CubeScanWrapperNode {
13331354
sql_query,
13341355
))
13351356
}
1336-
// Expr::WindowFunction { .. } => {}
1357+
Expr::WindowFunction {
1358+
fun,
1359+
args,
1360+
partition_by,
1361+
order_by,
1362+
window_frame,
1363+
} => {
1364+
let mut sql_args = Vec::new();
1365+
for arg in args {
1366+
let (sql, query) = Self::generate_sql_for_expr(
1367+
plan.clone(),
1368+
sql_query,
1369+
sql_generator.clone(),
1370+
arg,
1371+
ungrouped_scan_node.clone(),
1372+
)
1373+
.await?;
1374+
sql_query = query;
1375+
sql_args.push(sql);
1376+
}
1377+
let mut sql_partition_by = Vec::new();
1378+
for arg in partition_by {
1379+
let (sql, query) = Self::generate_sql_for_expr(
1380+
plan.clone(),
1381+
sql_query,
1382+
sql_generator.clone(),
1383+
arg,
1384+
ungrouped_scan_node.clone(),
1385+
)
1386+
.await?;
1387+
sql_query = query;
1388+
sql_partition_by.push(sql);
1389+
}
1390+
let mut sql_order_by = Vec::new();
1391+
for arg in order_by {
1392+
let (sql, query) = Self::generate_sql_for_expr(
1393+
plan.clone(),
1394+
sql_query,
1395+
sql_generator.clone(),
1396+
arg,
1397+
ungrouped_scan_node.clone(),
1398+
)
1399+
.await?;
1400+
sql_query = query;
1401+
sql_order_by.push(
1402+
sql_generator
1403+
.get_sql_templates()
1404+
// TODO asc/desc
1405+
.sort_expr(sql, true, false)
1406+
.map_err(|e| {
1407+
DataFusionError::Internal(format!(
1408+
"Can't generate SQL for sort expr: {}",
1409+
e
1410+
))
1411+
})?,
1412+
);
1413+
}
1414+
let resulting_sql = sql_generator
1415+
.get_sql_templates()
1416+
.window_function_expr(
1417+
fun,
1418+
sql_args,
1419+
sql_partition_by,
1420+
sql_order_by,
1421+
window_frame,
1422+
)
1423+
.map_err(|e| {
1424+
DataFusionError::Internal(format!(
1425+
"Can't generate SQL for window function: {}",
1426+
e
1427+
))
1428+
})?;
1429+
Ok((resulting_sql, sql_query))
1430+
}
13371431
// Expr::AggregateUDF { .. } => {}
13381432
// Expr::InList { .. } => {}
13391433
// Expr::Wildcard => {}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18847,12 +18847,20 @@ ORDER BY \"COUNT(count)\" DESC"
1884718847
.sql
1884818848
.contains("CASE WHEN"));
1884918849

18850-
assert!(logical_plan
18851-
.find_cube_scan_wrapper()
18852-
.wrapped_sql
18853-
.unwrap()
18854-
.sql
18855-
.contains("1123"));
18850+
assert!(
18851+
logical_plan
18852+
.find_cube_scan_wrapper()
18853+
.wrapped_sql
18854+
.unwrap()
18855+
.sql
18856+
.contains("1123"),
18857+
"SQL contains 1123: {}",
18858+
logical_plan
18859+
.find_cube_scan_wrapper()
18860+
.wrapped_sql
18861+
.unwrap()
18862+
.sql
18863+
);
1885618864

1885718865
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1885818866
println!(
@@ -18883,12 +18891,20 @@ ORDER BY \"COUNT(count)\" DESC"
1888318891
.sql
1888418892
.contains("CASE WHEN"));
1888518893

18886-
assert!(logical_plan
18887-
.find_cube_scan_wrapper()
18888-
.wrapped_sql
18889-
.unwrap()
18890-
.sql
18891-
.contains("LIMIT 1123"));
18894+
assert!(
18895+
logical_plan
18896+
.find_cube_scan_wrapper()
18897+
.wrapped_sql
18898+
.unwrap()
18899+
.sql
18900+
.contains("1123"),
18901+
"SQL contains 1123: {}",
18902+
logical_plan
18903+
.find_cube_scan_wrapper()
18904+
.wrapped_sql
18905+
.unwrap()
18906+
.sql
18907+
);
1889218908

1889318909
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1889418910
println!(
@@ -19063,6 +19079,43 @@ ORDER BY \"COUNT(count)\" DESC"
1906319079
.contains("EXTRACT"));
1906419080
}
1906519081

19082+
#[tokio::test]
19083+
async fn test_wrapper_window_function() {
19084+
if !Rewriter::sql_push_down_enabled() {
19085+
return;
19086+
}
19087+
init_logger();
19088+
19089+
let query_plan = convert_select_to_query_plan(
19090+
"SELECT customer_gender, AVG(avgPrice) mp, SUM(COUNT(count)) OVER() FROM KibanaSampleDataEcommerce a GROUP BY 1 LIMIT 100"
19091+
.to_string(),
19092+
DatabaseProtocol::PostgreSQL,
19093+
)
19094+
.await;
19095+
19096+
let logical_plan = query_plan.as_logical_plan();
19097+
assert!(
19098+
logical_plan
19099+
.find_cube_scan_wrapper()
19100+
.wrapped_sql
19101+
.unwrap()
19102+
.sql
19103+
.contains("OVER"),
19104+
"SQL should contain 'OVER': {}",
19105+
logical_plan
19106+
.find_cube_scan_wrapper()
19107+
.wrapped_sql
19108+
.unwrap()
19109+
.sql
19110+
);
19111+
19112+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
19113+
println!(
19114+
"Physical plan: {}",
19115+
displayable(physical_plan.as_ref()).indent()
19116+
);
19117+
}
19118+
1906619119
#[tokio::test]
1906719120
async fn test_thoughtspot_pg_date_trunc_year() {
1906819121
init_logger();

rust/cubesql/cubesql/src/compile/rewrite/converter.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ use datafusion::{
4141
logical_plan::{
4242
build_join_schema, build_table_udf_schema, exprlist_to_fields, normalize_cols,
4343
plan::{Aggregate, Extension, Filter, Join, Projection, Sort, TableUDFs, Window},
44-
CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation, Expr, Like, Limit,
45-
LogicalPlan, LogicalPlanBuilder, TableScan, Union,
44+
replace_col_to_expr, CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation,
45+
Expr, Like, Limit, LogicalPlan, LogicalPlanBuilder, TableScan, Union,
4646
},
4747
physical_plan::planner::DefaultPhysicalPlanner,
4848
scalar::ScalarValue,
@@ -1671,8 +1671,10 @@ impl LanguageToLogicalPlanConverter {
16711671
match_expr_list_node!(node_by_id, to_expr, params[2], WrappedSelectGroupExpr);
16721672
let aggr_expr =
16731673
match_expr_list_node!(node_by_id, to_expr, params[3], WrappedSelectAggrExpr);
1674-
let from = Arc::new(self.to_logical_plan(params[4])?);
1675-
let joins = match_list_node!(node_by_id, params[5], WrappedSelectJoins)
1674+
let window_expr =
1675+
match_expr_list_node!(node_by_id, to_expr, params[4], WrappedSelectWindowExpr);
1676+
let from = Arc::new(self.to_logical_plan(params[5])?);
1677+
let joins = match_list_node!(node_by_id, params[6], WrappedSelectJoins)
16761678
.into_iter()
16771679
.map(|j| {
16781680
if let LogicalPlanLanguage::WrappedSelectJoin(params) = j {
@@ -1688,28 +1690,49 @@ impl LanguageToLogicalPlanConverter {
16881690
.collect::<Result<Vec<_>, _>>()?;
16891691

16901692
let filter_expr =
1691-
match_expr_list_node!(node_by_id, to_expr, params[6], WrappedSelectFilterExpr);
1693+
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectFilterExpr);
16921694
let having_expr =
1693-
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectHavingExpr);
1694-
let limit = match_data_node!(node_by_id, params[8], WrappedSelectLimit);
1695-
let offset = match_data_node!(node_by_id, params[9], WrappedSelectOffset);
1695+
match_expr_list_node!(node_by_id, to_expr, params[8], WrappedSelectHavingExpr);
1696+
let limit = match_data_node!(node_by_id, params[9], WrappedSelectLimit);
1697+
let offset = match_data_node!(node_by_id, params[10], WrappedSelectOffset);
16961698
let order_expr =
1697-
match_expr_list_node!(node_by_id, to_expr, params[10], WrappedSelectOrderExpr);
1698-
let alias = match_data_node!(node_by_id, params[11], WrappedSelectAlias);
1699-
let ungrouped = match_data_node!(node_by_id, params[12], WrappedSelectUngrouped);
1699+
match_expr_list_node!(node_by_id, to_expr, params[11], WrappedSelectOrderExpr);
1700+
let alias = match_data_node!(node_by_id, params[12], WrappedSelectAlias);
1701+
let ungrouped = match_data_node!(node_by_id, params[13], WrappedSelectUngrouped);
17001702

17011703
let group_expr = normalize_cols(group_expr, &from)?;
17021704
let aggr_expr = normalize_cols(aggr_expr, &from)?;
17031705
let projection_expr = normalize_cols(projection_expr, &from)?;
1704-
let all_expr = match select_type {
1706+
let all_expr_without_window = match select_type {
17051707
WrappedSelectType::Projection => projection_expr.clone(),
17061708
WrappedSelectType::Aggregate => {
17071709
group_expr.iter().chain(aggr_expr.iter()).cloned().collect()
17081710
}
17091711
};
1712+
let without_window_fields =
1713+
exprlist_to_fields(all_expr_without_window.iter(), from.schema())?;
1714+
let replace_map = all_expr_without_window
1715+
.iter()
1716+
.zip(without_window_fields.iter())
1717+
.map(|(e, f)| (f.qualified_column(), e.clone()))
1718+
.collect::<Vec<_>>();
1719+
let replace_map = replace_map
1720+
.iter()
1721+
.map(|(c, e)| (c, e))
1722+
.collect::<HashMap<_, _>>();
1723+
let window_expr_rebased = window_expr
1724+
.iter()
1725+
.map(|e| replace_col_to_expr(e.clone(), &replace_map))
1726+
.collect::<Result<Vec<_>, _>>()?;
17101727
let schema = DFSchema::new_with_metadata(
17111728
// TODO support joins schema
1712-
exprlist_to_fields(all_expr.iter(), from.schema())?,
1729+
without_window_fields
1730+
.into_iter()
1731+
.chain(
1732+
exprlist_to_fields(window_expr_rebased.iter(), from.schema())?
1733+
.into_iter(),
1734+
)
1735+
.collect(),
17131736
HashMap::new(),
17141737
)?;
17151738

@@ -1725,6 +1748,7 @@ impl LanguageToLogicalPlanConverter {
17251748
projection_expr,
17261749
group_expr,
17271750
aggr_expr,
1751+
window_expr_rebased,
17281752
from,
17291753
joins,
17301754
filter_expr,

0 commit comments

Comments
 (0)