Skip to content

Commit e86f4be

Browse files
authored
fix(cubesql): Avoid constant folding for current_date() function duri… (#7498)
* fix(cubesql): Avoid constant folding for current_date() function during SQL push down * More accurate test * More accurate test: add check
1 parent 9f30775 commit e86f4be

File tree

7 files changed

+199
-108
lines changed

7 files changed

+199
-108
lines changed

rust/cubesql/cubesql/src/compile/engine/udf.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4787,5 +4787,15 @@ pub fn register_fun_stubs(mut ctx: SessionContext) -> SessionContext {
47874787
rettyp = Utf8
47884788
);
47894789

4790+
register_fun_stub!(
4791+
udf,
4792+
"eval_current_date",
4793+
argc = 0,
4794+
rettyp = Date32,
4795+
vol = Stable
4796+
);
4797+
4798+
register_fun_stub!(udf, "eval_now", argc = 0, rettyp = Timestamp, vol = Stable);
4799+
47904800
ctx
47914801
}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11749,7 +11749,7 @@ ORDER BY \"COUNT(count)\" DESC"
1174911749

1175011750
assert_eq!(
1175111751
logical_plan,
11752-
"Projection: Date32(\"0\") AS COL\
11752+
"Projection: currentdate() AS COL\
1175311753
\n EmptyRelation",
1175411754
);
1175511755

@@ -19277,6 +19277,35 @@ ORDER BY \"COUNT(count)\" DESC"
1927719277
);
1927819278
}
1927919279

19280+
#[tokio::test]
19281+
async fn test_tableau_custom_date_diff() {
19282+
if !Rewriter::sql_push_down_enabled() {
19283+
return;
19284+
}
19285+
init_logger();
19286+
19287+
let query_plan = convert_select_to_query_plan(
19288+
"SELECT SUM(CAST(FLOOR(EXTRACT(EPOCH FROM CAST(CURRENT_DATE() AS TIMESTAMP)) / 86400) - FLOOR(EXTRACT(EPOCH FROM CAST(order_date AS TIMESTAMP)) / 86400) AS BIGINT)) FROM KibanaSampleDataEcommerce a"
19289+
.to_string(),
19290+
DatabaseProtocol::PostgreSQL,
19291+
)
19292+
.await;
19293+
19294+
let logical_plan = query_plan.as_logical_plan();
19295+
assert!(logical_plan
19296+
.find_cube_scan_wrapper()
19297+
.wrapped_sql
19298+
.unwrap()
19299+
.sql
19300+
.contains("CURRENT_DATE()"));
19301+
19302+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
19303+
println!(
19304+
"Physical plan: {}",
19305+
displayable(physical_plan.as_ref()).indent()
19306+
);
19307+
}
19308+
1928019309
#[tokio::test]
1928119310
async fn test_thoughtspot_pg_date_trunc_year() {
1928219311
init_logger();

rust/cubesql/cubesql/src/compile/rewrite/analysis.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ use datafusion::{
2020
},
2121
logical_plan::{Column, DFSchema, Expr},
2222
physical_plan::{
23-
functions::Volatility, planner::DefaultPhysicalPlanner, ColumnarValue, PhysicalPlanner,
23+
functions::{BuiltinScalarFunction, Volatility},
24+
planner::DefaultPhysicalPlanner,
25+
ColumnarValue, PhysicalPlanner,
2426
},
2527
scalar::ScalarValue,
2628
};
@@ -486,7 +488,23 @@ impl LogicalPlanAnalysis {
486488
)
487489
.ok()?;
488490
if let Expr::ScalarUDF { fun, .. } = &expr {
489-
if &fun.name == "str_to_date"
491+
if &fun.name == "eval_now" {
492+
Self::eval_constant_expr(
493+
&egraph,
494+
&Expr::ScalarFunction {
495+
fun: BuiltinScalarFunction::Now,
496+
args: vec![],
497+
},
498+
)
499+
} else if &fun.name == "eval_current_date" {
500+
Self::eval_constant_expr(
501+
&egraph,
502+
&Expr::ScalarFunction {
503+
fun: BuiltinScalarFunction::CurrentDate,
504+
args: vec![],
505+
},
506+
)
507+
} else if &fun.name == "str_to_date"
490508
|| &fun.name == "date_add"
491509
|| &fun.name == "date_sub"
492510
|| &fun.name == "date"
@@ -511,8 +529,12 @@ impl LogicalPlanAnalysis {
511529
.ok()?;
512530

513531
if let Expr::ScalarFunction { fun, .. } = &expr {
514-
if fun.volatility() == Volatility::Immutable
515-
|| fun.volatility() == Volatility::Stable
532+
if (fun.volatility() == Volatility::Immutable
533+
|| fun.volatility() == Volatility::Stable)
534+
&& !matches!(
535+
fun,
536+
BuiltinScalarFunction::CurrentDate | BuiltinScalarFunction::Now
537+
)
516538
{
517539
Self::eval_constant_expr(&egraph, &expr)
518540
} else {
@@ -562,6 +584,7 @@ impl LogicalPlanAnalysis {
562584
Expr::Literal(ScalarValue::Utf8(value)) => match (value, data_type) {
563585
// Timezone set in Config
564586
(Some(_), DataType::Timestamp(_, _)) => (),
587+
(Some(_), DataType::Date32 | DataType::Date64) => (),
565588
_ => return None,
566589
},
567590
_ => (),
@@ -807,15 +830,16 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
807830
if let Some(ConstantFolding::Scalar(c)) = &egraph[id].data.constant {
808831
// TODO: ideally all constants should be aliased, but this requires
809832
// rewrites to extract `.data.constant` instead of `literal_expr`.
810-
let alias_name = if c.is_null() {
811-
egraph[id]
812-
.data
813-
.original_expr
814-
.as_ref()
815-
.map(|expr| expr.name(&DFSchema::empty()).unwrap())
816-
} else {
817-
None
818-
};
833+
let alias_name =
834+
if c.is_null() || matches!(c, ScalarValue::Date32(_) | ScalarValue::Date64(_)) {
835+
egraph[id]
836+
.data
837+
.original_expr
838+
.as_ref()
839+
.map(|expr| expr.name(&DFSchema::empty()).unwrap())
840+
} else {
841+
None
842+
};
819843
let c = c.clone();
820844
let value = egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(c)));
821845
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([value]));

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ crate::plan_to_language! {
379379
members: Vec<LogicalPlan>,
380380
aliases: Vec<(String, String)>,
381381
},
382-
FilterCastUnwrapReplacer {
382+
FilterSimplifyReplacer {
383383
filters: Vec<LogicalPlan>,
384384
},
385385
OrderReplacer {
@@ -1131,8 +1131,8 @@ fn filter_replacer(
11311131
)
11321132
}
11331133

1134-
fn filter_cast_unwrap_replacer(members: impl Display) -> String {
1135-
format!("(FilterCastUnwrapReplacer {})", members)
1134+
fn filter_simplify_replacer(members: impl Display) -> String {
1135+
format!("(FilterSimplifyReplacer {})", members)
11361136
}
11371137

11381138
fn inner_aggregate_split_replacer(members: impl Display, alias_to_cube: impl Display) -> String {

0 commit comments

Comments
 (0)