|
| 1 | +use crate::{ |
| 2 | + compile::rewrite::{ |
| 3 | + aggregate, |
| 4 | + analysis::LogicalPlanAnalysis, |
| 5 | + column_name_to_member_vec, cube_scan_wrapper, original_expr_name, |
| 6 | + rules::{members::MemberRules, wrapper::WrapperRules}, |
| 7 | + transforming_chain_rewrite, transforming_rewrite, wrapped_select, |
| 8 | + wrapped_select_filter_expr_empty_tail, wrapped_select_having_expr_empty_tail, |
| 9 | + wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail, |
| 10 | + wrapped_select_projection_expr_empty_tail, wrapper_pullup_replacer, |
| 11 | + wrapper_pushdown_replacer, AggregateFunctionExprDistinct, AggregateFunctionExprFun, |
| 12 | + AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrappedSelectUngrouped, |
| 13 | + WrapperPullupReplacerUngrouped, |
| 14 | + }, |
| 15 | + transport::V1CubeMetaMeasureExt, |
| 16 | + var, var_iter, |
| 17 | +}; |
| 18 | +use datafusion::logical_plan::Column; |
| 19 | +use egg::{EGraph, Rewrite, Subst}; |
| 20 | + |
| 21 | +impl WrapperRules { |
| 22 | + pub fn aggregate_rules( |
| 23 | + &self, |
| 24 | + rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>, |
| 25 | + ) { |
| 26 | + rules.extend(vec![transforming_rewrite( |
| 27 | + "wrapper-push-down-aggregate-to-cube-scan", |
| 28 | + aggregate( |
| 29 | + cube_scan_wrapper( |
| 30 | + wrapper_pullup_replacer( |
| 31 | + "?cube_scan_input", |
| 32 | + "?alias_to_cube", |
| 33 | + "?ungrouped", |
| 34 | + "?cube_members", |
| 35 | + ), |
| 36 | + "CubeScanWrapperFinalized:false", |
| 37 | + ), |
| 38 | + "?group_expr", |
| 39 | + "?aggr_expr", |
| 40 | + "AggregateSplit:false", |
| 41 | + ), |
| 42 | + cube_scan_wrapper( |
| 43 | + wrapped_select( |
| 44 | + "WrappedSelectSelectType:Aggregate", |
| 45 | + wrapper_pullup_replacer( |
| 46 | + wrapped_select_projection_expr_empty_tail(), |
| 47 | + "?alias_to_cube", |
| 48 | + "?ungrouped", |
| 49 | + "?cube_members", |
| 50 | + ), |
| 51 | + wrapper_pushdown_replacer( |
| 52 | + "?group_expr", |
| 53 | + "?alias_to_cube", |
| 54 | + "?ungrouped", |
| 55 | + "?cube_members", |
| 56 | + ), |
| 57 | + wrapper_pushdown_replacer( |
| 58 | + "?aggr_expr", |
| 59 | + "?alias_to_cube", |
| 60 | + "?ungrouped", |
| 61 | + "?cube_members", |
| 62 | + ), |
| 63 | + wrapper_pullup_replacer( |
| 64 | + "?cube_scan_input", |
| 65 | + "?alias_to_cube", |
| 66 | + "?ungrouped", |
| 67 | + "?cube_members", |
| 68 | + ), |
| 69 | + wrapped_select_joins_empty_tail(), |
| 70 | + wrapped_select_filter_expr_empty_tail(), |
| 71 | + wrapped_select_having_expr_empty_tail(), |
| 72 | + "WrappedSelectLimit:None", |
| 73 | + "WrappedSelectOffset:None", |
| 74 | + wrapper_pullup_replacer( |
| 75 | + wrapped_select_order_expr_empty_tail(), |
| 76 | + "?alias_to_cube", |
| 77 | + "?ungrouped", |
| 78 | + "?cube_members", |
| 79 | + ), |
| 80 | + "WrappedSelectAlias:None", |
| 81 | + "?select_ungrouped", |
| 82 | + ), |
| 83 | + "CubeScanWrapperFinalized:false", |
| 84 | + ), |
| 85 | + self.transform_aggregate("?ungrouped", "?select_ungrouped"), |
| 86 | + )]); |
| 87 | + |
| 88 | + // TODO add flag to disable dimension rules |
| 89 | + MemberRules::measure_rewrites( |
| 90 | + rules, |
| 91 | + |name: &'static str, |
| 92 | + aggr_expr: String, |
| 93 | + _measure_expr: String, |
| 94 | + fun_name: Option<&'static str>, |
| 95 | + distinct: Option<&'static str>, |
| 96 | + cast_data_type: Option<&'static str>, |
| 97 | + column: Option<&'static str>| { |
| 98 | + transforming_chain_rewrite( |
| 99 | + &format!("wrapper-{}", name), |
| 100 | + wrapper_pushdown_replacer( |
| 101 | + "?aggr_expr", |
| 102 | + "?alias_to_cube", |
| 103 | + "WrapperPullupReplacerUngrouped:true", |
| 104 | + "?cube_members", |
| 105 | + ), |
| 106 | + vec![("?aggr_expr", aggr_expr)], |
| 107 | + wrapper_pullup_replacer( |
| 108 | + "?measure", |
| 109 | + "?alias_to_cube", |
| 110 | + "WrapperPullupReplacerUngrouped:true", |
| 111 | + "?cube_members", |
| 112 | + ), |
| 113 | + self.pushdown_measure( |
| 114 | + "?aggr_expr", |
| 115 | + column, |
| 116 | + fun_name, |
| 117 | + distinct, |
| 118 | + cast_data_type, |
| 119 | + "?cube_members", |
| 120 | + "?measure", |
| 121 | + ), |
| 122 | + ) |
| 123 | + }, |
| 124 | + ); |
| 125 | + |
| 126 | + Self::list_pushdown_pullup_rules( |
| 127 | + rules, |
| 128 | + "wrapper-aggregate-aggr-expr", |
| 129 | + "AggregateAggrExpr", |
| 130 | + "WrappedSelectAggrExpr", |
| 131 | + ); |
| 132 | + |
| 133 | + Self::list_pushdown_pullup_rules( |
| 134 | + rules, |
| 135 | + "wrapper-aggregate-group-expr", |
| 136 | + "AggregateGroupExpr", |
| 137 | + "WrappedSelectGroupExpr", |
| 138 | + ); |
| 139 | + } |
| 140 | + |
| 141 | + fn transform_aggregate( |
| 142 | + &self, |
| 143 | + ungrouped_var: &'static str, |
| 144 | + select_ungrouped_var: &'static str, |
| 145 | + ) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool { |
| 146 | + let ungrouped_var = var!(ungrouped_var); |
| 147 | + let select_ungrouped_var = var!(select_ungrouped_var); |
| 148 | + move |egraph, subst| { |
| 149 | + for ungrouped in |
| 150 | + var_iter!(egraph[subst[ungrouped_var]], WrapperPullupReplacerUngrouped).cloned() |
| 151 | + { |
| 152 | + subst.insert( |
| 153 | + select_ungrouped_var, |
| 154 | + egraph.add(LogicalPlanLanguage::WrappedSelectUngrouped( |
| 155 | + WrappedSelectUngrouped(ungrouped), |
| 156 | + )), |
| 157 | + ); |
| 158 | + return true; |
| 159 | + } |
| 160 | + false |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + fn pushdown_measure( |
| 165 | + &self, |
| 166 | + original_expr_var: &'static str, |
| 167 | + column_var: Option<&'static str>, |
| 168 | + fun_name_var: Option<&'static str>, |
| 169 | + distinct_var: Option<&'static str>, |
| 170 | + // TODO support cast push downs |
| 171 | + _cast_data_type_var: Option<&'static str>, |
| 172 | + cube_members_var: &'static str, |
| 173 | + measure_out_var: &'static str, |
| 174 | + ) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool { |
| 175 | + let original_expr_var = var!(original_expr_var); |
| 176 | + let column_var = column_var.map(|v| var!(v)); |
| 177 | + let fun_name_var = fun_name_var.map(|v| var!(v)); |
| 178 | + let distinct_var = distinct_var.map(|v| var!(v)); |
| 179 | + // let cast_data_type_var = cast_data_type_var.map(|v| var!(v)); |
| 180 | + let cube_members_var = var!(cube_members_var); |
| 181 | + let measure_out_var = var!(measure_out_var); |
| 182 | + let cube_context = self.cube_context.clone(); |
| 183 | + move |egraph, subst| { |
| 184 | + if let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) { |
| 185 | + for fun in fun_name_var |
| 186 | + .map(|fun_var| { |
| 187 | + var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun) |
| 188 | + .map(|fun| Some(fun)) |
| 189 | + .collect() |
| 190 | + }) |
| 191 | + .unwrap_or(vec![None]) |
| 192 | + { |
| 193 | + for distinct in distinct_var |
| 194 | + .map(|distinct_var| { |
| 195 | + var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct) |
| 196 | + .map(|d| *d) |
| 197 | + .collect() |
| 198 | + }) |
| 199 | + .unwrap_or(vec![false]) |
| 200 | + { |
| 201 | + let call_agg_type = MemberRules::get_agg_type(fun, distinct); |
| 202 | + |
| 203 | + let column_iter = if let Some(column_var) = column_var { |
| 204 | + var_iter!(egraph[subst[column_var]], ColumnExprColumn) |
| 205 | + .cloned() |
| 206 | + .collect() |
| 207 | + } else { |
| 208 | + vec![Column::from_name(MemberRules::default_count_measure_name())] |
| 209 | + }; |
| 210 | + |
| 211 | + if let Some(member_name_to_expr) = egraph[subst[cube_members_var]] |
| 212 | + .data |
| 213 | + .member_name_to_expr |
| 214 | + .clone() |
| 215 | + { |
| 216 | + let column_name_to_member_name = |
| 217 | + column_name_to_member_vec(member_name_to_expr); |
| 218 | + for column in column_iter { |
| 219 | + if let Some((_, Some(member))) = column_name_to_member_name |
| 220 | + .iter() |
| 221 | + .find(|(cn, _)| cn == &column.name) |
| 222 | + { |
| 223 | + if let Some(measure) = |
| 224 | + cube_context.meta.find_measure_with_name(member.to_string()) |
| 225 | + { |
| 226 | + if call_agg_type.is_none() |
| 227 | + || measure |
| 228 | + .is_same_agg_type(call_agg_type.as_ref().unwrap()) |
| 229 | + { |
| 230 | + let column_expr_column = |
| 231 | + egraph.add(LogicalPlanLanguage::ColumnExprColumn( |
| 232 | + ColumnExprColumn(column.clone()), |
| 233 | + )); |
| 234 | + |
| 235 | + let column_expr = |
| 236 | + egraph.add(LogicalPlanLanguage::ColumnExpr([ |
| 237 | + column_expr_column, |
| 238 | + ])); |
| 239 | + let alias_expr_alias = |
| 240 | + egraph.add(LogicalPlanLanguage::AliasExprAlias( |
| 241 | + AliasExprAlias(alias.clone()), |
| 242 | + )); |
| 243 | + |
| 244 | + let alias_expr = |
| 245 | + egraph.add(LogicalPlanLanguage::AliasExpr([ |
| 246 | + column_expr, |
| 247 | + alias_expr_alias, |
| 248 | + ])); |
| 249 | + |
| 250 | + subst.insert(measure_out_var, alias_expr); |
| 251 | + |
| 252 | + return true; |
| 253 | + } |
| 254 | + } |
| 255 | + } |
| 256 | + } |
| 257 | + } |
| 258 | + } |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + false |
| 263 | + } |
| 264 | + } |
| 265 | +} |
0 commit comments