Skip to content

Commit 671e067

Browse files
authored
fix(cubesql): Transform IN filter with one value to = with all expressions
1 parent f135933 commit 671e067

File tree

3 files changed

+59
-33
lines changed

3 files changed

+59
-33
lines changed

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24325,4 +24325,46 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
2432524325

2432624326
Ok(())
2432724327
}
24328+
24329+
#[tokio::test]
24330+
async fn test_filter_time_dimension_equals_as_date_range() {
24331+
init_logger();
24332+
24333+
let logical_plan = convert_select_to_query_plan(
24334+
r#"
24335+
SELECT
24336+
measure(count) AS cnt,
24337+
date_trunc('month', order_date) AS dt
24338+
FROM KibanaSampleDataEcommerce
24339+
WHERE order_date IN (to_timestamp('2019-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US'))
24340+
GROUP BY 2
24341+
;"#
24342+
.to_string(),
24343+
DatabaseProtocol::PostgreSQL,
24344+
)
24345+
.await
24346+
.as_logical_plan();
24347+
24348+
assert_eq!(
24349+
logical_plan.find_cube_scan().request,
24350+
V1LoadRequestQuery {
24351+
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
24352+
dimensions: Some(vec![]),
24353+
segments: Some(vec![]),
24354+
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
24355+
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
24356+
granularity: Some("month".to_string()),
24357+
date_range: Some(json!(vec![
24358+
"2019-01-01T00:00:00.000Z".to_string(),
24359+
"2019-01-01T00:00:00.000Z".to_string()
24360+
]))
24361+
}]),
24362+
order: None,
24363+
limit: None,
24364+
offset: None,
24365+
filters: None,
24366+
ungrouped: None,
24367+
}
24368+
)
24369+
}
2432824370
}

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,10 @@ fn inlist_expr(expr: impl Display, list: impl Display, negated: impl Display) ->
15361536
format!("(InListExpr {} {} {})", expr, list, negated)
15371537
}
15381538

1539+
fn inlist_expr_list(exprs: Vec<impl Display>) -> String {
1540+
flat_list_expr("InListExprList", exprs, true)
1541+
}
1542+
15391543
fn insubquery_expr(expr: impl Display, subquery: impl Display, negated: impl Display) -> String {
15401544
format!("(InSubqueryExpr {} {} {})", expr, subquery, negated)
15411545
}

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ use crate::{
77
column_expr, cube_scan, cube_scan_filters, cube_scan_filters_empty_tail, cube_scan_members,
88
dimension_expr, expr_column_name, filter, filter_member, filter_op, filter_op_filters,
99
filter_op_filters_empty_tail, filter_replacer, filter_simplify_replacer, fun_expr,
10-
fun_expr_args_legacy, fun_expr_var_arg, inlist_expr, is_not_null_expr, is_null_expr,
11-
like_expr, limit, list_rewrite, literal_bool, literal_expr, literal_int, literal_string,
12-
measure_expr, member_name_to_expr_by_alias, negative_expr, not_expr, projection, rewrite,
10+
fun_expr_args_legacy, fun_expr_var_arg, inlist_expr, inlist_expr_list, is_not_null_expr,
11+
is_null_expr, like_expr, limit, list_rewrite, literal_bool, literal_expr, literal_int,
12+
literal_string, measure_expr, member_name_to_expr_by_alias, negative_expr, not_expr,
13+
projection, rewrite,
1314
rewriter::RewriteRules,
1415
scalar_fun_expr_args_empty_tail, segment_member, time_dimension_date_range_replacer,
1516
time_dimension_expr, transform_original_expr_to_alias, transforming_chain_rewrite,
@@ -394,18 +395,18 @@ impl RewriteRules for FilterRules {
394395
transforming_rewrite(
395396
"in-filter-equal",
396397
filter_replacer(
397-
inlist_expr("?expr", "?list", "?negated"),
398+
inlist_expr("?expr", inlist_expr_list(vec!["?elem"]), "?negated"),
398399
"?alias_to_cube",
399400
"?members",
400401
"?filter_aliases",
401402
),
402403
filter_replacer(
403-
"?binary_expr",
404+
binary_expr("?expr", "?op", "?elem"),
404405
"?alias_to_cube",
405406
"?members",
406407
"?filter_aliases",
407408
),
408-
self.transform_filter_in_to_equal("?expr", "?list", "?negated", "?binary_expr"),
409+
self.transform_filter_in_to_equal("?negated", "?op"),
409410
),
410411
transforming_rewrite(
411412
"filter-in-list-datetrunc",
@@ -3288,45 +3289,24 @@ impl FilterRules {
32883289
// Transform ?expr IN (?literal) to ?expr = ?literal
32893290
fn transform_filter_in_to_equal(
32903291
&self,
3291-
expr_val: &'static str,
3292-
list_var: &'static str,
32933292
negated_var: &'static str,
3294-
return_binary_expr_var: &'static str,
3293+
op_var: &'static str,
32953294
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
3296-
let expr_val = var!(expr_val);
3297-
let list_var = var!(list_var);
32983295
let negated_var = var!(negated_var);
3299-
let return_binary_expr_var = var!(return_binary_expr_var);
3296+
let op_var = var!(op_var);
33003297

33013298
move |egraph, subst| {
3302-
let expr_id = subst[expr_val];
3303-
let scalar = match &egraph[subst[list_var]].data.constant_in_list {
3304-
Some(list) if list.len() == 1 => list[0].clone(),
3305-
_ => return false,
3306-
};
3307-
33083299
for negated in var_iter!(egraph[subst[negated_var]], InListExprNegated) {
33093300
let operator = if *negated {
33103301
Operator::NotEq
33113302
} else {
33123303
Operator::Eq
33133304
};
3314-
let operator =
3315-
egraph.add(LogicalPlanLanguage::BinaryExprOp(BinaryExprOp(operator)));
3316-
3317-
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExprValue(
3318-
LiteralExprValue(scalar),
3319-
));
3320-
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([literal_expr]));
3321-
3322-
let return_binary_expr = egraph.add(LogicalPlanLanguage::BinaryExpr([
3323-
expr_id,
3324-
operator,
3325-
literal_expr,
3326-
]));
3327-
3328-
subst.insert(return_binary_expr_var, return_binary_expr);
33293305

3306+
subst.insert(
3307+
op_var,
3308+
egraph.add(LogicalPlanLanguage::BinaryExprOp(BinaryExprOp(operator))),
3309+
);
33303310
return true;
33313311
}
33323312

0 commit comments

Comments
 (0)