Skip to content

Commit 1bd57fd

Browse files
committed
[WIP rename insert_replaced_measure] feat(cubesql): Rewrite incorrect aggregation function on measures under wrapper
1 parent c36d156 commit 1bd57fd

File tree

4 files changed

+203
-15
lines changed

4 files changed

+203
-15
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2781,10 +2781,7 @@ limit
27812781
#[tokio::test]
27822782
async fn test_select_error() {
27832783
let variants = vec![
2784-
(
2785-
"SELECT AVG(maxPrice) FROM KibanaSampleDataEcommerce".to_string(),
2786-
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'maxPrice' is 'MAX()' but 'AVG()' was provided. Please check logs for additional information.".to_string()),
2787-
),
2784+
// TODO are there any errors that we could test for?
27882785
];
27892786

27902787
for (input_query, expected_error) in variants.iter() {

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 122 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
compile::{
3-
engine::udf::MEASURE_UDAF_NAME,
3+
engine::udf::{MEASURE_UDAF_NAME, PATCH_MEASURE_UDAF_NAME},
44
rewrite::{
55
aggregate, alias_expr, cube_scan_wrapper, grouping_set_expr, original_expr_name,
66
rewrite,
@@ -14,15 +14,16 @@ use crate::{
1414
wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer,
1515
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct,
1616
AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn,
17-
ListType, LogicalPlanData, LogicalPlanLanguage, WrappedSelectPushToCube,
18-
WrapperReplacerContextAliasToCube, WrapperReplacerContextPushToCube,
17+
ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage,
18+
WrappedSelectPushToCube, WrapperReplacerContextAliasToCube,
19+
WrapperReplacerContextPushToCube,
1920
},
2021
},
2122
copy_flag,
2223
transport::V1CubeMetaMeasureExt,
2324
var, var_iter,
2425
};
25-
use datafusion::logical_plan::Column;
26+
use datafusion::{logical_plan::Column, scalar::ScalarValue};
2627
use egg::{Subst, Var};
2728
use std::ops::IndexMut;
2829

@@ -887,15 +888,17 @@ impl WrapperRules {
887888
if let Some(measure) =
888889
meta.find_measure_with_name(member.to_string())
889890
{
890-
if call_agg_type.is_none()
891-
|| measure.is_same_agg_type(
892-
call_agg_type.as_ref().unwrap(),
893-
disable_strict_agg_type_match,
894-
)
895-
{
891+
fn insert_regular_measure(
892+
egraph: &mut CubeEGraph,
893+
subst: &mut Subst,
894+
column: Column,
895+
alias: String,
896+
out_expr_var: Var,
897+
out_alias_var: Var,
898+
) {
896899
let column_expr_column =
897900
egraph.add(LogicalPlanLanguage::ColumnExprColumn(
898-
ColumnExprColumn(column.clone()),
901+
ColumnExprColumn(column),
899902
));
900903
let column_expr =
901904
egraph.add(LogicalPlanLanguage::ColumnExpr([
@@ -920,11 +923,119 @@ impl WrapperRules {
920923

921924
subst.insert(out_expr_var, udaf_expr);
922925

926+
let alias_expr_alias =
927+
egraph.add(LogicalPlanLanguage::AliasExprAlias(
928+
AliasExprAlias(alias),
929+
));
930+
subst.insert(out_alias_var, alias_expr_alias);
931+
}
932+
933+
fn insert_replaced_measure(
934+
egraph: &mut CubeEGraph,
935+
subst: &mut Subst,
936+
column: Column,
937+
call_agg_type: String,
938+
alias: String,
939+
out_expr_var: Var,
940+
out_alias_var: Var,
941+
) {
942+
let column_expr_column =
943+
egraph.add(LogicalPlanLanguage::ColumnExprColumn(
944+
ColumnExprColumn(column.clone()),
945+
));
946+
let column_expr =
947+
egraph.add(LogicalPlanLanguage::ColumnExpr([
948+
column_expr_column,
949+
]));
950+
let new_aggregation_value =
951+
egraph.add(LogicalPlanLanguage::LiteralExprValue(
952+
LiteralExprValue(ScalarValue::Utf8(Some(
953+
call_agg_type,
954+
))),
955+
));
956+
let new_aggregation_expr =
957+
egraph.add(LogicalPlanLanguage::LiteralExpr([
958+
new_aggregation_value,
959+
]));
960+
let add_filters_value =
961+
egraph.add(LogicalPlanLanguage::LiteralExprValue(
962+
LiteralExprValue(ScalarValue::Null),
963+
));
964+
let add_filters_expr =
965+
egraph.add(LogicalPlanLanguage::LiteralExpr([
966+
add_filters_value,
967+
]));
968+
let udaf_name_expr = egraph.add(
969+
LogicalPlanLanguage::AggregateUDFExprFun(
970+
AggregateUDFExprFun(
971+
PATCH_MEASURE_UDAF_NAME.to_string(),
972+
),
973+
),
974+
);
975+
let udaf_args_expr = egraph.add(
976+
LogicalPlanLanguage::AggregateUDFExprArgs(vec![
977+
column_expr,
978+
new_aggregation_expr,
979+
add_filters_expr,
980+
]),
981+
);
982+
let udaf_expr =
983+
egraph.add(LogicalPlanLanguage::AggregateUDFExpr(
984+
[udaf_name_expr, udaf_args_expr],
985+
));
986+
987+
subst.insert(out_expr_var, udaf_expr);
988+
923989
let alias_expr_alias =
924990
egraph.add(LogicalPlanLanguage::AliasExprAlias(
925991
AliasExprAlias(alias.clone()),
926992
));
927993
subst.insert(out_alias_var, alias_expr_alias);
994+
}
995+
996+
let Some(call_agg_type) = &call_agg_type else {
997+
// call_agg_type is None, rewrite as is
998+
insert_regular_measure(
999+
egraph,
1000+
subst,
1001+
column,
1002+
alias,
1003+
out_expr_var,
1004+
out_alias_var,
1005+
);
1006+
1007+
return true;
1008+
};
1009+
1010+
if measure.is_same_agg_type(
1011+
call_agg_type,
1012+
disable_strict_agg_type_match,
1013+
) {
1014+
insert_regular_measure(
1015+
egraph,
1016+
subst,
1017+
column,
1018+
alias,
1019+
out_expr_var,
1020+
out_alias_var,
1021+
);
1022+
1023+
return true;
1024+
}
1025+
1026+
if measure.allow_replace_agg_type(
1027+
call_agg_type,
1028+
disable_strict_agg_type_match,
1029+
) {
1030+
insert_replaced_measure(
1031+
egraph,
1032+
subst,
1033+
column,
1034+
call_agg_type.clone(),
1035+
alias,
1036+
out_expr_var,
1037+
out_alias_var,
1038+
);
9281039

9291040
return true;
9301041
}

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,3 +1465,56 @@ WHERE
14651465
assert_eq!(request.measures.unwrap().len(), 1);
14661466
assert_eq!(request.dimensions.unwrap().len(), 0);
14671467
}
1468+
1469+
/// MIN(avg_measure) should get pushed to Cube with replaced measure
1470+
#[tokio::test]
1471+
async fn wrapper_min_from_avg_measure() {
1472+
if !Rewriter::sql_push_down_enabled() {
1473+
return;
1474+
}
1475+
init_testing_logger();
1476+
1477+
let query_plan = convert_select_to_query_plan(
1478+
// language=PostgreSQL
1479+
r#"
1480+
SELECT
1481+
MIN(avgPrice)
1482+
FROM
1483+
MultiTypeCube
1484+
"#
1485+
.to_string(),
1486+
DatabaseProtocol::PostgreSQL,
1487+
)
1488+
.await;
1489+
1490+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1491+
println!(
1492+
"Physical plan: {}",
1493+
displayable(physical_plan.as_ref()).indent()
1494+
);
1495+
1496+
assert_eq!(
1497+
query_plan
1498+
.as_logical_plan()
1499+
.find_cube_scan_wrapped_sql()
1500+
.request,
1501+
TransportLoadRequestQuery {
1502+
measures: Some(vec![json!({
1503+
"cubeName": "MultiTypeCube",
1504+
"alias": "min_multitypecub",
1505+
"expr": {
1506+
"type": "PatchMeasure",
1507+
"sourceMeasure": "MultiTypeCube.avgPrice",
1508+
"replaceAggregationType": "min",
1509+
"addFilters": [],
1510+
},
1511+
"groupingSet": null,
1512+
})
1513+
.to_string(),]),
1514+
dimensions: Some(vec![]),
1515+
segments: Some(vec![]),
1516+
order: Some(vec![]),
1517+
..Default::default()
1518+
}
1519+
);
1520+
}

rust/cubesql/cubesql/src/transport/ext.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ pub trait V1CubeMetaMeasureExt {
1010

1111
fn is_same_agg_type(&self, expect_agg_type: &str, disable_strict_match: bool) -> bool;
1212

13+
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool;
14+
1315
fn get_sql_type(&self) -> ColumnType;
1416
}
1517

@@ -45,6 +47,31 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure {
4547
}
4648
}
4749

50+
// This should be aligned with BaseMeasure.preparePatchedMeasure
51+
// See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16
52+
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool {
53+
if disable_strict_match {
54+
return true;
55+
}
56+
let Some(agg_type) = &self.agg_type else {
57+
return false;
58+
};
59+
60+
match (agg_type.as_str(), query_agg_type) {
61+
(
62+
"sum" | "avg" | "min" | "max",
63+
"sum" | "avg" | "min" | "max" | "count_distinct" | "count_distinct_approx",
64+
) => true,
65+
66+
(
67+
"count_distinct" | "count_distinct_approx",
68+
"count_distinct" | "count_distinct_approx",
69+
) => true,
70+
71+
_ => false,
72+
}
73+
}
74+
4875
fn get_sql_type(&self) -> ColumnType {
4976
let from_type = match &self.r#type.to_lowercase().as_str() {
5077
&"number" => ColumnType::Double,

0 commit comments

Comments
 (0)