Skip to content

Commit ada0afd

Browse files
authored
fix(cubesql): Quicksight AVG measures support (#6323)
1 parent 838c7cf commit ada0afd

File tree

3 files changed

+205
-20
lines changed

3 files changed

+205
-20
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3732,16 +3732,12 @@ ORDER BY \"COUNT(count)\" DESC"
37323732
async fn test_select_error() {
37333733
let variants = vec![
37343734
(
3735-
"SELECT COUNT(maxPrice) FROM KibanaSampleDataEcommerce".to_string(),
3736-
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'maxPrice' is 'MAX()' but 'COUNT()' was provided. Please check logs for additional information.".to_string()),
3735+
"SELECT AVG(maxPrice) FROM KibanaSampleDataEcommerce".to_string(),
3736+
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'maxPrice' is 'MAX()' but 'AVG()' was provided. Please check logs for additional information.".to_string()),
37373737
),
37383738
(
3739-
"SELECT COUNT(someNumber) FROM NumberCube".to_string(),
3740-
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'someNumber' is 'MEASURE()' but 'COUNT()' was provided. Please check logs for additional information.".to_string()),
3741-
),
3742-
(
3743-
"SELECT COUNT(order_date) FROM KibanaSampleDataEcommerce".to_string(),
3744-
CompilationError::user("Error during rewrite: Dimension 'order_date' was used with the aggregate function 'COUNT()'. Please use a measure instead. Please check logs for additional information.".to_string()),
3739+
"SELECT AVG(someNumber) FROM NumberCube".to_string(),
3740+
CompilationError::user("Error during rewrite: Measure aggregation type doesn't match. The aggregation type for 'someNumber' is 'MEASURE()' but 'AVG()' was provided. Please check logs for additional information.".to_string()),
37453741
),
37463742
];
37473743

@@ -10397,6 +10393,52 @@ ORDER BY \"COUNT(count)\" DESC"
1039710393
Ok(())
1039810394
}
1039910395

10396+
#[tokio::test]
10397+
async fn test_quicksight_dense_rank() -> Result<(), CubeError> {
10398+
init_logger();
10399+
10400+
let query_plan = convert_select_to_query_plan(
10401+
r#"
10402+
SELECT "faabeaae-5980-4f8f-a5ba-12f56f191f1e.order_date", "isotherrow_1", "faabeaae-5980-4f8f-a5ba-12f56f191f1e.avgPrice_avg", "$otherbucket_group_count", "count"
10403+
FROM (
10404+
SELECT "$f4" AS "faabeaae-5980-4f8f-a5ba-12f56f191f1e.order_date", "$f5", "$f6" AS "isotherrow_1", SUM("$weighted_avg_unit_4") AS "faabeaae-5980-4f8f-a5ba-12f56f191f1e.avgPrice_avg", COUNT(*) AS "$otherbucket_group_count", SUM("count") AS "count"
10405+
FROM (
10406+
SELECT "count", CASE WHEN "$RANK_1" > 2500 THEN NULL ELSE "faabeaae-5980-4f8f-a5ba-12f56f191f1e.order_date" END AS "$f4", CASE WHEN "$RANK_1" > 2500 THEN NULL ELSE "$RANK_1" END AS "$f5", CASE WHEN "$RANK_1" > 2500 THEN 1 ELSE 0 END AS "$f6", CAST("$weighted_avg_count_3" AS FLOAT) / NULLIF(CAST(SUM("$weighted_avg_count_3") OVER (PARTITION BY CASE WHEN "$RANK_1" > 2500 THEN NULL ELSE "faabeaae-5980-4f8f-a5ba-12f56f191f1e.order_date" END, CASE WHEN "$RANK_1" > 2500 THEN NULL ELSE "$RANK_1" END, CASE WHEN "$RANK_1" > 2500 THEN 1 ELSE 0 END) AS FLOAT), 0) * "faabeaae-5980-4f8f-a5ba-12f56f191f1e.avgPrice_avg" AS "$weighted_avg_unit_4"
10407+
FROM (
10408+
SELECT "order_date" AS "faabeaae-5980-4f8f-a5ba-12f56f191f1e.order_date", COUNT(*) AS "count", AVG("avgPrice") AS "faabeaae-5980-4f8f-a5ba-12f56f191f1e.avgPrice_avg", DENSE_RANK() OVER (ORDER BY AVG("avgPrice") DESC NULLS LAST, "order_date" NULLS FIRST) AS "$RANK_1", COUNT("avgPrice") AS "$weighted_avg_count_3"
10409+
FROM "public"."KibanaSampleDataEcommerce"
10410+
GROUP BY "order_date"
10411+
) AS "t"
10412+
) AS "t0"
10413+
GROUP BY "$f4", "$f5", "$f6"
10414+
ORDER BY "$f5" NULLS FIRST
10415+
) AS "t1"
10416+
;"#.to_string(),
10417+
DatabaseProtocol::PostgreSQL,
10418+
).await;
10419+
10420+
let logical_plan = query_plan.as_logical_plan();
10421+
10422+
assert_eq!(
10423+
logical_plan.find_cube_scan().request,
10424+
V1LoadRequestQuery {
10425+
measures: Some(vec![
10426+
"KibanaSampleDataEcommerce.count".to_string(),
10427+
"KibanaSampleDataEcommerce.avgPrice".to_string()
10428+
]),
10429+
segments: Some(vec![]),
10430+
dimensions: Some(vec!["KibanaSampleDataEcommerce.order_date".to_string()]),
10431+
time_dimensions: None,
10432+
order: None,
10433+
limit: None,
10434+
offset: None,
10435+
filters: None,
10436+
}
10437+
);
10438+
10439+
Ok(())
10440+
}
10441+
1040010442
#[tokio::test]
1040110443
async fn test_localtimestamp() -> Result<(), CubeError> {
1040210444
// TODO: the value will be different with the introduction of TZ support

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,18 @@ impl MemberRules {
911911
),
912912
true,
913913
));
914+
rules.extend(find_matching_old_member_with_count(
915+
"agg-fun-default-count-alias",
916+
alias_expr(
917+
agg_fun_expr(
918+
"Count",
919+
vec![literal_expr("?any")],
920+
"AggregateFunctionExprDistinct:false",
921+
),
922+
"?alias",
923+
),
924+
true,
925+
));
914926
rules.extend(find_matching_old_member(
915927
"agg-fun-with-cast",
916928
// TODO need to check data_type if we can remove the cast
@@ -1017,6 +1029,17 @@ impl MemberRules {
10171029
Some("?distinct"),
10181030
None,
10191031
));
1032+
rules.push(pushdown_measure_rewrite(
1033+
"member-pushdown-replacer-agg-fun-default-count-alias",
1034+
alias_expr(
1035+
agg_fun_expr("?fun_name", vec![literal_expr("?any")], "?distinct"),
1036+
"?alias",
1037+
),
1038+
measure_expr("?name", "?old_alias"),
1039+
Some("?fun_name"),
1040+
Some("?distinct"),
1041+
None,
1042+
));
10201043

10211044
rules.push(transforming_chain_rewrite(
10221045
"member-pushdown-date-trunc",
@@ -3257,7 +3280,7 @@ impl MemberRules {
32573280
}
32583281
}
32593282

3260-
fn get_agg_type(fun: Option<&AggregateFunction>, distinct: bool) -> Option<String> {
3283+
pub fn get_agg_type(fun: Option<&AggregateFunction>, distinct: bool) -> Option<String> {
32613284
fun.map(|fun| {
32623285
match fun {
32633286
AggregateFunction::Count => {

rust/cubesql/cubesql/src/compile/rewrite/rules/split.rs

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
OuterProjectionSplitReplacerAliasToCube, ProjectionAlias, ScalarFunctionExprFun,
2121
},
2222
},
23-
transport::V1CubeMetaExt,
23+
transport::{V1CubeMetaExt, V1CubeMetaMeasureExt},
2424
var, var_iter, CubeError,
2525
};
2626
use datafusion::{
@@ -1603,14 +1603,15 @@ impl RewriteRules for SplitRules {
16031603
),
16041604
),
16051605
// Aggregate function
1606-
transforming_rewrite(
1606+
transforming_chain_rewrite(
16071607
"split-push-down-aggr-fun-inner-replacer",
16081608
inner_aggregate_split_replacer(
1609-
agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"),
1609+
"?aggr_expr",
16101610
"?cube",
16111611
),
1612-
agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"),
1613-
self.transform_inner_measure("?cube", Some("?column")),
1612+
vec![("?aggr_expr", agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"))],
1613+
"?out_expr".to_string(),
1614+
self.transform_inner_measure("?cube", Some("?column"), Some("?aggr_expr"), Some("?fun"), Some("?distinct"), Some("?out_expr")),
16141615
),
16151616
transforming_rewrite(
16161617
"split-push-down-aggr-fun-inner-replacer-simple-count",
@@ -1619,7 +1620,7 @@ impl RewriteRules for SplitRules {
16191620
"?cube",
16201621
),
16211622
agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"),
1622-
self.transform_inner_measure("?cube", None),
1623+
self.transform_inner_measure("?cube", None, None, None, None, None),
16231624
),
16241625
transforming_rewrite(
16251626
"split-push-down-aggr-fun-inner-replacer-missing-count",
@@ -1638,7 +1639,17 @@ impl RewriteRules for SplitRules {
16381639
agg_fun_expr("?fun", vec![column_expr("?column")], "?distinct"),
16391640
)],
16401641
"?alias".to_string(),
1641-
self.transform_outer_projection_aggr_fun("?cube", "?expr", "?column", "?alias"),
1642+
self.transform_outer_projection_aggr_fun("?cube", "?expr", Some("?column"), "?alias"),
1643+
),
1644+
transforming_chain_rewrite(
1645+
"split-push-down-aggr-fun-outer-replacer-simple-count",
1646+
outer_projection_split_replacer("?expr", "?cube"),
1647+
vec![(
1648+
"?expr",
1649+
agg_fun_expr("?fun", vec![literal_expr("?literal")], "?distinct"),
1650+
)],
1651+
"?alias".to_string(),
1652+
self.transform_outer_projection_aggr_fun("?cube", "?expr", None, "?alias"),
16421653
),
16431654
transforming_chain_rewrite(
16441655
"split-push-down-aggr-fun-outer-aggr-replacer",
@@ -4511,9 +4522,17 @@ impl SplitRules {
45114522
&self,
45124523
cube_expr_var: &'static str,
45134524
column_var: Option<&'static str>,
4525+
aggr_expr_var: Option<&'static str>,
4526+
fun_var: Option<&'static str>,
4527+
distinct_var: Option<&'static str>,
4528+
out_expr_var: Option<&'static str>,
45144529
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
45154530
let cube_expr_var = var!(cube_expr_var);
45164531
let column_var = column_var.map(|column_var| var!(column_var));
4532+
let aggr_expr_var = aggr_expr_var.map(|v| var!(v));
4533+
let fun_var = fun_var.map(|v| var!(v));
4534+
let distinct_var = distinct_var.map(|v| var!(v));
4535+
let out_expr_var = out_expr_var.map(|v| var!(v));
45174536
let meta = self.cube_context.meta.clone();
45184537
move |egraph, subst| {
45194538
for alias_to_cube in var_iter!(
@@ -4533,7 +4552,99 @@ impl SplitRules {
45334552
)])
45344553
{
45354554
if let Some((_, cube)) = meta.find_cube_by_column(&alias_to_cube, &column) {
4536-
if cube.lookup_measure(&column.name).is_some() {
4555+
if let Some(measure) = cube.lookup_measure(&column.name) {
4556+
if let Some((((fun_var, distinct_var), out_expr_var), aggr_expr_var)) =
4557+
fun_var
4558+
.zip(distinct_var)
4559+
.zip(out_expr_var)
4560+
.zip(aggr_expr_var)
4561+
{
4562+
for fun in
4563+
var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun)
4564+
{
4565+
for distinct in var_iter!(
4566+
egraph[subst[distinct_var]],
4567+
AggregateFunctionExprDistinct
4568+
) {
4569+
// If count is wrapping non count measure let's allow split and just count rows
4570+
// TODO it might worth allowing all other non matching aggregate functions
4571+
if let AggregateFunction::Count = fun {
4572+
if !distinct {
4573+
let agg_type = MemberRules::get_agg_type(
4574+
Some(&fun),
4575+
*distinct,
4576+
);
4577+
if !measure.is_same_agg_type(&agg_type.unwrap()) {
4578+
if let Some(expr_name) = original_expr_name(
4579+
egraph,
4580+
subst[aggr_expr_var],
4581+
) {
4582+
let measure_fun = egraph.add(
4583+
LogicalPlanLanguage::AggregateFunctionExprFun(
4584+
AggregateFunctionExprFun(AggregateFunction::Count)
4585+
),
4586+
);
4587+
4588+
let measure_distinct = egraph.add(
4589+
LogicalPlanLanguage::AggregateFunctionExprDistinct(
4590+
AggregateFunctionExprDistinct(false)
4591+
),
4592+
);
4593+
let tail = egraph.add(
4594+
LogicalPlanLanguage::AggregateFunctionExprArgs(
4595+
vec![],
4596+
),
4597+
);
4598+
4599+
let literal_expr_value = egraph.add(
4600+
LogicalPlanLanguage::LiteralExprValue(
4601+
LiteralExprValue(
4602+
ScalarValue::Int64(None),
4603+
),
4604+
),
4605+
);
4606+
4607+
let column_expr = egraph.add(
4608+
LogicalPlanLanguage::LiteralExpr([
4609+
literal_expr_value,
4610+
]),
4611+
);
4612+
let args = egraph.add(
4613+
LogicalPlanLanguage::AggregateFunctionExprArgs(
4614+
vec![column_expr, tail],
4615+
),
4616+
);
4617+
let aggr_expr = egraph.add(
4618+
LogicalPlanLanguage::AggregateFunctionExpr(
4619+
[measure_fun, args, measure_distinct],
4620+
),
4621+
);
4622+
let alias = egraph.add(
4623+
LogicalPlanLanguage::AliasExprAlias(
4624+
AliasExprAlias(expr_name),
4625+
),
4626+
);
4627+
4628+
let alias_expr = egraph.add(
4629+
LogicalPlanLanguage::AliasExpr([
4630+
aggr_expr, alias,
4631+
]),
4632+
);
4633+
subst.insert(out_expr_var, alias_expr);
4634+
return true;
4635+
}
4636+
}
4637+
}
4638+
}
4639+
}
4640+
}
4641+
}
4642+
4643+
if let Some((aggr_expr_var, out_expr_var)) =
4644+
aggr_expr_var.zip(out_expr_var)
4645+
{
4646+
subst.insert(out_expr_var, subst[aggr_expr_var]);
4647+
}
45374648
return true;
45384649
}
45394650
}
@@ -5379,12 +5490,12 @@ impl SplitRules {
53795490
&self,
53805491
cube_var: &'static str,
53815492
original_expr_var: &'static str,
5382-
column_var: &'static str,
5493+
column_var: Option<&'static str>,
53835494
alias_expr_var: &'static str,
53845495
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
53855496
let cube_var = var!(cube_var);
53865497
let original_expr_var = var!(original_expr_var);
5387-
let column_var = var!(column_var);
5498+
let column_var = column_var.map(|v| var!(v));
53885499
let alias_expr_var = var!(alias_expr_var);
53895500
let meta = self.cube_context.meta.clone();
53905501
move |egraph, subst| {
@@ -5395,7 +5506,16 @@ impl SplitRules {
53955506
.cloned()
53965507
{
53975508
if let Some(name) = original_expr_name(egraph, subst[original_expr_var]) {
5398-
for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn).cloned() {
5509+
for column in column_var
5510+
.map(|column_var| {
5511+
var_iter!(egraph[subst[column_var]], ColumnExprColumn)
5512+
.map(|c| c.clone())
5513+
.collect()
5514+
})
5515+
.unwrap_or(vec![Column::from_name(
5516+
MemberRules::default_count_measure_name(),
5517+
)])
5518+
{
53995519
if let Some((_, cube)) = meta.find_cube_by_column(&alias_to_cube, &column) {
54005520
if cube.lookup_measure(&column.name).is_some() {
54015521
let alias = egraph.add(LogicalPlanLanguage::ColumnExprColumn(

0 commit comments

Comments
 (0)