Skip to content

Commit ea1fa9c

Browse files
committed
feat(cubesql): Support Utf8 * Interval expression
1 parent 08bffcb commit ea1fa9c

File tree

8 files changed

+108
-50
lines changed

8 files changed

+108
-50
lines changed

packages/cubejs-backend-native/Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/cubesql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ homepage = "https://cube.dev"
1010

1111
[dependencies]
1212
arc-swap = "1"
13-
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "3fb79c32defa4ec39ccf658b6185704d1980f228", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
13+
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "5833202f5a9b69fded3cd45ffc9041ca5c404d33", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
1414
anyhow = "1.0"
1515
thiserror = "1.0"
1616
cubeclient = { path = "../cubeclient" }

rust/cubesql/cubesql/src/compile/engine/udf.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use datafusion::{
1111
StructBuilder, TimestampMicrosecondArray, TimestampMillisecondArray,
1212
TimestampNanosecondArray, TimestampSecondArray, UInt32Builder,
1313
},
14-
compute::{cast, concat},
14+
compute::{cast, cast_with_options, concat, CastOptions},
1515
datatypes::{
1616
DataType, Date32Type, Field, Float64Type, Int32Type, Int64Type, IntervalDayTimeType,
1717
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampNanosecondType, UInt32Type,
@@ -1129,7 +1129,7 @@ macro_rules! date_math_udf {
11291129
let intervals = downcast_primitive_arg!(&$ARGS[1], "interval", $SECOND_ARG_TYPE);
11301130
let mut builder = TimestampNanosecondArray::builder(timestamps.len());
11311131
for i in 0..timestamps.len() {
1132-
if timestamps.is_null(i) {
1132+
if timestamps.is_null(i) || intervals.is_null(i) {
11331133
builder.append_null()?;
11341134
} else {
11351135
let timestamp = timestamps.value_as_datetime(i).unwrap();
@@ -1348,7 +1348,20 @@ pub fn create_interval_mul_udf() -> ScalarUDF {
13481348
let fun = make_scalar_function(move |args: &[ArrayRef]| {
13491349
assert!(args.len() == 2);
13501350

1351-
let multiplicands = downcast_primitive_arg!(args[1], "multiplicand", Int64Type);
1351+
let multiplicands = match args[1].data_type() {
1352+
DataType::Utf8 => {
1353+
cast_with_options(&args[1], &DataType::Int64, &CastOptions { safe: false })?
1354+
}
1355+
DataType::Int64 => Arc::clone(&args[1]),
1356+
t => {
1357+
return Err(DataFusionError::Execution(format!(
1358+
"unsupported multiplicand type {}",
1359+
t
1360+
)))
1361+
}
1362+
};
1363+
1364+
let multiplicands = downcast_primitive_arg!(multiplicands, "multiplicand", Int64Type);
13521365

13531366
match &args[0].data_type() {
13541367
DataType::Interval(IntervalUnit::YearMonth) => {
@@ -1425,6 +1438,18 @@ pub fn create_interval_mul_udf() -> ScalarUDF {
14251438
DataType::Interval(IntervalUnit::MonthDayNano),
14261439
DataType::Int64,
14271440
]),
1441+
TypeSignature::Exact(vec![
1442+
DataType::Interval(IntervalUnit::YearMonth),
1443+
DataType::Utf8,
1444+
]),
1445+
TypeSignature::Exact(vec![
1446+
DataType::Interval(IntervalUnit::DayTime),
1447+
DataType::Utf8,
1448+
]),
1449+
TypeSignature::Exact(vec![
1450+
DataType::Interval(IntervalUnit::MonthDayNano),
1451+
DataType::Utf8,
1452+
]),
14281453
],
14291454
Volatility::Immutable,
14301455
),

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20035,4 +20035,20 @@ ORDER BY \"COUNT(count)\" DESC"
2003520035
.sql;
2003620036
assert!(sql.contains("+ '7 day'::interval"));
2003720037
}
20038+
20039+
#[tokio::test]
20040+
async fn test_string_multiply_interval() -> Result<(), CubeError> {
20041+
init_logger();
20042+
20043+
insta::assert_snapshot!(
20044+
"test_string_multiply_interval",
20045+
execute_query(
20046+
"SELECT NULL * INTERVAL '1 day' n, '5' * INTERVAL '1 day' d5".to_string(),
20047+
DatabaseProtocol::PostgreSQL
20048+
)
20049+
.await?
20050+
);
20051+
20052+
Ok(())
20053+
}
2003820054
}

rust/cubesql/cubesql/src/compile/rewrite/analysis.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,10 +805,29 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
805805

806806
fn modify(egraph: &mut EGraph<LogicalPlanLanguage, Self>, id: Id) {
807807
if let Some(ConstantFolding::Scalar(c)) = &egraph[id].data.constant {
808+
// TODO: ideally all constants should be aliased, but this requires
809+
// rewrites to extract `.data.constant` instead of `literal_expr`.
810+
let alias_name = if c.is_null() {
811+
egraph[id]
812+
.data
813+
.original_expr
814+
.as_ref()
815+
.map(|expr| expr.name(&DFSchema::empty()).unwrap())
816+
} else {
817+
None
818+
};
808819
let c = c.clone();
809820
let value = egraph.add(LogicalPlanLanguage::LiteralExprValue(LiteralExprValue(c)));
810821
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([value]));
811-
egraph.union(id, literal_expr);
822+
if let Some(alias_name) = alias_name {
823+
let alias = egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
824+
alias_name,
825+
)));
826+
let alias_expr = egraph.add(LogicalPlanLanguage::AliasExpr([literal_expr, alias]));
827+
egraph.union(id, alias_expr);
828+
} else {
829+
egraph.union(id, literal_expr);
830+
}
812831
}
813832
}
814833
}

rust/cubesql/cubesql/src/compile/rewrite/rules/dates.rs

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ use crate::{
33
compile::{
44
engine::provider::CubeContext,
55
rewrite::{
6-
agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, binary_expr, cast_expr,
7-
cast_expr_explicit, column_expr, fun_expr, literal_expr, literal_int, literal_string,
8-
negative_expr, rewrite, rewriter::RewriteRules, to_day_interval_expr,
9-
transforming_rewrite, transforming_rewrite_with_root, udf_expr, AliasExprAlias,
10-
CastExprDataType, LiteralExprValue, LogicalPlanLanguage,
6+
agg_fun_expr, alias_expr,
7+
analysis::{ConstantFolding, LogicalPlanAnalysis},
8+
binary_expr, cast_expr, cast_expr_explicit, column_expr, fun_expr, literal_expr,
9+
literal_int, literal_string, negative_expr, rewrite,
10+
rewriter::RewriteRules,
11+
to_day_interval_expr, transforming_rewrite, transforming_rewrite_with_root, udf_expr,
12+
AliasExprAlias, CastExprDataType, LiteralExprValue, LogicalPlanLanguage,
1113
},
1214
},
1315
var, var_iter,
@@ -331,55 +333,40 @@ impl RewriteRules for DateRules {
331333
),
332334
transforming_rewrite(
333335
"binary-expr-interval-add-right",
334-
binary_expr("?left", "+", literal_expr("?interval")),
335-
udf_expr(
336-
"date_add",
337-
vec!["?left".to_string(), literal_expr("?interval")],
338-
),
336+
binary_expr("?left", "+", "?interval"),
337+
udf_expr("date_add", vec!["?left", "?interval"]),
339338
self.transform_interval_binary_expr("?interval"),
340339
),
341340
transforming_rewrite(
342341
"binary-expr-interval-add-left",
343-
binary_expr(literal_expr("?interval"), "+", "?right"),
344-
udf_expr(
345-
"date_add",
346-
vec!["?right".to_string(), literal_expr("?interval")],
347-
),
342+
binary_expr("?interval", "+", "?right"),
343+
udf_expr("date_add", vec!["?right", "?interval"]),
348344
self.transform_interval_binary_expr("?interval"),
349345
),
350346
transforming_rewrite(
351347
"binary-expr-interval-sub",
352-
binary_expr("?left", "-", literal_expr("?interval")),
353-
udf_expr(
354-
"date_sub",
355-
vec!["?left".to_string(), literal_expr("?interval")],
356-
),
348+
binary_expr("?left", "-", "?interval"),
349+
udf_expr("date_sub", vec!["?left", "?interval"]),
357350
self.transform_interval_binary_expr("?interval"),
358351
),
359352
transforming_rewrite(
360353
"binary-expr-interval-mul-right",
361-
binary_expr("?left", "*", literal_expr("?interval")),
362-
udf_expr(
363-
"interval_mul",
364-
vec![literal_expr("?interval"), "?left".to_string()],
365-
),
354+
binary_expr("?left", "*", "?interval"),
355+
udf_expr("interval_mul", vec!["?interval", "?left"]),
366356
self.transform_interval_binary_expr("?interval"),
367357
),
368358
transforming_rewrite(
369359
"binary-expr-interval-mul-left",
370-
binary_expr(literal_expr("?interval"), "*", "?right"),
371-
udf_expr(
372-
"interval_mul",
373-
vec![literal_expr("?interval"), "?right".to_string()],
374-
),
360+
binary_expr("?interval", "*", "?right"),
361+
udf_expr("interval_mul", vec!["?interval", "?right"]),
375362
self.transform_interval_binary_expr("?interval"),
376363
),
377364
transforming_rewrite(
378365
"binary-expr-interval-neg",
379-
negative_expr(literal_expr("?interval")),
366+
negative_expr("?interval"),
380367
udf_expr(
381368
"interval_mul",
382-
vec![literal_expr("?interval"), literal_int(-1)],
369+
vec!["?interval".to_string(), literal_int(-1)],
383370
),
384371
self.transform_interval_binary_expr("?interval"),
385372
),
@@ -499,7 +486,9 @@ impl DateRules {
499486
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
500487
let interval_var = var!(interval_var);
501488
move |egraph, subst| {
502-
for interval in var_iter!(egraph[subst[interval_var]], LiteralExprValue) {
489+
if let Some(ConstantFolding::Scalar(interval)) =
490+
&egraph[subst[interval_var]].data.constant
491+
{
503492
match interval {
504493
ScalarValue::IntervalYearMonth(_)
505494
| ScalarValue::IntervalDayTime(_)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
source: cubesql/src/compile/mod.rs
3+
expression: "execute_query(\"SELECT NULL * INTERVAL '1 day' n, '5' * INTERVAL '1 day' d5\".to_string(),\n DatabaseProtocol::PostgreSQL).await?"
4+
---
5+
+------+------------------------------------------------+
6+
| n | d5 |
7+
+------+------------------------------------------------+
8+
| NULL | 0 years 0 mons 5 days 0 hours 0 mins 0.00 secs |
9+
+------+------------------------------------------------+

0 commit comments

Comments
 (0)