Skip to content

Commit 67c027b

Browse files
authored
feat(cubesql): Automatically cast literal strings in binary expressions (#9938)
1 parent b778b76 commit 67c027b

File tree

4 files changed

+135
-15
lines changed

4 files changed

+135
-15
lines changed

rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ use datafusion::{
1111
Limit, Partitioning, Projection, Repartition, Sort, Subquery, TableScan, TableUDFs,
1212
Union, Values, Window,
1313
},
14-
union_with_alias, Column, DFSchema, ExprSchemable, LogicalPlan, LogicalPlanBuilder,
15-
Operator,
14+
union_with_alias, Column, DFSchema, ExprRewritable, ExprSchemable, LogicalPlan,
15+
LogicalPlanBuilder, Operator,
16+
},
17+
optimizer::{
18+
optimizer::{OptimizerConfig, OptimizerRule},
19+
simplify_expressions::ConstEvaluator,
1620
},
17-
optimizer::optimizer::{OptimizerConfig, OptimizerRule},
1821
scalar::ScalarValue,
1922
sql::planner::ContextProvider,
2023
};
@@ -27,7 +30,9 @@ use crate::compile::{engine::CubeContext, rewrite::rules::utils::DatePartToken};
2730
/// Currently this includes replacing:
2831
/// - literal granularities in `DatePart` and `DateTrunc` functions
2932
/// with their normalized equivalents
30-
/// - replacing `DATE - DATE` expressions with `DATEDIFF` equivalent
33+
/// - `DATE - DATE` expressions with `DATEDIFF` equivalent
34+
/// - binary operations between a literal string and an expression
35+
/// of a different type to a string casted to that type
3136
pub struct PlanNormalize<'a> {
3237
cube_ctx: &'a CubeContext,
3338
}
@@ -1189,8 +1194,10 @@ fn grouping_set_normalize(
11891194
}
11901195

11911196
/// Recursively normalizes binary expressions.
1192-
/// Currently this includes replacing `DATE - DATE` expressions
1193-
/// with respective `DATEDIFF` function calls.
1197+
/// Currently this includes replacing:
1198+
/// - `DATE - DATE` expressions with respective `DATEDIFF` function calls
1199+
/// - binary operations between a literal string and an expression
1200+
/// of a different type to a string casted to that type
11941201
fn binary_expr_normalize(
11951202
optimizer: &PlanNormalize,
11961203
left: &Expr,
@@ -1221,10 +1228,9 @@ fn binary_expr_normalize(
12211228
// can be rewritten to something else, a binary variation still exists and would be picked
12221229
// for SQL push down generation either way. This creates an issue in dialects
12231230
// other than Postgres that would return INTERVAL on `DATE - DATE` expression.
1224-
if left.get_type(schema)? == DataType::Date32
1225-
&& op == Operator::Minus
1226-
&& right.get_type(schema)? == DataType::Date32
1227-
{
1231+
let left_type = left.get_type(schema)?;
1232+
let right_type = right.get_type(schema)?;
1233+
if left_type == DataType::Date32 && op == Operator::Minus && right_type == DataType::Date32 {
12281234
let fun = optimizer
12291235
.cube_ctx
12301236
.get_function_meta("datediff")
@@ -1241,5 +1247,84 @@ fn binary_expr_normalize(
12411247
return Ok(Expr::ScalarUDF { fun, args });
12421248
}
12431249

1244-
Ok(Expr::BinaryExpr { left, op, right })
1250+
// Check if one side of the binary expression is a literal string. If that's the case,
1251+
// attempt to cast the string to other type based on the operator and type on the other side.
1252+
// If none of the sides is a literal string, the normalization is complete.
1253+
let (other_type, literal_on_the_left) = match (left.as_ref(), right.as_ref()) {
1254+
(_, Expr::Literal(ScalarValue::Utf8(Some(_)))) => (left_type, false),
1255+
(Expr::Literal(ScalarValue::Utf8(Some(_))), _) => (right_type, true),
1256+
_ => return Ok(Expr::BinaryExpr { left, op, right }),
1257+
};
1258+
let Some(cast_type) = binary_expr_cast_literal(&op, &other_type) else {
1259+
return Ok(Expr::BinaryExpr { left, op, right });
1260+
};
1261+
if literal_on_the_left {
1262+
let new_left = evaluate_expr(optimizer, left.cast_to(&cast_type, schema)?)?;
1263+
Ok(Expr::BinaryExpr {
1264+
left: Box::new(new_left),
1265+
op,
1266+
right,
1267+
})
1268+
} else {
1269+
let new_right = evaluate_expr(optimizer, right.cast_to(&cast_type, schema)?)?;
1270+
Ok(Expr::BinaryExpr {
1271+
left,
1272+
op,
1273+
right: Box::new(new_right),
1274+
})
1275+
}
1276+
}
1277+
1278+
/// Returns the type a literal string should be casted to based on the operator
1279+
/// and the type on the other side of the binary expression.
1280+
/// If no casting is needed, returns `None`.
1281+
fn binary_expr_cast_literal(op: &Operator, other_type: &DataType) -> Option<DataType> {
1282+
if other_type == &DataType::Utf8 {
1283+
// If the other side is a string, casting is never required
1284+
return None;
1285+
}
1286+
1287+
match op {
1288+
// Comparison operators should cast strings to the other side type
1289+
Operator::Eq
1290+
| Operator::NotEq
1291+
| Operator::Lt
1292+
| Operator::LtEq
1293+
| Operator::Gt
1294+
| Operator::GtEq
1295+
| Operator::IsDistinctFrom
1296+
| Operator::IsNotDistinctFrom => Some(other_type.clone()),
1297+
// Arithmetic operators should cast strings to the other side type
1298+
Operator::Plus
1299+
| Operator::Minus
1300+
| Operator::Multiply
1301+
| Operator::Divide
1302+
| Operator::Modulo
1303+
| Operator::Exponentiate => Some(other_type.clone()),
1304+
// Logical operators operate only on booleans
1305+
Operator::And | Operator::Or => Some(DataType::Boolean),
1306+
// LIKE and regexes operate only on strings, no casting needed
1307+
Operator::Like
1308+
| Operator::NotLike
1309+
| Operator::ILike
1310+
| Operator::NotILike
1311+
| Operator::RegexMatch
1312+
| Operator::RegexIMatch
1313+
| Operator::RegexNotMatch
1314+
| Operator::RegexNotIMatch => None,
1315+
// Bitwise oprators should cast strings to the other side type
1316+
Operator::BitwiseAnd
1317+
| Operator::BitwiseOr
1318+
| Operator::BitwiseShiftRight
1319+
| Operator::BitwiseShiftLeft => Some(other_type.clone()),
1320+
// String concat allows string on either side, no casting needed
1321+
Operator::StringConcat => None,
1322+
}
1323+
}
1324+
1325+
/// Evaluates an expression to a constant if possible.
1326+
fn evaluate_expr(optimizer: &PlanNormalize, expr: Expr) -> Result<Expr> {
1327+
let execution_props = &optimizer.cube_ctx.state.execution_props;
1328+
let mut const_evaluator = ConstEvaluator::new(execution_props);
1329+
expr.rewrite(&mut const_evaluator)
12451330
}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15756,7 +15756,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1575615756
member: Some("KibanaSampleDataEcommerce.order_date".to_string()),
1575715757
operator: Some("inDateRange".to_string()),
1575815758
values: Some(vec![
15759-
"2019-01-01 00:00:00.0".to_string(),
15759+
"2019-01-01T00:00:00.000Z".to_string(),
1576015760
"2019-12-31T23:59:59.999Z".to_string(),
1576115761
]),
1576215762
or: None,
@@ -15766,7 +15766,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1576615766
member: Some("KibanaSampleDataEcommerce.order_date".to_string()),
1576715767
operator: Some("inDateRange".to_string()),
1576815768
values: Some(vec![
15769-
"2021-01-01 00:00:00.0".to_string(),
15769+
"2021-01-01T00:00:00.000Z".to_string(),
1577015770
"2021-12-31T23:59:59.999Z".to_string(),
1577115771
]),
1577215772
or: None,
@@ -17576,7 +17576,6 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1757617576

1757717577
let logical_plan = query_plan.as_logical_plan();
1757817578
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
17579-
println!("Generated SQL: {}", sql);
1758017579
assert!(sql.contains("2025-01-01"));
1758117580
assert!(sql.contains("customer_gender} IN (SELECT"));
1758217581

@@ -17586,4 +17585,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1758617585
displayable(physical_plan.as_ref()).indent()
1758717586
);
1758817587
}
17588+
17589+
#[tokio::test]
17590+
async fn test_string_literal_auto_cast() {
17591+
if !Rewriter::sql_push_down_enabled() {
17592+
return;
17593+
}
17594+
init_testing_logger();
17595+
17596+
let query_plan = convert_select_to_query_plan(
17597+
r#"
17598+
SELECT id
17599+
FROM KibanaSampleDataEcommerce
17600+
WHERE
17601+
LOWER(customer_gender) != 'unknown'
17602+
AND has_subscription = 'TRUE'
17603+
GROUP BY 1
17604+
"#
17605+
.to_string(),
17606+
DatabaseProtocol::PostgreSQL,
17607+
)
17608+
.await;
17609+
17610+
let logical_plan = query_plan.as_logical_plan();
17611+
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
17612+
assert!(sql.contains("${KibanaSampleDataEcommerce.has_subscription} = TRUE"));
17613+
assert!(!sql.contains("'TRUE'"));
17614+
17615+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
17616+
println!(
17617+
"Physical plan: {}",
17618+
displayable(physical_plan.as_ref()).indent()
17619+
);
17620+
}
1758917621
}

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3610,6 +3610,9 @@ impl FilterRules {
36103610
let year = match year {
36113611
ScalarValue::Int64(Some(year)) => year,
36123612
ScalarValue::Int32(Some(year)) => year as i64,
3613+
ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(&year) => {
3614+
year.round() as i64
3615+
}
36133616
ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => {
36143617
if let Ok(year) = year_str.parse::<i64>() {
36153618
year

rust/cubesql/cubesql/src/compile/test/test_filters.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ GROUP BY
5252
V1LoadRequestQueryFilterItem {
5353
member: Some("MultiTypeCube.dim_date0".to_string()),
5454
operator: Some("afterDate".to_string()),
55-
values: Some(vec!["2019-01-01 00:00:00".to_string()]),
55+
values: Some(vec!["2019-01-01T00:00:00.000Z".to_string()]),
5656
or: None,
5757
and: None,
5858
},

0 commit comments

Comments
 (0)