Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ use datafusion::{
Limit, Partitioning, Projection, Repartition, Sort, Subquery, TableScan, TableUDFs,
Union, Values, Window,
},
union_with_alias, Column, DFSchema, ExprSchemable, LogicalPlan, LogicalPlanBuilder,
Operator,
union_with_alias, Column, DFSchema, ExprRewritable, ExprSchemable, LogicalPlan,
LogicalPlanBuilder, Operator,
},
optimizer::{
optimizer::{OptimizerConfig, OptimizerRule},
simplify_expressions::ConstEvaluator,
},
optimizer::optimizer::{OptimizerConfig, OptimizerRule},
scalar::ScalarValue,
sql::planner::ContextProvider,
};
Expand All @@ -27,7 +30,9 @@ use crate::compile::{engine::CubeContext, rewrite::rules::utils::DatePartToken};
/// Currently this includes replacing:
/// - literal granularities in `DatePart` and `DateTrunc` functions
/// with their normalized equivalents
/// - replacing `DATE - DATE` expressions with `DATEDIFF` equivalent
/// - `DATE - DATE` expressions with `DATEDIFF` equivalent
/// - binary operations between a literal string and an expression
/// of a different type to a string casted to that type
pub struct PlanNormalize<'a> {
cube_ctx: &'a CubeContext,
}
Expand Down Expand Up @@ -1189,8 +1194,10 @@ fn grouping_set_normalize(
}

/// Recursively normalizes binary expressions.
/// Currently this includes replacing `DATE - DATE` expressions
/// with respective `DATEDIFF` function calls.
/// Currently this includes replacing:
/// - `DATE - DATE` expressions with respective `DATEDIFF` function calls
/// - binary operations between a literal string and an expression
/// of a different type to a string casted to that type
fn binary_expr_normalize(
optimizer: &PlanNormalize,
left: &Expr,
Expand Down Expand Up @@ -1221,10 +1228,9 @@ fn binary_expr_normalize(
// can be rewritten to something else, a binary variation still exists and would be picked
// for SQL push down generation either way. This creates an issue in dialects
// other than Postgres that would return INTERVAL on `DATE - DATE` expression.
if left.get_type(schema)? == DataType::Date32
&& op == Operator::Minus
&& right.get_type(schema)? == DataType::Date32
{
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
if left_type == DataType::Date32 && op == Operator::Minus && right_type == DataType::Date32 {
let fun = optimizer
.cube_ctx
.get_function_meta("datediff")
Expand All @@ -1241,5 +1247,84 @@ fn binary_expr_normalize(
return Ok(Expr::ScalarUDF { fun, args });
}

Ok(Expr::BinaryExpr { left, op, right })
// Check if one side of the binary expression is a literal string. If that's the case,
// attempt to cast the string to other type based on the operator and type on the other side.
// If none of the sides is a literal string, the normalization is complete.
let (other_type, literal_on_the_left) = match (left.as_ref(), right.as_ref()) {
(_, Expr::Literal(ScalarValue::Utf8(Some(_)))) => (left_type, false),
(Expr::Literal(ScalarValue::Utf8(Some(_))), _) => (right_type, true),
_ => return Ok(Expr::BinaryExpr { left, op, right }),
};
let Some(cast_type) = binary_expr_cast_literal(&op, &other_type) else {
return Ok(Expr::BinaryExpr { left, op, right });
};
if literal_on_the_left {
let new_left = evaluate_expr(optimizer, left.cast_to(&cast_type, schema)?)?;
Ok(Expr::BinaryExpr {
left: Box::new(new_left),
op,
right,
})
} else {
let new_right = evaluate_expr(optimizer, right.cast_to(&cast_type, schema)?)?;
Ok(Expr::BinaryExpr {
left,
op,
right: Box::new(new_right),
})
}
}

/// Returns the type a literal string should be casted to based on the operator
/// and the type on the other side of the binary expression.
/// If no casting is needed, returns `None`.
fn binary_expr_cast_literal(op: &Operator, other_type: &DataType) -> Option<DataType> {
if other_type == &DataType::Utf8 {
// If the other side is a string, casting is never required
return None;
}

match op {
// Comparison operators should cast strings to the other side type
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom => Some(other_type.clone()),
// Arithmetic operators should cast strings to the other side type
Operator::Plus
| Operator::Minus
| Operator::Multiply
| Operator::Divide
| Operator::Modulo
| Operator::Exponentiate => Some(other_type.clone()),
// Logical operators operate only on booleans
Operator::And | Operator::Or => Some(DataType::Boolean),
// LIKE and regexes operate only on strings, no casting needed
Operator::Like
| Operator::NotLike
| Operator::ILike
| Operator::NotILike
| Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
| Operator::RegexNotIMatch => None,
// Bitwise oprators should cast strings to the other side type
Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseShiftRight
| Operator::BitwiseShiftLeft => Some(other_type.clone()),
// String concat allows string on either side, no casting needed
Operator::StringConcat => None,
}
}

/// Evaluates an expression to a constant if possible.
fn evaluate_expr(optimizer: &PlanNormalize, expr: Expr) -> Result<Expr> {
let execution_props = &optimizer.cube_ctx.state.execution_props;
let mut const_evaluator = ConstEvaluator::new(execution_props);
expr.rewrite(&mut const_evaluator)
}
38 changes: 35 additions & 3 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15756,7 +15756,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
member: Some("KibanaSampleDataEcommerce.order_date".to_string()),
operator: Some("inDateRange".to_string()),
values: Some(vec![
"2019-01-01 00:00:00.0".to_string(),
"2019-01-01T00:00:00.000Z".to_string(),
"2019-12-31T23:59:59.999Z".to_string(),
]),
or: None,
Expand All @@ -15766,7 +15766,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
member: Some("KibanaSampleDataEcommerce.order_date".to_string()),
operator: Some("inDateRange".to_string()),
values: Some(vec![
"2021-01-01 00:00:00.0".to_string(),
"2021-01-01T00:00:00.000Z".to_string(),
"2021-12-31T23:59:59.999Z".to_string(),
]),
or: None,
Expand Down Expand Up @@ -17576,7 +17576,6 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
println!("Generated SQL: {}", sql);
assert!(sql.contains("2025-01-01"));
assert!(sql.contains("customer_gender} IN (SELECT"));

Expand All @@ -17586,4 +17585,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
displayable(physical_plan.as_ref()).indent()
);
}

#[tokio::test]
async fn test_string_literal_auto_cast() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
r#"
SELECT id
FROM KibanaSampleDataEcommerce
WHERE
LOWER(customer_gender) != 'unknown'
AND has_subscription = 'TRUE'
GROUP BY 1
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
assert!(sql.contains("${KibanaSampleDataEcommerce.has_subscription} = TRUE"));
assert!(!sql.contains("'TRUE'"));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}
}
3 changes: 3 additions & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3610,6 +3610,9 @@ impl FilterRules {
let year = match year {
ScalarValue::Int64(Some(year)) => year,
ScalarValue::Int32(Some(year)) => year as i64,
ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(&year) => {
year.round() as i64
}
ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => {
if let Ok(year) = year_str.parse::<i64>() {
year
Expand Down
2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/src/compile/test/test_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ GROUP BY
V1LoadRequestQueryFilterItem {
member: Some("MultiTypeCube.dim_date0".to_string()),
operator: Some("afterDate".to_string()),
values: Some(vec!["2019-01-01 00:00:00".to_string()]),
values: Some(vec!["2019-01-01T00:00:00.000Z".to_string()]),
or: None,
and: None,
},
Expand Down
Loading