Skip to content

Commit 0f79ca3

Browse files
committed
feat(cubesql): Rewrite incorrect aggregation function on measures under wrapper
1 parent 5c5915c commit 0f79ca3

File tree

4 files changed

+206
-18
lines changed

4 files changed

+206
-18
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2783,17 +2783,14 @@ limit
27832783

27842784
#[tokio::test]
27852785
async fn test_select_error() {
2786-
let variants = [
2787-
(
2788-
"SELECT AVG(maxPrice) FROM KibanaSampleDataEcommerce".to_string(),
2789-
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'maxPrice' is 'MAX()' but 'AVG()' was provided. Please check logs for additional information.".to_string()),
2790-
),
2786+
let variants: &[(&str, _)] = &[
2787+
// TODO are there any errors that we could test for?
27912788
];
27922789

2793-
for (input_query, expected_error) in variants.iter() {
2790+
for (input_query, expected_error) in variants {
27942791
let meta = get_test_tenant_ctx();
27952792
let query = convert_sql_to_cube_query(
2796-
&input_query,
2793+
input_query,
27972794
meta.clone(),
27982795
get_test_session(DatabaseProtocol::PostgreSQL, meta).await,
27992796
)

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
compile::{
3-
engine::udf::MEASURE_UDAF_NAME,
3+
engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
44
rewrite::{
55
aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
66
rewrite,
@@ -14,15 +14,16 @@ use crate::{
1414
wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer,
1515
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct,
1616
AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn,
17-
ListType, LogicalPlanData, LogicalPlanLanguage, WrappedSelectPushToCube,
18-
WrapperReplacerContextAliasToCube, WrapperReplacerContextPushToCube,
17+
ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage,
18+
WrappedSelectPushToCube, WrapperReplacerContextAliasToCube,
19+
WrapperReplacerContextPushToCube,
1920
},
2021
},
2122
copy_flag,
2223
transport::V1CubeMetaMeasureExt,
2324
var, var_iter,
2425
};
25-
use datafusion::logical_plan::Column;
26+
use datafusion::{logical_plan::Column, scalar::ScalarValue};
2627
use egg::{Subst, Var};
2728
use std::ops::IndexMut;
2829

@@ -885,15 +886,17 @@ impl WrapperRules {
885886
)
886887
{
887888
if let Some(measure) = meta.find_measure_with_name(member) {
888-
if call_agg_type.is_none()
889-
|| measure.is_same_agg_type(
890-
call_agg_type.as_ref().unwrap(),
891-
disable_strict_agg_type_match,
892-
)
893-
{
889+
fn insert_regular_measure(
890+
egraph: &mut CubeEGraph,
891+
subst: &mut Subst,
892+
column: Column,
893+
alias: String,
894+
out_expr_var: Var,
895+
out_alias_var: Var,
896+
) {
894897
let column_expr_column =
895898
egraph.add(LogicalPlanLanguage::ColumnExprColumn(
896-
ColumnExprColumn(column.clone()),
899+
ColumnExprColumn(column),
897900
));
898901
let column_expr =
899902
egraph.add(LogicalPlanLanguage::ColumnExpr([
@@ -918,11 +921,119 @@ impl WrapperRules {
918921

919922
subst.insert(out_expr_var, udaf_expr);
920923

924+
let alias_expr_alias =
925+
egraph.add(LogicalPlanLanguage::AliasExprAlias(
926+
AliasExprAlias(alias),
927+
));
928+
subst.insert(out_alias_var, alias_expr_alias);
929+
}
930+
931+
fn insert_patch_measure(
932+
egraph: &mut CubeEGraph,
933+
subst: &mut Subst,
934+
column: Column,
935+
call_agg_type: String,
936+
alias: String,
937+
out_expr_var: Var,
938+
out_alias_var: Var,
939+
) {
940+
let column_expr_column =
941+
egraph.add(LogicalPlanLanguage::ColumnExprColumn(
942+
ColumnExprColumn(column.clone()),
943+
));
944+
let column_expr =
945+
egraph.add(LogicalPlanLanguage::ColumnExpr([
946+
column_expr_column,
947+
]));
948+
let new_aggregation_value =
949+
egraph.add(LogicalPlanLanguage::LiteralExprValue(
950+
LiteralExprValue(ScalarValue::Utf8(Some(
951+
call_agg_type,
952+
))),
953+
));
954+
let new_aggregation_expr =
955+
egraph.add(LogicalPlanLanguage::LiteralExpr([
956+
new_aggregation_value,
957+
]));
958+
let add_filters_value =
959+
egraph.add(LogicalPlanLanguage::LiteralExprValue(
960+
LiteralExprValue(ScalarValue::Null),
961+
));
962+
let add_filters_expr =
963+
egraph.add(LogicalPlanLanguage::LiteralExpr([
964+
add_filters_value,
965+
]));
966+
let udaf_name_expr = egraph.add(
967+
LogicalPlanLanguage::AggregateUDFExprFun(
968+
AggregateUDFExprFun(
969+
PATCH_MEASURE_UDAF_NAME.to_string(),
970+
),
971+
),
972+
);
973+
let udaf_args_expr = egraph.add(
974+
LogicalPlanLanguage::AggregateUDFExprArgs(vec![
975+
column_expr,
976+
new_aggregation_expr,
977+
add_filters_expr,
978+
]),
979+
);
980+
let udaf_expr =
981+
egraph.add(LogicalPlanLanguage::AggregateUDFExpr(
982+
[udaf_name_expr, udaf_args_expr],
983+
));
984+
985+
subst.insert(out_expr_var, udaf_expr);
986+
921987
let alias_expr_alias =
922988
egraph.add(LogicalPlanLanguage::AliasExprAlias(
923989
AliasExprAlias(alias.clone()),
924990
));
925991
subst.insert(out_alias_var, alias_expr_alias);
992+
}
993+
994+
let Some(call_agg_type) = &call_agg_type else {
995+
// call_agg_type is None, rewrite as is
996+
insert_regular_measure(
997+
egraph,
998+
subst,
999+
column,
1000+
alias,
1001+
out_expr_var,
1002+
out_alias_var,
1003+
);
1004+
1005+
return true;
1006+
};
1007+
1008+
if measure.is_same_agg_type(
1009+
call_agg_type,
1010+
disable_strict_agg_type_match,
1011+
) {
1012+
insert_regular_measure(
1013+
egraph,
1014+
subst,
1015+
column,
1016+
alias,
1017+
out_expr_var,
1018+
out_alias_var,
1019+
);
1020+
1021+
return true;
1022+
}
1023+
1024+
if measure.allow_replace_agg_type(
1025+
call_agg_type,
1026+
disable_strict_agg_type_match,
1027+
) {
1028+
insert_patch_measure(
1029+
egraph,
1030+
subst,
1031+
column,
1032+
call_agg_type.clone(),
1033+
alias,
1034+
out_expr_var,
1035+
out_alias_var,
1036+
);
9261037

9271038
return true;
9281039
}

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,3 +1748,56 @@ async fn wrapper_dimension_agg_where_false() {
17481748
assert!(!sql.contains(r#""limit""#));
17491749
assert!(sql.contains("LIMIT 50000"));
17501750
}
1751+
1752+
/// MIN(avg_measure) should get pushed to Cube with replaced measure
1753+
#[tokio::test]
1754+
async fn wrapper_min_from_avg_measure() {
1755+
if !Rewriter::sql_push_down_enabled() {
1756+
return;
1757+
}
1758+
init_testing_logger();
1759+
1760+
let query_plan = convert_select_to_query_plan(
1761+
// language=PostgreSQL
1762+
r#"
1763+
SELECT
1764+
MIN(avgPrice)
1765+
FROM
1766+
MultiTypeCube
1767+
"#
1768+
.to_string(),
1769+
DatabaseProtocol::PostgreSQL,
1770+
)
1771+
.await;
1772+
1773+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1774+
println!(
1775+
"Physical plan: {}",
1776+
displayable(physical_plan.as_ref()).indent()
1777+
);
1778+
1779+
assert_eq!(
1780+
query_plan
1781+
.as_logical_plan()
1782+
.find_cube_scan_wrapped_sql()
1783+
.request,
1784+
TransportLoadRequestQuery {
1785+
measures: Some(vec![json!({
1786+
"cubeName": "MultiTypeCube",
1787+
"alias": "min_multitypecub",
1788+
"expr": {
1789+
"type": "PatchMeasure",
1790+
"sourceMeasure": "MultiTypeCube.avgPrice",
1791+
"replaceAggregationType": "min",
1792+
"addFilters": [],
1793+
},
1794+
"groupingSet": null,
1795+
})
1796+
.to_string(),]),
1797+
dimensions: Some(vec![]),
1798+
segments: Some(vec![]),
1799+
order: Some(vec![]),
1800+
..Default::default()
1801+
}
1802+
);
1803+
}

rust/cubesql/cubesql/src/transport/ext.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ pub trait V1CubeMetaMeasureExt {
1010

1111
fn is_same_agg_type(&self, expect_agg_type: &str, disable_strict_match: bool) -> bool;
1212

13+
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool;
14+
1315
fn get_sql_type(&self) -> ColumnType;
1416
}
1517

@@ -45,6 +47,31 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure {
4547
}
4648
}
4749

50+
// This should be aligned with BaseMeasure.preparePatchedMeasure
51+
// See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16
52+
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool {
53+
if disable_strict_match {
54+
return true;
55+
}
56+
let Some(agg_type) = &self.agg_type else {
57+
return false;
58+
};
59+
60+
match (agg_type.as_str(), query_agg_type) {
61+
(
62+
"sum" | "avg" | "min" | "max",
63+
"sum" | "avg" | "min" | "max" | "count_distinct" | "count_distinct_approx",
64+
) => true,
65+
66+
(
67+
"count_distinct" | "count_distinct_approx",
68+
"count_distinct" | "count_distinct_approx",
69+
) => true,
70+
71+
_ => false,
72+
}
73+
}
74+
4875
fn get_sql_type(&self) -> ColumnType {
4976
let from_type = match &self.r#type.to_lowercase().as_str() {
5077
&"number" => ColumnType::Double,

0 commit comments

Comments
 (0)