diff --git a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs index bbf3427b31cd5..92fac8681dc1b 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs @@ -11,10 +11,13 @@ use datafusion::{ Limit, Partitioning, Projection, Repartition, Sort, Subquery, TableScan, TableUDFs, Union, Values, Window, }, - union_with_alias, Column, DFSchema, ExprSchemable, LogicalPlan, LogicalPlanBuilder, - Operator, + union_with_alias, Column, DFSchema, ExprRewritable, ExprSchemable, LogicalPlan, + LogicalPlanBuilder, Operator, + }, + optimizer::{ + optimizer::{OptimizerConfig, OptimizerRule}, + simplify_expressions::ConstEvaluator, }, - optimizer::optimizer::{OptimizerConfig, OptimizerRule}, scalar::ScalarValue, sql::planner::ContextProvider, }; @@ -27,7 +30,9 @@ use crate::compile::{engine::CubeContext, rewrite::rules::utils::DatePartToken}; /// Currently this includes replacing: /// - literal granularities in `DatePart` and `DateTrunc` functions /// with their normalized equivalents -/// - replacing `DATE - DATE` expressions with `DATEDIFF` equivalent +/// - `DATE - DATE` expressions with `DATEDIFF` equivalent +/// - binary operations between a literal string and an expression +/// of a different type to a string casted to that type pub struct PlanNormalize<'a> { cube_ctx: &'a CubeContext, } @@ -1189,8 +1194,10 @@ fn grouping_set_normalize( } /// Recursively normalizes binary expressions. -/// Currently this includes replacing `DATE - DATE` expressions -/// with respective `DATEDIFF` function calls. +/// Currently this includes replacing: +/// - `DATE - DATE` expressions with respective `DATEDIFF` function calls +/// - binary operations between a literal string and an expression +/// of a different type to a string casted to that type fn binary_expr_normalize( optimizer: &PlanNormalize, left: &Expr, @@ -1221,10 +1228,9 @@ fn binary_expr_normalize( // can be rewritten to something else, a binary variation still exists and would be picked // for SQL push down generation either way. This creates an issue in dialects // other than Postgres that would return INTERVAL on `DATE - DATE` expression. - if left.get_type(schema)? == DataType::Date32 - && op == Operator::Minus - && right.get_type(schema)? == DataType::Date32 - { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + if left_type == DataType::Date32 && op == Operator::Minus && right_type == DataType::Date32 { let fun = optimizer .cube_ctx .get_function_meta("datediff") @@ -1241,5 +1247,84 @@ fn binary_expr_normalize( return Ok(Expr::ScalarUDF { fun, args }); } - Ok(Expr::BinaryExpr { left, op, right }) + // Check if one side of the binary expression is a literal string. If that's the case, + // attempt to cast the string to other type based on the operator and type on the other side. + // If none of the sides is a literal string, the normalization is complete. + let (other_type, literal_on_the_left) = match (left.as_ref(), right.as_ref()) { + (_, Expr::Literal(ScalarValue::Utf8(Some(_)))) => (left_type, false), + (Expr::Literal(ScalarValue::Utf8(Some(_))), _) => (right_type, true), + _ => return Ok(Expr::BinaryExpr { left, op, right }), + }; + let Some(cast_type) = binary_expr_cast_literal(&op, &other_type) else { + return Ok(Expr::BinaryExpr { left, op, right }); + }; + if literal_on_the_left { + let new_left = evaluate_expr(optimizer, left.cast_to(&cast_type, schema)?)?; + Ok(Expr::BinaryExpr { + left: Box::new(new_left), + op, + right, + }) + } else { + let new_right = evaluate_expr(optimizer, right.cast_to(&cast_type, schema)?)?; + Ok(Expr::BinaryExpr { + left, + op, + right: Box::new(new_right), + }) + } +} + +/// Returns the type a literal string should be casted to based on the operator +/// and the type on the other side of the binary expression. +/// If no casting is needed, returns `None`. +fn binary_expr_cast_literal(op: &Operator, other_type: &DataType) -> Option { + if other_type == &DataType::Utf8 { + // If the other side is a string, casting is never required + return None; + } + + match op { + // Comparison operators should cast strings to the other side type + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom => Some(other_type.clone()), + // Arithmetic operators should cast strings to the other side type + Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::Exponentiate => Some(other_type.clone()), + // Logical operators operate only on booleans + Operator::And | Operator::Or => Some(DataType::Boolean), + // LIKE and regexes operate only on strings, no casting needed + Operator::Like + | Operator::NotLike + | Operator::ILike + | Operator::NotILike + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch => None, + // Bitwise oprators should cast strings to the other side type + Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft => Some(other_type.clone()), + // String concat allows string on either side, no casting needed + Operator::StringConcat => None, + } +} + +/// Evaluates an expression to a constant if possible. +fn evaluate_expr(optimizer: &PlanNormalize, expr: Expr) -> Result { + let execution_props = &optimizer.cube_ctx.state.execution_props; + let mut const_evaluator = ConstEvaluator::new(execution_props); + expr.rewrite(&mut const_evaluator) } diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index e6bee5af10214..193a41d385c10 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -15756,7 +15756,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), member: Some("KibanaSampleDataEcommerce.order_date".to_string()), operator: Some("inDateRange".to_string()), values: Some(vec![ - "2019-01-01 00:00:00.0".to_string(), + "2019-01-01T00:00:00.000Z".to_string(), "2019-12-31T23:59:59.999Z".to_string(), ]), or: None, @@ -15766,7 +15766,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), member: Some("KibanaSampleDataEcommerce.order_date".to_string()), operator: Some("inDateRange".to_string()), values: Some(vec![ - "2021-01-01 00:00:00.0".to_string(), + "2021-01-01T00:00:00.000Z".to_string(), "2021-12-31T23:59:59.999Z".to_string(), ]), or: None, @@ -17576,7 +17576,6 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; - println!("Generated SQL: {}", sql); assert!(sql.contains("2025-01-01")); assert!(sql.contains("customer_gender} IN (SELECT")); @@ -17586,4 +17585,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), displayable(physical_plan.as_ref()).indent() ); } + + #[tokio::test] + async fn test_string_literal_auto_cast() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT id + FROM KibanaSampleDataEcommerce + WHERE + LOWER(customer_gender) != 'unknown' + AND has_subscription = 'TRUE' + GROUP BY 1 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains("${KibanaSampleDataEcommerce.has_subscription} = TRUE")); + assert!(!sql.contains("'TRUE'")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs index c34cc60e43762..a8d5aa697f348 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs @@ -3610,6 +3610,9 @@ impl FilterRules { let year = match year { ScalarValue::Int64(Some(year)) => year, ScalarValue::Int32(Some(year)) => year as i64, + ScalarValue::Float64(Some(year)) if (1000.0..=9999.0).contains(&year) => { + year.round() as i64 + } ScalarValue::Utf8(Some(ref year_str)) if year_str.len() == 4 => { if let Ok(year) = year_str.parse::() { year diff --git a/rust/cubesql/cubesql/src/compile/test/test_filters.rs b/rust/cubesql/cubesql/src/compile/test/test_filters.rs index e07b5dbababaa..e2cb545d7c597 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_filters.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_filters.rs @@ -52,7 +52,7 @@ GROUP BY V1LoadRequestQueryFilterItem { member: Some("MultiTypeCube.dim_date0".to_string()), operator: Some("afterDate".to_string()), - values: Some(vec!["2019-01-01 00:00:00".to_string()]), + values: Some(vec!["2019-01-01T00:00:00.000Z".to_string()]), or: None, and: None, },