From b3e648e16d6b746efcd766bc775cbfc6d2598f33 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Tue, 12 Aug 2025 02:39:29 +0400 Subject: [PATCH] fix(cubesql): Improve SQL push down for Athena/Presto Signed-off-by: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> --- .../src/adapter/PrestodbQuery.ts | 4 +- .../engine/df/optimizers/plan_normalize.rs | 265 ++++++++++++++---- rust/cubesql/cubesql/src/compile/mod.rs | 166 ++++++++++- .../cubesql/src/compile/query_engine.rs | 2 +- .../src/compile/rewrite/rules/dates.rs | 38 +++ 5 files changed, 413 insertions(+), 62 deletions(-) diff --git a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts index 43bdd06da45b2..3d4810c66b621 100644 --- a/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts +++ b/packages/cubejs-schema-compiler/src/adapter/PrestodbQuery.ts @@ -157,8 +157,8 @@ export class PrestodbQuery extends BaseQuery { templates.expressions.timestamp_literal = 'from_iso8601_timestamp(\'{{ value }}\')'; // Presto requires concat types to be VARCHAR templates.expressions.binary = '{% if op == \'||\' %}' + - 'CAST({{ left }} AS VARCHAR) || CAST({{ right }} AS VARCHAR)' + - '{% else %}{{ left }} {{ op }} {{ right }}{% endif %}'; + '(CAST({{ left }} AS VARCHAR) || CAST({{ right }} AS VARCHAR))' + + '{% else %}({{ left }} {{ op }} {{ right }}){% endif %}'; delete templates.expressions.ilike; templates.types.string = 'VARCHAR'; templates.types.float = 'REAL'; diff --git a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs index 57c7c97da72c7..bbf3427b31cd5 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/optimizers/plan_normalize.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use datafusion::{ + arrow::datatypes::DataType, error::{DataFusionError, Result}, logical_expr::{BuiltinScalarFunction, Expr, GroupingSet, Like}, logical_plan::{ @@ -10,28 +11,34 @@ use datafusion::{ Limit, Partitioning, Projection, Repartition, Sort, Subquery, TableScan, TableUDFs, Union, Values, Window, }, - union_with_alias, Column, DFSchema, LogicalPlan, LogicalPlanBuilder, + union_with_alias, Column, DFSchema, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + Operator, }, optimizer::optimizer::{OptimizerConfig, OptimizerRule}, scalar::ScalarValue, + sql::planner::ContextProvider, }; -use crate::compile::rewrite::rules::utils::DatePartToken; +use crate::compile::{engine::CubeContext, rewrite::rules::utils::DatePartToken}; /// PlanNormalize optimizer rule walks through the query and applies transformations /// to normalize the logical plan structure and expressions. /// -/// Currently this includes replacing literal granularities in `DatePart` and `DateTrunc` functions -/// with their normalized equivalents. -pub struct PlanNormalize {} +/// Currently this includes replacing: +/// - literal granularities in `DatePart` and `DateTrunc` functions +/// with their normalized equivalents +/// - replacing `DATE - DATE` expressions with `DATEDIFF` equivalent +pub struct PlanNormalize<'a> { + cube_ctx: &'a CubeContext, +} -impl PlanNormalize { - pub fn new() -> Self { - Self {} +impl<'a> PlanNormalize<'a> { + pub fn new(cube_ctx: &'a CubeContext) -> Self { + Self { cube_ctx } } } -impl OptimizerRule for PlanNormalize { +impl OptimizerRule for PlanNormalize<'_> { fn optimize( &self, plan: &LogicalPlan, @@ -63,9 +70,12 @@ fn plan_normalize( alias, }) => { let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; + let schema = input.schema(); let new_expr = expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let alias = alias.clone(); @@ -93,8 +103,14 @@ fn plan_normalize( LogicalPlan::Filter(Filter { predicate, input }) => { let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; - let predicate = - expr_normalize(optimizer, predicate, remapped_columns, optimizer_config)?; + let schema = input.schema(); + let predicate = expr_normalize( + optimizer, + predicate, + schema, + remapped_columns, + optimizer_config, + )?; LogicalPlanBuilder::from(input).filter(predicate)?.build() } @@ -105,9 +121,12 @@ fn plan_normalize( schema: _, }) => { let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; + let schema = input.schema(); let new_window_expr = window_expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; for (window_expr, new_window_expr) in window_expr.iter().zip(new_window_expr.iter()) { @@ -132,13 +151,18 @@ fn plan_normalize( schema: _, }) => { let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; + let schema = input.schema(); let new_group_expr = group_expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let new_aggr_expr = aggr_expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; *remapped_columns = HashMap::new(); @@ -167,11 +191,14 @@ fn plan_normalize( } LogicalPlan::Sort(Sort { expr, input }) => { + let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; + let schema = input.schema(); let expr = expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; - let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; LogicalPlanBuilder::from(input).sort(expr)?.build() } @@ -262,13 +289,20 @@ fn plan_normalize( partitioning_scheme, }) => { let input = plan_normalize(optimizer, input, remapped_columns, optimizer_config)?; + let schema = input.schema(); let partitioning_scheme = match partitioning_scheme { Partitioning::RoundRobinBatch(n) => Partitioning::RoundRobinBatch(*n), Partitioning::Hash(exprs, n) => { let exprs = exprs .iter() .map(|expr| { - expr_normalize(optimizer, expr, remapped_columns, optimizer_config) + expr_normalize( + optimizer, + expr, + schema, + remapped_columns, + optimizer_config, + ) }) .collect::>>()?; Partitioning::Hash(exprs, *n) @@ -321,7 +355,15 @@ fn plan_normalize( let projected_schema = Arc::clone(projected_schema); let filters = filters .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize( + optimizer, + expr, + &projected_schema, + remapped_columns, + optimizer_config, + ) + }) .collect::>>()?; let fetch = *fetch; @@ -391,13 +433,19 @@ fn plan_normalize( p @ LogicalPlan::DropTable(_) => Ok(p.clone()), - LogicalPlan::Values(Values { schema: _, values }) => { + LogicalPlan::Values(Values { schema, values }) => { let values = values .iter() .map(|row| { row.iter() .map(|expr| { - expr_normalize(optimizer, expr, remapped_columns, optimizer_config) + expr_normalize( + optimizer, + expr, + schema, + remapped_columns, + optimizer_config, + ) }) .collect::>>() }) @@ -448,11 +496,14 @@ fn plan_normalize( remapped_columns, optimizer_config, )?); + let schema = input.schema(); let new_expr = expr .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; - let schema = build_table_udf_schema(&input, &new_expr)?; + let new_schema = build_table_udf_schema(&input, &new_expr)?; for (expr, new_expr) in expr.iter().zip(new_expr.iter()) { let old_name = expr.name(&DFSchema::empty())?; @@ -467,7 +518,7 @@ fn plan_normalize( Ok(LogicalPlan::TableUDFs(TableUDFs { expr: new_expr, input, - schema, + schema: new_schema, })) } @@ -492,6 +543,7 @@ fn plan_normalize( fn expr_normalize( optimizer: &PlanNormalize, expr: &Expr, + schema: &DFSchema, remapped_columns: &HashMap, optimizer_config: &OptimizerConfig, ) -> Result { @@ -500,6 +552,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -522,22 +575,15 @@ fn expr_normalize( e @ Expr::Literal(..) => Ok(e.clone()), - Expr::BinaryExpr { left, op, right } => { - let left = Box::new(expr_normalize( - optimizer, - left, - remapped_columns, - optimizer_config, - )?); - let op = *op; - let right = Box::new(expr_normalize( - optimizer, - right, - remapped_columns, - optimizer_config, - )?); - Ok(Expr::BinaryExpr { left, op, right }) - } + Expr::BinaryExpr { left, op, right } => binary_expr_normalize( + optimizer, + left, + op, + right, + schema, + remapped_columns, + optimizer_config, + ), Expr::AnyExpr { left, @@ -548,6 +594,7 @@ fn expr_normalize( let left = Box::new(expr_normalize( optimizer, left, + schema, remapped_columns, optimizer_config, )?); @@ -555,6 +602,7 @@ fn expr_normalize( let right = Box::new(expr_normalize( optimizer, right, + schema, remapped_columns, optimizer_config, )?); @@ -577,12 +625,14 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let pattern = Box::new(expr_normalize( optimizer, pattern, + schema, remapped_columns, optimizer_config, )?); @@ -605,12 +655,14 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let pattern = Box::new(expr_normalize( optimizer, pattern, + schema, remapped_columns, optimizer_config, )?); @@ -633,12 +685,14 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let pattern = Box::new(expr_normalize( optimizer, pattern, + schema, remapped_columns, optimizer_config, )?); @@ -655,6 +709,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -665,6 +720,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -675,6 +731,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -685,6 +742,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -695,12 +753,14 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let key = Box::new(expr_normalize( optimizer, key, + schema, remapped_columns, optimizer_config, )?); @@ -716,6 +776,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -723,12 +784,14 @@ fn expr_normalize( let low = Box::new(expr_normalize( optimizer, low, + schema, remapped_columns, optimizer_config, )?); let high = Box::new(expr_normalize( optimizer, high, + schema, remapped_columns, optimizer_config, )?); @@ -751,6 +814,7 @@ fn expr_normalize( Ok::<_, DataFusionError>(Box::new(expr_normalize( optimizer, e, + schema, remapped_columns, optimizer_config, )?)) @@ -763,12 +827,14 @@ fn expr_normalize( Box::new(expr_normalize( optimizer, when, + schema, remapped_columns, optimizer_config, )?), Box::new(expr_normalize( optimizer, then, + schema, remapped_columns, optimizer_config, )?), @@ -781,6 +847,7 @@ fn expr_normalize( Ok::<_, DataFusionError>(Box::new(expr_normalize( optimizer, e, + schema, remapped_columns, optimizer_config, )?)) @@ -797,6 +864,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -808,6 +876,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -823,6 +892,7 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); @@ -840,6 +910,7 @@ fn expr_normalize( optimizer, fun, args, + schema, remapped_columns, optimizer_config, )?; @@ -850,7 +921,9 @@ fn expr_normalize( let fun = Arc::clone(fun); let args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| { + expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; Ok(Expr::ScalarUDF { fun, args }) } @@ -859,7 +932,9 @@ fn expr_normalize( let fun = Arc::clone(fun); let args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| { + expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; Ok(Expr::TableUDF { fun, args }) } @@ -873,14 +948,18 @@ fn expr_normalize( let fun = fun.clone(); let args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| { + expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let distinct = *distinct; let within_group = within_group .as_ref() .map(|expr| { expr.iter() - .map(|e| expr_normalize(optimizer, e, remapped_columns, optimizer_config)) + .map(|e| { + expr_normalize(optimizer, e, schema, remapped_columns, optimizer_config) + }) .collect::>>() }) .transpose()?; @@ -902,15 +981,21 @@ fn expr_normalize( let fun = fun.clone(); let args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| { + expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let partition_by = partition_by .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let order_by = order_by .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; let window_frame = *window_frame; Ok(Expr::WindowFunction { @@ -926,7 +1011,9 @@ fn expr_normalize( let fun = Arc::clone(fun); let args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| { + expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; Ok(Expr::AggregateUDF { fun, args }) } @@ -939,12 +1026,13 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let list = list .iter() - .map(|e| expr_normalize(optimizer, e, remapped_columns, optimizer_config)) + .map(|e| expr_normalize(optimizer, e, schema, remapped_columns, optimizer_config)) .collect::>>()?; let negated = *negated; Ok(Expr::InList { @@ -962,12 +1050,14 @@ fn expr_normalize( let expr = Box::new(expr_normalize( optimizer, expr, + schema, remapped_columns, optimizer_config, )?); let subquery = Box::new(expr_normalize( optimizer, subquery, + schema, remapped_columns, optimizer_config, )?); @@ -987,6 +1077,7 @@ fn expr_normalize( let grouping_set = grouping_set_normalize( optimizer, grouping_set, + schema, remapped_columns, optimizer_config, )?; @@ -1015,13 +1106,14 @@ fn scalar_function_normalize( optimizer: &PlanNormalize, fun: &BuiltinScalarFunction, args: &[Expr], + schema: &DFSchema, remapped_columns: &HashMap, optimizer_config: &OptimizerConfig, ) -> Result<(BuiltinScalarFunction, Vec)> { let fun = fun.clone(); let mut args = args .iter() - .map(|arg| expr_normalize(optimizer, arg, remapped_columns, optimizer_config)) + .map(|arg| expr_normalize(optimizer, arg, schema, remapped_columns, optimizer_config)) .collect::>>()?; // If the function is `DatePart` or `DateTrunc` and the first argument is a literal string, @@ -1048,6 +1140,7 @@ fn scalar_function_normalize( fn grouping_set_normalize( optimizer: &PlanNormalize, grouping_set: &GroupingSet, + schema: &DFSchema, remapped_columns: &HashMap, optimizer_config: &OptimizerConfig, ) -> Result { @@ -1055,7 +1148,9 @@ fn grouping_set_normalize( GroupingSet::Rollup(exprs) => { let exprs = exprs .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; Ok(GroupingSet::Rollup(exprs)) } @@ -1063,7 +1158,9 @@ fn grouping_set_normalize( GroupingSet::Cube(exprs) => { let exprs = exprs .iter() - .map(|expr| expr_normalize(optimizer, expr, remapped_columns, optimizer_config)) + .map(|expr| { + expr_normalize(optimizer, expr, schema, remapped_columns, optimizer_config) + }) .collect::>>()?; Ok(GroupingSet::Cube(exprs)) } @@ -1075,7 +1172,13 @@ fn grouping_set_normalize( Ok(exprs .iter() .map(|expr| { - expr_normalize(optimizer, expr, remapped_columns, optimizer_config) + expr_normalize( + optimizer, + expr, + schema, + remapped_columns, + optimizer_config, + ) }) .collect::>>()?) }) @@ -1084,3 +1187,59 @@ fn grouping_set_normalize( } } } + +/// Recursively normalizes binary expressions. +/// Currently this includes replacing `DATE - DATE` expressions +/// with respective `DATEDIFF` function calls. +fn binary_expr_normalize( + optimizer: &PlanNormalize, + left: &Expr, + op: &Operator, + right: &Expr, + schema: &DFSchema, + remapped_columns: &HashMap, + optimizer_config: &OptimizerConfig, +) -> Result { + let left = Box::new(expr_normalize( + optimizer, + left, + schema, + remapped_columns, + optimizer_config, + )?); + let op = *op; + let right = Box::new(expr_normalize( + optimizer, + right, + schema, + remapped_columns, + optimizer_config, + )?); + + // Check if the expression is `DATE - DATE` and replace it with `DATEDIFF` with same semantics. + // Rationale to do this in optimizer than rewrites is that while the expression + // can be rewritten to something else, a binary variation still exists and would be picked + // for SQL push down generation either way. This creates an issue in dialects + // other than Postgres that would return INTERVAL on `DATE - DATE` expression. + if left.get_type(schema)? == DataType::Date32 + && op == Operator::Minus + && right.get_type(schema)? == DataType::Date32 + { + let fun = optimizer + .cube_ctx + .get_function_meta("datediff") + .ok_or_else(|| { + DataFusionError::Internal( + "Unable to find 'datediff' function in cube context".to_string(), + ) + })?; + let args = vec![ + Expr::Literal(ScalarValue::Utf8(Some("day".to_string()))), + *right, + *left, + ]; + return Ok(Expr::ScalarUDF { fun, args }); + } + + Ok(Expr::BinaryExpr { left, op, right }) +} diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 2b3066d80ecaa..595c7b958af2f 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -14092,7 +14092,7 @@ ORDER BY "source"."str0" ASC async fn test_thoughtspot_pg_extract_day_of_quarter() { init_testing_logger(); - let logical_plan = convert_select_to_query_plan( + let query_plan = convert_select_to_query_plan( r#" SELECT (CAST("ta_1"."order_date" AS date) - CAST((CAST(EXTRACT(YEAR FROM "ta_1"."order_date") || '-' || EXTRACT(MONTH FROM "ta_1"."order_date") || '-01' AS DATE) + (((MOD(CAST((EXTRACT(MONTH FROM "ta_1"."order_date") - 1) AS numeric), 3) + 1) - 1) * -1) * INTERVAL '1 month') AS date) + 1) AS "ca_1", @@ -14106,8 +14106,22 @@ ORDER BY "source"."str0" ASC .to_string(), DatabaseProtocol::PostgreSQL, ) - .await - .as_logical_plan(); + .await; + let logical_plan = query_plan.as_logical_plan(); + + if Rewriter::sql_push_down_enabled() { + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains("DATEDIFF(day,")); + assert!(sql.contains("EXTRACT(year")); + assert!(sql.contains("EXTRACT(month")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + return; + } assert_eq!( logical_plan.find_cube_scan().request, @@ -17249,9 +17263,9 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), DatabaseProtocol::PostgreSQL, vec![( "expressions/binary".to_string(), - "'{% if op == \'||\' %}CAST({{ left }} AS VARCHAR) || \ - CAST({{ right }} AS VARCHAR)\ - {% else %}{{ left }} {{ op }} {{ right }}{% endif %}'" + "{% if op == \'||\' %}(CAST({{ left }} AS VARCHAR) || \ + CAST({{ right }} AS VARCHAR))\ + {% else %}({{ left }} {{ op }} {{ right }}){% endif %}" .to_string(), )], ) @@ -17314,4 +17328,144 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATE_DIFF('day', ")); } + + #[tokio::test] + async fn test_athena_binary_expr_brackets() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan_customized( + r#" + SELECT + CAST( + EXTRACT(YEAR FROM "ta_1"."order_date") || '-' || + ((FLOOR(((EXTRACT(MONTH FROM "ta_1"."order_date") - 1) / NULLIF(3, 0))) * 3) + 1) + || '-01' AS DATE + ) AS "ca_1", + COALESCE(sum("ta_1"."sumPrice"), 0) AS "ca_2" + FROM "ovr"."public"."KibanaSampleDataEcommerce" AS "ta_1" + GROUP BY "ca_1" + ORDER BY "ca_1" ASC NULLS LAST + LIMIT 10000 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + vec![( + "expressions/binary".to_string(), + "{% if op == \'||\' %}(CAST({{ left }} AS VARCHAR) || \ + CAST({{ right }} AS VARCHAR))\ + {% else %}({{ left }} {{ op }} {{ right }}){% endif %}" + .to_string(), + )], + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains(" - 1) / 3)")); + } + + #[tokio::test] + async fn test_athena_date_part_over_age() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan_customized( + r#" + SELECT + DATE_TRUNC('MONTH', CAST("ta_1"."order_date" AS date)) AS "ca_1", + COALESCE(sum("ta_1"."sumPrice"), 0) AS "ca_2", + min(CAST( + DATE_PART('year', AGE("ta_1"."order_date", DATE '1970-01-01')) * 12 + + DATE_PART('month', AGE("ta_1"."order_date", DATE '1970-01-01')) + AS int + )) AS "ca_3", + min( + (MOD(CAST((EXTRACT(MONTH FROM "ta_1"."order_date") - 1) AS numeric), 3) + 1) + ) AS "ca_4", + min(CEIL((EXTRACT(MONTH FROM "ta_1"."order_date") / NULLIF(3.0, 0.0)))) AS "ca_5" + FROM "ovr"."public"."KibanaSampleDataEcommerce" AS "ta_1" + GROUP BY "ca_1" + ORDER BY "ca_1" ASC NULLS LAST + LIMIT 5000 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + vec![( + "functions/DATEDIFF".to_string(), + "DATE_DIFF('{{ date_part }}', {{ args[1] }}, {{ args[2] }})".to_string(), + )], + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains("DATE_DIFF('month', ")); + } + + #[tokio::test] + async fn test_athena_date_minus_date() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan_customized( + r#" + SELECT + DATE_TRUNC('week', "ta_1"."order_date") AS "ca_1", + COALESCE(sum("ta_1"."sumPrice"), 0) AS "ca_2", + min((CEIL(((( + CAST("ta_1"."order_date" AS date) - CAST(DATE '1970-01-01' AS date) + 1 + 7) - 4 + ) / NULLIF(7.0, 0.0))) - 1)) AS "ca_3", + min(FLOOR((( + EXTRACT(DAY FROM ( + ("ta_1"."order_date") + ((4 - (MOD(CAST(( + CAST("ta_1"."order_date" AS date) - CAST(DATE '1970-01-01' AS date) + 3 + ) AS numeric), 7) + 1))) * INTERVAL '1 day')) + 6 + ) / NULLIF(7, 0)))) AS "ca_4", + min( + (MOD(CAST((EXTRACT(MONTH FROM "ta_1"."order_date") - 1) AS numeric), 3) + 1) + ) AS "ca_6", + min(CEIL((EXTRACT(MONTH FROM "ta_1"."order_date") / NULLIF(3.0, 0.0)))) AS "ca_7" + FROM "ovr"."public"."KibanaSampleDataEcommerce" AS "ta_1" + GROUP BY "ca_1" + ORDER BY "ca_1" ASC NULLS LAST + LIMIT 5000 + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + vec![( + "functions/DATEDIFF".to_string(), + "DATE_DIFF('{{ date_part }}', {{ args[1] }}, {{ args[2] }})".to_string(), + )], + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; + assert!(sql.contains("DATE_DIFF('day', ")); + } } diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index 62f4e56d806ce..705b001d484e7 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -140,7 +140,7 @@ pub trait QueryEngine { let optimizer_config = OptimizerConfig::new(); let optimizers: Vec> = vec![ - Arc::new(PlanNormalize::new()), + Arc::new(PlanNormalize::new(&cube_ctx)), Arc::new(ProjectionDropOut::new()), Arc::new(FilterPushDown::new()), Arc::new(SortPushDown::new()), diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs index b0b7702be2ce5..38520eaa10d02 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs @@ -409,6 +409,44 @@ impl RewriteRules for DateRules { "?new_granularity", ), ), + // AGE function seems to be a popular choice for this date arithmetic, + // but it is not supported in SQL push down by most dialects. + transforming_rewrite_with_root( + "thoughtspot-date-part-over-age-as-datediff-month", + binary_expr( + binary_expr( + self.fun_expr( + "DatePart", + vec![ + literal_string("year"), + udf_expr("age", vec!["?newer_date", "?older_date"]), + ], + ), + "*", + literal_int(12), + ), + "+", + self.fun_expr( + "DatePart", + vec![ + literal_string("month"), + udf_expr("age", vec!["?newer_date", "?older_date"]), + ], + ), + ), + alias_expr( + udf_expr( + "datediff", + vec![ + literal_string("month"), + "?older_date".to_string(), + "?newer_date".to_string(), + ], + ), + "?alias", + ), + self.transform_root_alias("?alias"), + ), ] } }