diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 7e03b61dc5bbd..7869d24e37488 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -2134,10 +2134,6 @@ export class BaseQuery { if (m.expressionName && !collectedMeasures.length && !m.isMemberExpression) { throw new UserError(`Subquery measure ${m.expressionName} should reference at least one member`); } - if (!collectedMeasures.length && m.isMemberExpression && m.query.allCubeNames.length > 1 && m.measureSql() === 'COUNT(*)') { - const cubeName = m.expressionCubeName ? `\`${m.expressionCubeName}\` ` : ''; - throw new UserError(`The query contains \`COUNT(*)\` expression but cube/view ${cubeName}is missing \`count\` measure`); - } if (collectedMeasures.length === 0 && m.isMemberExpression) { // `m` is member expression measure, but does not reference any other measure diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index 34862a24eb1e1..ac255039da7be 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -3824,6 +3824,21 @@ impl<'ctx, 'mem> CollectMembersVisitor<'ctx, 'mem> { Ok(()) } + + fn handle_count_rows(&mut self) -> Result<()> { + // COUNT(*) references all members in the ungrouped scan node + for member in &self.push_to_cube_context.ungrouped_scan_node.member_fields { + match member { + MemberField::Member(member) => { + self.used_members.insert(member.member.clone()); + } + MemberField::Literal(_) => { + // Do nothing + } + } + } + Ok(()) + } } impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> { @@ -3832,6 +3847,13 @@ impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> { Expr::Column(ref c) => { self.handle_column(c)?; } + Expr::AggregateFunction { + fun: AggregateFunction::Count, + args, + .. + } if args.len() == 1 && matches!(args[0], Expr::Literal(_)) => { + self.handle_count_rows()?; + } _ => {} } diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 595c7b958af2f..455a60c7eed3c 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -17468,4 +17468,40 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATE_DIFF('day', ")); } + + #[tokio::test] + async fn test_count_over_joined_cubes() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT COUNT(*) + FROM ( + SELECT + t1.id AS id, + t2.read AS read + FROM KibanaSampleDataEcommerce t1 + LEFT JOIN Logs t2 ON t1.__cubeJoinField = t2.__cubeJoinField + ) t + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains("COUNT(*)")); + assert!(sql.contains("KibanaSampleDataEcommerce")); + assert!(sql.contains("Logs")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs index 06cec9249abbc..81721abf9575b 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs @@ -2062,7 +2062,7 @@ impl MemberRules { ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { let member_pushdown_replacer_alias_to_cube_var = var!(member_pushdown_replacer_alias_to_cube_var); - let column_var = match column_to_search { + let column_var = match &column_to_search { ColumnToSearch::Var(column_var) => Some(var!(column_var)), ColumnToSearch::DefaultCount => None, }; @@ -2088,6 +2088,17 @@ impl MemberRules { }); for alias_to_cube in alias_to_cubes { + // Do not push down COUNT(*) if there are joined cubes + if matches!(column_to_search, ColumnToSearch::DefaultCount) { + let joined_cubes = alias_to_cube + .iter() + .map(|(_, cube_name)| cube_name) + .collect::>(); + if joined_cubes.len() > 1 { + continue; + } + } + let column_iter = match column_var { Some(column_var) => var_iter!(egraph[subst[column_var]], ColumnExprColumn) .cloned() diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs index 85aba85addbfb..203cd1b2b578c 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs @@ -17,7 +17,8 @@ use crate::{ wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct, AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn, ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage, - WrappedSelectPushToCube, WrapperReplacerContextPushToCube, + WrappedSelectPushToCube, WrapperReplacerContextAliasToCube, + WrapperReplacerContextPushToCube, }, }, copy_flag, @@ -26,7 +27,7 @@ use crate::{ }; use datafusion::{logical_plan::Column, scalar::ScalarValue}; use egg::{Subst, Var}; -use std::ops::IndexMut; +use std::{collections::HashSet, ops::IndexMut}; impl WrapperRules { pub fn aggregate_rules(&self, rules: &mut Vec) { @@ -290,6 +291,7 @@ impl WrapperRules { "?cube_members", "?out_measure_expr", "?out_measure_alias", + "?alias_to_cube", ), ) }, @@ -1035,6 +1037,7 @@ impl WrapperRules { cube_members_var: Var, out_expr_var: Var, out_alias_var: Var, + alias_to_cube_var: Var, meta: &MetaContext, disable_strict_agg_type_match: bool, ) -> bool { @@ -1042,90 +1045,111 @@ impl WrapperRules { return false; }; - for fun in fun_name_var - .map(|fun_var| { - var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun) - .map(|fun| Some(fun.clone())) - .collect() - }) - .unwrap_or(vec![None]) + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperReplacerContextAliasToCube + ) + .cloned() + .collect::>() { - for distinct in distinct_var - .map(|distinct_var| { - var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct) - .map(|d| *d) + // Do not push down COUNT(*) if there are joined cubes + let is_count_rows = column_var.is_none(); + if is_count_rows { + let joined_cubes = alias_to_cube + .iter() + .map(|(_, cube_name)| cube_name) + .collect::>(); + if joined_cubes.len() > 1 { + continue; + } + } + + for fun in fun_name_var + .map(|fun_var| { + var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun) + .map(|fun| Some(fun.clone())) .collect() }) - .unwrap_or(vec![false]) + .unwrap_or(vec![None]) { - let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct); + for distinct in distinct_var + .map(|distinct_var| { + var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct) + .map(|d| *d) + .collect() + }) + .unwrap_or(vec![false]) + { + let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct); - let column_iter = if let Some(column_var) = column_var { - var_iter!(egraph[subst[column_var]], ColumnExprColumn) - .cloned() - .collect() - } else { - vec![Column::from_name(MemberRules::default_count_measure_name())] - }; + let column_iter = if let Some(column_var) = column_var { + var_iter!(egraph[subst[column_var]], ColumnExprColumn) + .cloned() + .collect() + } else { + vec![Column::from_name(MemberRules::default_count_measure_name())] + }; - if let Some(member_names_to_expr) = &mut egraph - .index_mut(subst[cube_members_var]) - .data - .member_name_to_expr - { - for column in column_iter { - if let Some((&(Some(ref member), _, _), _)) = - LogicalPlanData::do_find_member_by_alias( - member_names_to_expr, - &column.name, - ) - { - if let Some(measure) = meta.find_measure_with_name(member) { - let Some(call_agg_type) = &call_agg_type else { - // call_agg_type is None, rewrite as is - Self::insert_regular_measure( - egraph, - subst, - column, - alias, - out_expr_var, - out_alias_var, - ); + if let Some(member_names_to_expr) = &mut egraph + .index_mut(subst[cube_members_var]) + .data + .member_name_to_expr + { + for column in column_iter { + if let Some((&(Some(ref member), _, _), _)) = + LogicalPlanData::do_find_member_by_alias( + member_names_to_expr, + &column.name, + ) + { + if let Some(measure) = meta.find_measure_with_name(member) { + let Some(call_agg_type) = &call_agg_type else { + // call_agg_type is None, rewrite as is + Self::insert_regular_measure( + egraph, + subst, + column, + alias, + out_expr_var, + out_alias_var, + ); - return true; - }; + return true; + }; - if measure - .is_same_agg_type(call_agg_type, disable_strict_agg_type_match) - { - Self::insert_regular_measure( - egraph, - subst, - column, - alias, - out_expr_var, - out_alias_var, - ); + if measure.is_same_agg_type( + call_agg_type, + disable_strict_agg_type_match, + ) { + Self::insert_regular_measure( + egraph, + subst, + column, + alias, + out_expr_var, + out_alias_var, + ); - return true; - } + return true; + } - if measure.allow_replace_agg_type( - call_agg_type, - disable_strict_agg_type_match, - ) { - Self::insert_patch_measure( - egraph, - subst, - column, - Some(call_agg_type.clone()), - alias, - Some(out_expr_var), - None, - out_alias_var, - ); + if measure.allow_replace_agg_type( + call_agg_type, + disable_strict_agg_type_match, + ) { + Self::insert_patch_measure( + egraph, + subst, + column, + Some(call_agg_type.clone()), + alias, + Some(out_expr_var), + None, + out_alias_var, + ); - return true; + return true; + } } } } @@ -1148,6 +1172,7 @@ impl WrapperRules { cube_members_var: &'static str, out_expr_var: &'static str, out_alias_var: &'static str, + alias_to_cube_var: &'static str, ) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool { let original_expr_var = var!(original_expr_var); let column_var = column_var.map(|v| var!(v)); @@ -1157,6 +1182,7 @@ impl WrapperRules { let cube_members_var = var!(cube_members_var); let out_expr_var = var!(out_expr_var); let out_alias_var = var!(out_alias_var); + let alias_to_cube_var = var!(alias_to_cube_var); let meta = self.meta_context.clone(); let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match(); move |egraph, subst| { @@ -1170,6 +1196,7 @@ impl WrapperRules { cube_members_var, out_expr_var, out_alias_var, + alias_to_cube_var, &meta, disable_strict_agg_type_match, )