Skip to content

Commit faf2fa5

Browse files
committed
[WIP][FIXUP][SPLIT] rework measure filters rewrite
1 parent 6171a42 commit faf2fa5

File tree

2 files changed

+157
-19
lines changed

2 files changed

+157
-19
lines changed

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 130 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,7 @@ impl WrapperRules {
360360
PATCH_MEASURE_UDAF_NAME,
361361
vec![
362362
column_expr("?measure_column"),
363-
// TODO support doing both: changing agg type and adding filters
364-
literal_null(),
363+
"?replace_agg_type".to_string(),
365364
wrapper_pushdown_replacer(
366365
// = is a proper way to filter here:
367366
// CASE NULL WHEN ... will return null
@@ -374,7 +373,15 @@ impl WrapperRules {
374373
),
375374
"?out_measure_alias",
376375
),
377-
self.transform_filtered_measure("?aggr_expr", "?literal", "?out_measure_alias"),
376+
self.transform_filtered_measure(
377+
"?aggr_expr",
378+
"?literal",
379+
"?measure_column",
380+
"?fun",
381+
"?cube_members",
382+
"?replace_agg_type",
383+
"?out_measure_alias",
384+
),
378385
),
379386
rewrite(
380387
"wrapper-pull-up-aggregation-over-filtered-measure",
@@ -936,20 +943,31 @@ impl WrapperRules {
936943
egraph: &mut CubeEGraph,
937944
subst: &mut Subst,
938945
column: Column,
939-
call_agg_type: String,
946+
call_agg_type: Option<String>,
940947
alias: String,
941-
out_expr_var: Var,
948+
out_expr_var: Option<Var>,
949+
out_replace_agg_type: Option<Var>,
942950
out_alias_var: Var,
943951
) {
944952
let column_expr_column = egraph.add(LogicalPlanLanguage::ColumnExprColumn(
945-
ColumnExprColumn(column.clone()),
953+
ColumnExprColumn(column),
946954
));
947955
let column_expr = egraph.add(LogicalPlanLanguage::ColumnExpr([column_expr_column]));
948-
let new_aggregation_value = egraph.add(LogicalPlanLanguage::LiteralExprValue(
949-
LiteralExprValue(ScalarValue::Utf8(Some(call_agg_type))),
950-
));
956+
let new_aggregation_value = match call_agg_type {
957+
Some(call_agg_type) => egraph.add(LogicalPlanLanguage::LiteralExprValue(
958+
LiteralExprValue(ScalarValue::Utf8(Some(call_agg_type))),
959+
)),
960+
None => egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(
961+
ScalarValue::Null,
962+
))),
963+
};
951964
let new_aggregation_expr =
952965
egraph.add(LogicalPlanLanguage::LiteralExpr([new_aggregation_value]));
966+
967+
if let Some(out_replace_agg_type) = out_replace_agg_type {
968+
subst.insert(out_replace_agg_type, new_aggregation_expr);
969+
}
970+
953971
let add_filters_value = egraph.add(LogicalPlanLanguage::LiteralExprValue(
954972
LiteralExprValue(ScalarValue::Null),
955973
));
@@ -967,7 +985,9 @@ impl WrapperRules {
967985
udaf_args_expr,
968986
]));
969987

970-
subst.insert(out_expr_var, udaf_expr);
988+
if let Some(out_expr_var) = out_expr_var {
989+
subst.insert(out_expr_var, udaf_expr);
990+
}
971991

972992
let alias_expr_alias = egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
973993
alias.clone(),
@@ -1068,9 +1088,10 @@ impl WrapperRules {
10681088
egraph,
10691089
subst,
10701090
column,
1071-
call_agg_type.clone(),
1091+
Some(call_agg_type.clone()),
10721092
alias,
1073-
out_expr_var,
1093+
Some(out_expr_var),
1094+
None,
10741095
out_alias_var,
10751096
);
10761097

@@ -1129,12 +1150,23 @@ impl WrapperRules {
11291150
&self,
11301151
aggr_expr_var: &'static str,
11311152
literal_var: &'static str,
1153+
column_var: &'static str,
1154+
fun_name_var: &'static str,
1155+
cube_members_var: &'static str,
1156+
replace_agg_type_var: &'static str,
11321157
out_measure_alias_var: &'static str,
11331158
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
11341159
let aggr_expr_var = var!(aggr_expr_var);
11351160
let literal_var = var!(literal_var);
1161+
let column_var = var!(column_var);
1162+
let fun_name_var = var!(fun_name_var);
1163+
let cube_members_var = var!(cube_members_var);
1164+
let replace_agg_type_var = var!(replace_agg_type_var);
11361165
let out_measure_alias_var = var!(out_measure_alias_var);
11371166

1167+
let meta = self.meta_context.clone();
1168+
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
1169+
11381170
move |egraph, subst| {
11391171
match &egraph[subst[literal_var]].data.constant {
11401172
Some(ConstantFolding::Scalar(_)) => {
@@ -1145,17 +1177,96 @@ impl WrapperRules {
11451177
}
11461178
}
11471179

1148-
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
1149-
// TODO support both changing agg fun and add filter
1150-
11511180
let Some(alias) = original_expr_name(egraph, subst[aggr_expr_var]) else {
11521181
return false;
11531182
};
1154-
let alias_expr_alias =
1155-
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(alias)));
1156-
subst.insert(out_measure_alias_var, alias_expr_alias);
11571183

1158-
true
1184+
for fun in var_iter!(egraph[subst[fun_name_var]], AggregateFunctionExprFun)
1185+
.cloned()
1186+
.collect::<Vec<_>>()
1187+
{
1188+
let call_agg_type = MemberRules::get_agg_type(Some(&fun), false);
1189+
1190+
let column_iter = var_iter!(egraph[subst[column_var]], ColumnExprColumn)
1191+
.cloned()
1192+
.collect::<Vec<_>>();
1193+
1194+
if let Some(member_names_to_expr) = &mut egraph
1195+
.index_mut(subst[cube_members_var])
1196+
.data
1197+
.member_name_to_expr
1198+
{
1199+
for column in column_iter {
1200+
if let Some((&(Some(ref member), _, _), _)) =
1201+
LogicalPlanData::do_find_member_by_alias(
1202+
member_names_to_expr,
1203+
&column.name,
1204+
)
1205+
{
1206+
if let Some(measure) = meta.find_measure_with_name(member) {
1207+
if !measure.allow_add_filter(call_agg_type.as_deref()) {
1208+
continue;
1209+
}
1210+
1211+
let Some(call_agg_type) = &call_agg_type else {
1212+
// call_agg_type is None, rewrite as is
1213+
Self::insert_patch_measure(
1214+
egraph,
1215+
subst,
1216+
column,
1217+
None,
1218+
alias,
1219+
None,
1220+
Some(replace_agg_type_var),
1221+
out_measure_alias_var,
1222+
);
1223+
1224+
return true;
1225+
};
1226+
1227+
if measure
1228+
.is_same_agg_type(call_agg_type, disable_strict_agg_type_match)
1229+
{
1230+
Self::insert_patch_measure(
1231+
egraph,
1232+
subst,
1233+
column,
1234+
None,
1235+
alias,
1236+
None,
1237+
Some(replace_agg_type_var),
1238+
out_measure_alias_var,
1239+
);
1240+
1241+
return true;
1242+
}
1243+
1244+
if measure.allow_replace_agg_type(
1245+
call_agg_type,
1246+
disable_strict_agg_type_match,
1247+
) {
1248+
Self::insert_patch_measure(
1249+
egraph,
1250+
subst,
1251+
column,
1252+
Some(call_agg_type.clone()),
1253+
alias,
1254+
None,
1255+
Some(replace_agg_type_var),
1256+
out_measure_alias_var,
1257+
);
1258+
1259+
return true;
1260+
}
1261+
}
1262+
}
1263+
}
1264+
}
1265+
}
1266+
1267+
false
1268+
1269+
// TODO share code with Self::pushdown_measure: locate cube and measure, check that ?fun matches measure, etc
11591270
}
11601271
}
11611272
}

rust/cubesql/cubesql/src/transport/ext.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub trait V1CubeMetaMeasureExt {
1212

1313
fn allow_replace_agg_type(&self, query_agg_type: &str, disable_strict_match: bool) -> bool;
1414

15+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool;
16+
1517
fn get_sql_type(&self) -> ColumnType;
1618
}
1719

@@ -72,6 +74,31 @@ impl V1CubeMetaMeasureExt for CubeMetaMeasure {
7274
}
7375
}
7476

77+
// This should be aligned with BaseMeasure.preparePatchedMeasure
78+
// TODO proper reference
79+
// See packages/cubejs-schema-compiler/src/adapter/BaseMeasure.ts:16
80+
fn allow_add_filter(&self, query_agg_type: Option<&str>) -> bool {
81+
let Some(agg_type) = &self.agg_type else {
82+
return false;
83+
};
84+
85+
let agg_type = match query_agg_type {
86+
Some(query_agg_type) => query_agg_type,
87+
None => agg_type,
88+
};
89+
90+
match agg_type {
91+
"sum"
92+
| "avg"
93+
| "min"
94+
| "max"
95+
| "count"
96+
| "count_distinct"
97+
| "count_distinct_approx" => true,
98+
_ => false,
99+
}
100+
}
101+
75102
fn get_sql_type(&self) -> ColumnType {
76103
let from_type = match &self.r#type.to_lowercase().as_str() {
77104
&"number" => ColumnType::Double,

0 commit comments

Comments
 (0)