Skip to content

Commit c36d156

Browse files
committed
[WIP] feat(cubesql): Support PatchMeasure in SQL pushdown
1 parent 1e6264f commit c36d156

File tree

5 files changed

+330
-1
lines changed

5 files changed

+330
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
---
2+
source: cubesql/src/compile/engine/df/wrapper.rs
3+
expression: "UngroupedMemberDef\n{\n cube_name: \"cube\".to_string(), alias: \"alias\".to_string(), expr:\n UngroupedMemberExpr::PatchMeasure(PatchMeasureDef\n {\n source_measure: \"cube.measure\".to_string(), replace_aggregation_type:\n None, add_filters:\n vec![SqlFunctionExpr\n {\n cube_params: vec![\"cube\".to_string()], sql:\n \"1 + 2 = 3\".to_string(),\n }],\n }), grouping_set: None,\n}"
4+
---
5+
{
6+
"cubeName": "cube",
7+
"alias": "alias",
8+
"expr": {
9+
"type": "PatchMeasure",
10+
"sourceMeasure": "cube.measure",
11+
"replaceAggregationType": null,
12+
"addFilters": [
13+
{
14+
"cubeParams": [
15+
"cube"
16+
],
17+
"sql": "1 + 2 = 3"
18+
}
19+
]
20+
},
21+
"groupingSet": null
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
source: cubesql/src/compile/engine/df/wrapper.rs
3+
expression: "UngroupedMemberDef\n{\n cube_name: \"cube\".to_string(), alias: \"alias\".to_string(), expr:\n UngroupedMemberExpr::SqlFunction(SqlFunctionExpr\n {\n cube_params: vec![\"cube\".to_string(), \"other\".to_string()], sql:\n \"1 + 2\".to_string(),\n }), grouping_set: None,\n}"
4+
---
5+
{
6+
"cubeName": "cube",
7+
"alias": "alias",
8+
"expr": {
9+
"type": "SqlFunction",
10+
"cubeParams": [
11+
"cube",
12+
"other"
13+
],
14+
"sql": "1 + 2"
15+
},
16+
"groupingSet": null
17+
}

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 257 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
compile::{
33
engine::{
44
df::scan::{CubeScanNode, DataType, MemberField, WrappedSelectNode},
5-
udf::MEASURE_UDAF_NAME,
5+
udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
66
},
77
rewrite::{
88
extract_exprlist_from_groupping_set,
@@ -75,10 +75,21 @@ struct SqlFunctionExpr {
7575
sql: String,
7676
}
7777

78+
#[derive(Debug, Clone, Serialize)]
79+
struct PatchMeasureDef {
80+
#[serde(rename = "sourceMeasure")]
81+
source_measure: String,
82+
#[serde(rename = "replaceAggregationType")]
83+
replace_aggregation_type: Option<String>,
84+
#[serde(rename = "addFilters")]
85+
add_filters: Vec<SqlFunctionExpr>,
86+
}
87+
7888
#[derive(Debug, Clone, Serialize)]
7989
#[serde(tag = "type")]
8090
enum UngroupedMemberExpr {
8191
SqlFunction(SqlFunctionExpr),
92+
PatchMeasure(PatchMeasureDef),
8293
}
8394

8495
#[derive(Debug, Clone, Serialize)]
@@ -1151,6 +1162,20 @@ impl CubeScanWrapperNode {
11511162
)
11521163
.await?;
11531164
let group_descs = extract_group_type_from_groupping_set(&group_expr)?;
1165+
1166+
let (patch_measures, aggr_expr, sql) = Self::extract_patch_measures(
1167+
schema.as_ref(),
1168+
aggr_expr,
1169+
sql,
1170+
generator.clone(),
1171+
column_remapping,
1172+
&mut next_remapper,
1173+
can_rename_columns,
1174+
push_to_cube_context,
1175+
subqueries_sql.clone(),
1176+
)
1177+
.await?;
1178+
11541179
let (aggregate, sql) = Self::generate_column_expr(
11551180
schema.clone(),
11561181
aggr_expr.clone(),
@@ -1296,6 +1321,11 @@ impl CubeScanWrapperNode {
12961321
&ungrouped_scan_node.used_cubes,
12971322
)
12981323
}))
1324+
.chain(patch_measures.into_iter().map(
1325+
|(def, cube, alias)| {
1326+
Self::patch_measure_expr(def, cube, alias)
1327+
},
1328+
))
12991329
.collect::<Result<_>>()?,
13001330
),
13011331
dimensions: Some(
@@ -1340,6 +1370,7 @@ impl CubeScanWrapperNode {
13401370
.map(|(_e, c)| c.clone())
13411371
};
13421372

1373+
// TODO handle patch measures collection here
13431374
let aliased_column = find_column(&aggr_expr, &aggregate)
13441375
.or_else(|| find_column(&projection_expr, &projection))
13451376
.or_else(|| find_column(&flat_group_expr, &group_by))
@@ -1420,6 +1451,12 @@ impl CubeScanWrapperNode {
14201451
request: load_request.clone(),
14211452
})
14221453
} else {
1454+
if !patch_measures.is_empty() {
1455+
return Err(CubeError::internal(format!(
1456+
"Unexpected patch measures for non-push-to-Cube wrapped select: {patch_measures:?}",
1457+
)));
1458+
}
1459+
14231460
let resulting_sql = generator
14241461
.get_sql_templates()
14251462
.select(
@@ -1492,6 +1529,174 @@ impl CubeScanWrapperNode {
14921529
})
14931530
}
14941531

1532+
fn get_patch_measure<'l>(
1533+
sql_query: SqlQuery,
1534+
sql_generator: Arc<dyn SqlGenerator>,
1535+
expr: &'l Expr,
1536+
push_to_cube_context: Option<&'l PushToCubeContext<'_>>,
1537+
subqueries: Arc<HashMap<String, String>>,
1538+
) -> Pin<
1539+
Box<
1540+
dyn Future<
1541+
Output = result::Result<
1542+
(Option<(PatchMeasureDef, String)>, SqlQuery),
1543+
CubeError,
1544+
>,
1545+
> + Send
1546+
+ 'l,
1547+
>,
1548+
> {
1549+
Box::pin(async move {
1550+
match expr {
1551+
Expr::Alias(inner, _alias) => {
1552+
Self::get_patch_measure(
1553+
sql_query,
1554+
sql_generator,
1555+
inner,
1556+
push_to_cube_context,
1557+
subqueries,
1558+
)
1559+
.await
1560+
}
1561+
Expr::AggregateUDF { fun, args } => {
1562+
if fun.name != PATCH_MEASURE_UDAF_NAME {
1563+
return Ok((None, sql_query));
1564+
}
1565+
1566+
let Some(PushToCubeContext {
1567+
ungrouped_scan_node,
1568+
..
1569+
}) = push_to_cube_context
1570+
else {
1571+
return Err(CubeError::internal(format!(
1572+
"Unexpected UDAF expression without push-to-Cube context: {}",
1573+
fun.name
1574+
)));
1575+
};
1576+
1577+
let (measure, aggregation, filter) = match args.as_slice() {
1578+
[measure, aggregation, filter] => (measure, aggregation, filter),
1579+
_ => {
1580+
return Err(CubeError::internal(format!(
1581+
"Unexpected number arguments for UDAF: {}, {args:?}",
1582+
fun.name
1583+
)))
1584+
}
1585+
};
1586+
1587+
let Expr::Column(measure_column) = measure else {
1588+
return Err(CubeError::internal(format!(
1589+
"First argument should be column expression: {}",
1590+
fun.name
1591+
)));
1592+
};
1593+
1594+
let aggregation = match aggregation {
1595+
Expr::Literal(ScalarValue::Utf8(Some(aggregation))) => Some(aggregation),
1596+
Expr::Literal(ScalarValue::Null) => None,
1597+
_ => {
1598+
return Err(CubeError::internal(format!(
1599+
"Second argument should be Utf8 literal expression: {}",
1600+
fun.name
1601+
)));
1602+
}
1603+
};
1604+
1605+
let (filters, sql_query) = match filter {
1606+
Expr::Literal(ScalarValue::Null) => (vec![], sql_query),
1607+
_ => {
1608+
let (filter, sql_query) = Self::generate_sql_for_expr(
1609+
sql_query,
1610+
sql_generator.clone(),
1611+
filter.clone(),
1612+
push_to_cube_context,
1613+
subqueries.clone(),
1614+
)
1615+
.await?;
1616+
(
1617+
vec![SqlFunctionExpr {
1618+
cube_params: ungrouped_scan_node.used_cubes.clone(),
1619+
sql: filter,
1620+
}],
1621+
sql_query,
1622+
)
1623+
}
1624+
};
1625+
1626+
let member =
1627+
Self::find_member_in_ungrouped_scan(ungrouped_scan_node, measure_column)?;
1628+
1629+
let MemberField::Member(member) = member else {
1630+
return Err(CubeError::internal(format!(
1631+
"First argument should reference member, not literal: {}",
1632+
fun.name
1633+
)));
1634+
};
1635+
1636+
let (cube, _member) = member.split_once('.').ok_or_else(|| {
1637+
CubeError::internal(format!("Can't parse cube name from member {member}",))
1638+
})?;
1639+
1640+
Ok((
1641+
Some((
1642+
PatchMeasureDef {
1643+
source_measure: member.clone(),
1644+
replace_aggregation_type: aggregation.cloned(),
1645+
add_filters: filters,
1646+
},
1647+
cube.to_string(),
1648+
)),
1649+
sql_query,
1650+
))
1651+
}
1652+
_ => Ok((None, sql_query)),
1653+
}
1654+
})
1655+
}
1656+
1657+
async fn extract_patch_measures(
1658+
schema: &DFSchema,
1659+
exprs: impl IntoIterator<Item = Expr>,
1660+
mut sql_query: SqlQuery,
1661+
sql_generator: Arc<dyn SqlGenerator>,
1662+
column_remapping: Option<&ColumnRemapping>,
1663+
next_remapper: &mut Remapper,
1664+
can_rename_columns: bool,
1665+
push_to_cube_context: Option<&PushToCubeContext<'_>>,
1666+
subqueries: Arc<HashMap<String, String>>,
1667+
) -> result::Result<(Vec<(PatchMeasureDef, String, String)>, Vec<Expr>, SqlQuery), CubeError>
1668+
{
1669+
let mut patches = vec![];
1670+
let mut other = vec![];
1671+
1672+
for original_expr in exprs {
1673+
let (expr, alias) = Self::remap_column_expression(
1674+
schema,
1675+
&original_expr,
1676+
column_remapping,
1677+
next_remapper,
1678+
can_rename_columns,
1679+
)?;
1680+
1681+
let (patch_def, sql_query_next) = Self::get_patch_measure(
1682+
sql_query,
1683+
sql_generator.clone(),
1684+
&expr,
1685+
push_to_cube_context,
1686+
subqueries.clone(),
1687+
)
1688+
.await?;
1689+
sql_query = sql_query_next;
1690+
if let Some((patch_def, cube)) = patch_def {
1691+
patches.push((patch_def, cube, alias));
1692+
} else {
1693+
other.push(expr);
1694+
}
1695+
}
1696+
1697+
Ok((patches, other, sql_query))
1698+
}
1699+
14951700
fn remap_column_expression(
14961701
schema: &DFSchema,
14971702
original_expr: &Expr,
@@ -1597,6 +1802,21 @@ impl CubeScanWrapperNode {
15971802
Ok(serde_json::json!(res).to_string())
15981803
}
15991804

1805+
fn patch_measure_expr(
1806+
def: PatchMeasureDef,
1807+
cube_name: String,
1808+
alias: String,
1809+
) -> Result<String> {
1810+
let res = UngroupedMemberDef {
1811+
cube_name,
1812+
alias,
1813+
expr: UngroupedMemberExpr::PatchMeasure(def),
1814+
grouping_set: None,
1815+
};
1816+
1817+
Ok(serde_json::json!(res).to_string())
1818+
}
1819+
16001820
fn generate_sql_cast_expr(
16011821
sql_generator: Arc<dyn SqlGenerator>,
16021822
inner_expr: String,
@@ -2590,6 +2810,7 @@ impl CubeScanWrapperNode {
25902810

25912811
Ok((format!("${{{member}}}"), sql_query))
25922812
}
2813+
// There's no branch for PatchMeasure, because it should generate via different path
25932814
_ => Err(DataFusionError::Internal(format!(
25942815
"Can't generate SQL for UDAF: {}",
25952816
fun.name
@@ -2760,3 +2981,38 @@ impl UserDefinedLogicalNode for CubeScanWrapperNode {
27602981
})
27612982
}
27622983
}
2984+
2985+
#[cfg(test)]
2986+
mod tests {
2987+
use super::*;
2988+
2989+
#[test]
2990+
fn test_member_expression_sql() {
2991+
insta::assert_json_snapshot!(UngroupedMemberDef {
2992+
cube_name: "cube".to_string(),
2993+
alias: "alias".to_string(),
2994+
expr: UngroupedMemberExpr::SqlFunction(SqlFunctionExpr {
2995+
cube_params: vec!["cube".to_string(), "other".to_string()],
2996+
sql: "1 + 2".to_string(),
2997+
}),
2998+
grouping_set: None,
2999+
});
3000+
}
3001+
3002+
#[test]
3003+
fn test_member_expression_patch_measure() {
3004+
insta::assert_json_snapshot!(UngroupedMemberDef {
3005+
cube_name: "cube".to_string(),
3006+
alias: "alias".to_string(),
3007+
expr: UngroupedMemberExpr::PatchMeasure(PatchMeasureDef {
3008+
source_measure: "cube.measure".to_string(),
3009+
replace_aggregation_type: None,
3010+
add_filters: vec![SqlFunctionExpr {
3011+
cube_params: vec!["cube".to_string()],
3012+
sql: "1 + 2 = 3".to_string(),
3013+
}],
3014+
}),
3015+
grouping_set: None,
3016+
});
3017+
}
3018+
}

0 commit comments

Comments
 (0)