Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,6 @@ export class BaseQuery {
if (m.expressionName && !collectedMeasures.length && !m.isMemberExpression) {
throw new UserError(`Subquery measure ${m.expressionName} should reference at least one member`);
}
if (!collectedMeasures.length && m.isMemberExpression && m.query.allCubeNames.length > 1 && m.measureSql() === 'COUNT(*)') {
const cubeName = m.expressionCubeName ? `\`${m.expressionCubeName}\` ` : '';
throw new UserError(`The query contains \`COUNT(*)\` expression but cube/view ${cubeName}is missing \`count\` measure`);
}

if (collectedMeasures.length === 0 && m.isMemberExpression) {
// `m` is member expression measure, but does not reference any other measure
Expand Down
22 changes: 22 additions & 0 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3824,6 +3824,21 @@ impl<'ctx, 'mem> CollectMembersVisitor<'ctx, 'mem> {

Ok(())
}

fn handle_count_rows(&mut self) -> Result<()> {
// COUNT(*) references all members in the ungrouped scan node
for member in &self.push_to_cube_context.ungrouped_scan_node.member_fields {
match member {
MemberField::Member(member) => {
self.used_members.insert(member.member.clone());
}
MemberField::Literal(_) => {
// Do nothing
}
}
}
Ok(())
}
}

impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> {
Expand All @@ -3832,6 +3847,13 @@ impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> {
Expr::Column(ref c) => {
self.handle_column(c)?;
}
Expr::AggregateFunction {
fun: AggregateFunction::Count,
args,
..
} if args.len() == 1 && matches!(args[0], Expr::Literal(_)) => {
self.handle_count_rows()?;
}
_ => {}
}

Expand Down
36 changes: 36 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17468,4 +17468,40 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
assert!(sql.contains("DATE_DIFF('day', "));
}

#[tokio::test]
async fn test_count_over_joined_cubes() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
r#"
SELECT COUNT(*)
FROM (
SELECT
t1.id AS id,
t2.read AS read
FROM KibanaSampleDataEcommerce t1
LEFT JOIN Logs t2 ON t1.__cubeJoinField = t2.__cubeJoinField
) t
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
assert!(sql.contains("COUNT(*)"));
assert!(sql.contains("KibanaSampleDataEcommerce"));
assert!(sql.contains("Logs"));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}
}
13 changes: 12 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ impl MemberRules {
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
let member_pushdown_replacer_alias_to_cube_var =
var!(member_pushdown_replacer_alias_to_cube_var);
let column_var = match column_to_search {
let column_var = match &column_to_search {
ColumnToSearch::Var(column_var) => Some(var!(column_var)),
ColumnToSearch::DefaultCount => None,
};
Expand All @@ -2088,6 +2088,17 @@ impl MemberRules {
});

for alias_to_cube in alias_to_cubes {
// Do not push down COUNT(*) if there are joined cubes
if matches!(column_to_search, ColumnToSearch::DefaultCount) {
let joined_cubes = alias_to_cube
.iter()
.map(|(_, cube_name)| cube_name)
.collect::<HashSet<_>>();
if joined_cubes.len() > 1 {
continue;
}
}

let column_iter = match column_var {
Some(column_var) => var_iter!(egraph[subst[column_var]], ColumnExprColumn)
.cloned()
Expand Down
177 changes: 102 additions & 75 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use crate::{
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct,
AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn,
ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage,
WrappedSelectPushToCube, WrapperReplacerContextPushToCube,
WrappedSelectPushToCube, WrapperReplacerContextAliasToCube,
WrapperReplacerContextPushToCube,
},
},
copy_flag,
Expand All @@ -26,7 +27,7 @@ use crate::{
};
use datafusion::{logical_plan::Column, scalar::ScalarValue};
use egg::{Subst, Var};
use std::ops::IndexMut;
use std::{collections::HashSet, ops::IndexMut};

impl WrapperRules {
pub fn aggregate_rules(&self, rules: &mut Vec<CubeRewrite>) {
Expand Down Expand Up @@ -290,6 +291,7 @@ impl WrapperRules {
"?cube_members",
"?out_measure_expr",
"?out_measure_alias",
"?alias_to_cube",
),
)
},
Expand Down Expand Up @@ -1035,97 +1037,119 @@ impl WrapperRules {
cube_members_var: Var,
out_expr_var: Var,
out_alias_var: Var,
alias_to_cube_var: Var,
meta: &MetaContext,
disable_strict_agg_type_match: bool,
) -> bool {
let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) else {
return false;
};

for fun in fun_name_var
.map(|fun_var| {
var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun)
.map(|fun| Some(fun.clone()))
.collect()
})
.unwrap_or(vec![None])
for alias_to_cube in var_iter!(
egraph[subst[alias_to_cube_var]],
WrapperReplacerContextAliasToCube
)
.cloned()
.collect::<Vec<_>>()
{
for distinct in distinct_var
.map(|distinct_var| {
var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct)
.map(|d| *d)
// Do not push down COUNT(*) if there are joined cubes
let is_count_rows = column_var.is_none();
if is_count_rows {
let joined_cubes = alias_to_cube
.iter()
.map(|(_, cube_name)| cube_name)
.collect::<HashSet<_>>();
if joined_cubes.len() > 1 {
continue;
}
}

for fun in fun_name_var
.map(|fun_var| {
var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun)
.map(|fun| Some(fun.clone()))
.collect()
})
.unwrap_or(vec![false])
.unwrap_or(vec![None])
{
let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct);
for distinct in distinct_var
.map(|distinct_var| {
var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct)
.map(|d| *d)
.collect()
})
.unwrap_or(vec![false])
{
let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct);

let column_iter = if let Some(column_var) = column_var {
var_iter!(egraph[subst[column_var]], ColumnExprColumn)
.cloned()
.collect()
} else {
vec![Column::from_name(MemberRules::default_count_measure_name())]
};
let column_iter = if let Some(column_var) = column_var {
var_iter!(egraph[subst[column_var]], ColumnExprColumn)
.cloned()
.collect()
} else {
vec![Column::from_name(MemberRules::default_count_measure_name())]
};

if let Some(member_names_to_expr) = &mut egraph
.index_mut(subst[cube_members_var])
.data
.member_name_to_expr
{
for column in column_iter {
if let Some((&(Some(ref member), _, _), _)) =
LogicalPlanData::do_find_member_by_alias(
member_names_to_expr,
&column.name,
)
{
if let Some(measure) = meta.find_measure_with_name(member) {
let Some(call_agg_type) = &call_agg_type else {
// call_agg_type is None, rewrite as is
Self::insert_regular_measure(
egraph,
subst,
column,
alias,
out_expr_var,
out_alias_var,
);
if let Some(member_names_to_expr) = &mut egraph
.index_mut(subst[cube_members_var])
.data
.member_name_to_expr
{
for column in column_iter {
if let Some((&(Some(ref member), _, _), _)) =
LogicalPlanData::do_find_member_by_alias(
member_names_to_expr,
&column.name,
)
{
if let Some(measure) = meta.find_measure_with_name(member) {
let Some(call_agg_type) = &call_agg_type else {
// call_agg_type is None, rewrite as is
Self::insert_regular_measure(
egraph,
subst,
column,
alias,
out_expr_var,
out_alias_var,
);

return true;
};
return true;
};

if measure
.is_same_agg_type(call_agg_type, disable_strict_agg_type_match)
{
Self::insert_regular_measure(
egraph,
subst,
column,
alias,
out_expr_var,
out_alias_var,
);
if measure.is_same_agg_type(
call_agg_type,
disable_strict_agg_type_match,
) {
Self::insert_regular_measure(
egraph,
subst,
column,
alias,
out_expr_var,
out_alias_var,
);

return true;
}
return true;
}

if measure.allow_replace_agg_type(
call_agg_type,
disable_strict_agg_type_match,
) {
Self::insert_patch_measure(
egraph,
subst,
column,
Some(call_agg_type.clone()),
alias,
Some(out_expr_var),
None,
out_alias_var,
);
if measure.allow_replace_agg_type(
call_agg_type,
disable_strict_agg_type_match,
) {
Self::insert_patch_measure(
egraph,
subst,
column,
Some(call_agg_type.clone()),
alias,
Some(out_expr_var),
None,
out_alias_var,
);

return true;
return true;
}
}
}
}
Expand All @@ -1148,6 +1172,7 @@ impl WrapperRules {
cube_members_var: &'static str,
out_expr_var: &'static str,
out_alias_var: &'static str,
alias_to_cube_var: &'static str,
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
let original_expr_var = var!(original_expr_var);
let column_var = column_var.map(|v| var!(v));
Expand All @@ -1157,6 +1182,7 @@ impl WrapperRules {
let cube_members_var = var!(cube_members_var);
let out_expr_var = var!(out_expr_var);
let out_alias_var = var!(out_alias_var);
let alias_to_cube_var = var!(alias_to_cube_var);
let meta = self.meta_context.clone();
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
move |egraph, subst| {
Expand All @@ -1170,6 +1196,7 @@ impl WrapperRules {
cube_members_var,
out_expr_var,
out_alias_var,
alias_to_cube_var,
&meta,
disable_strict_agg_type_match,
)
Expand Down
Loading