Skip to content

Commit 2b7cc30

Browse files
committed
fix(cubesql): Fix CASE type with NULL values
1 parent 64a7ebd commit 2b7cc30

File tree

18 files changed

+144
-65
lines changed

18 files changed

+144
-65
lines changed

packages/cubejs-backend-native/Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/cubesql/cubesql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ homepage = "https://cube.dev"
1010

1111
[dependencies]
1212
arc-swap = "1"
13-
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "a0b4a6d2953c67857a3e24343fb2cba8ce2297cd", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
13+
datafusion = { git = 'https://github.com/cube-js/arrow-datafusion.git', rev = "3c85ef6583587f5b0b037be5810e979bede9c7dc", default-features = false, features = ["regex_expressions", "unicode_expressions"] }
1414
anyhow = "1.0"
1515
thiserror = "1.0.50"
1616
cubeclient = { path = "../cubeclient" }

rust/cubesql/cubesql/e2e/tests/snapshots/e2e__tests__postgres__pg_test_types.snap

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
---
22
source: cubesql/e2e/tests/postgres.rs
3-
assertion_line: 297
43
expression: "self.print_query_result(res, with_description, true).await"
54
---
6-
Utf8(NULL) type: 25 (text)
5+
NULL type: 25 (text)
76
f32 type: 700 (float4)
87
f64 type: 701 (float8)
98
i16 type: 21 (int2)
@@ -27,8 +26,8 @@ interval_month_day_nano type: 1186 (interval)
2726
str_arr type: 1009 (_text)
2827
i64_arr type: 1016 (_int8)
2928
f64_arr type: 1022 (_float8)
30-
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
31-
| Utf8(NULL) | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | date | tsmp | interval_year_month | interval_day_time | interval_month_day_nano | str_arr | i64_arr | f64_arr |
32-
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
33-
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | 2022-04-25 | 2022-04-25 16:25:01.164774 | 1 year 1 mons | 01:30:00 | 1 year 1 mons 1 days 01:30:00 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 |
34-
+------------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
29+
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
30+
| NULL | f32 | f64 | i16 | u16 | i32 | u32 | i64 | u64 | bool_true | bool_false | str | d0 | d2 | d5 | d10 | date | tsmp | interval_year_month | interval_day_time | interval_month_day_nano | str_arr | i64_arr | f64_arr |
31+
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+
32+
| NULL | 1.234 | 1.234 | 1 | 1 | 1 | 1 | 1 | 1 | true | false | test | 1 | 1.25 | 1.25000 | 1.2500000000 | 2022-04-25 | 2022-04-25 16:25:01.164774 | 1 year 1 mons | 01:30:00 | 1 year 1 mons 1 days 01:30:00 | test1,test2 | 1,2,3 | 1.2,2.3,3.4 |
33+
+------+-------+-------+-----+-----+-----+-----+-----+-----+-----------+------------+------+----+------+---------+--------------+------------+----------------------------+---------------------+-------------------+-------------------------------+-------------+---------+-------------+

rust/cubesql/cubesql/src/compile/engine/df/coerce.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
1010
| DataType::Float16
1111
| DataType::Float32
1212
| DataType::Float64
13+
| DataType::Null
1314
)
1415
}
1516

@@ -33,6 +34,9 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
3334
}
3435

3536
match (lhs_type, rhs_type) {
37+
(_, DataType::Null) => Some(lhs_type.clone()),
38+
(DataType::Null, _) => Some(rhs_type.clone()),
39+
//
3640
(_, DataType::UInt64) => Some(DataType::UInt64),
3741
(DataType::UInt64, _) => Some(DataType::UInt64),
3842
//
@@ -50,6 +54,9 @@ pub fn if_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
5054
}
5155

5256
let hack_ty = match (lhs_type, rhs_type) {
57+
(_, DataType::Null) => Some(lhs_type.clone()),
58+
(DataType::Null, _) => Some(rhs_type.clone()),
59+
//
5360
(DataType::Utf8, DataType::UInt64) => Some(DataType::Utf8),
5461
(DataType::Utf8, DataType::Int64) => Some(DataType::Utf8),
5562
//
@@ -69,6 +76,9 @@ pub fn least_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
6976
}
7077

7178
let hack_ty = match (lhs_type, rhs_type) {
79+
(_, DataType::Null) => Some(lhs_type.clone()),
80+
(DataType::Null, _) => Some(rhs_type.clone()),
81+
//
7282
(DataType::Utf8, DataType::UInt64) => Some(DataType::Utf8),
7383
(DataType::Utf8, DataType::Int64) => Some(DataType::Utf8),
7484
//

rust/cubesql/cubesql/src/compile/engine/df/columar.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,23 @@ use std::sync::Arc;
99

1010
macro_rules! if_then_else {
1111
($BUILDER_TYPE:ty, $ARRAY_TYPE:ty, $BOOLS:expr, $TRUE:expr, $FALSE:expr) => {{
12-
let true_values = $TRUE
12+
let true_values = if $TRUE.data_type() == &DataType::Null {
13+
Arc::new(<$ARRAY_TYPE>::from(vec![None; $TRUE.len()]))
14+
} else {
15+
$TRUE
16+
};
17+
let true_values = true_values
1318
.as_ref()
1419
.as_any()
1520
.downcast_ref::<$ARRAY_TYPE>()
1621
.expect("true_values downcast failed");
1722

18-
let false_values = $FALSE
23+
let false_values = if $FALSE.data_type() == &DataType::Null {
24+
Arc::new(<$ARRAY_TYPE>::from(vec![None; $FALSE.len()]))
25+
} else {
26+
$FALSE
27+
};
28+
let false_values = false_values
1929
.as_ref()
2030
.as_any()
2131
.downcast_ref::<$ARRAY_TYPE>()

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ impl CubeScanWrapperNode {
13341334
}
13351335
// ScalarValue::IntervalMonthDayNano(_) => {}
13361336
// ScalarValue::Struct(_, _) => {}
1337+
ScalarValue::Null => ("NULL".to_string(), sql_query),
13371338
x => {
13381339
return Err(DataFusionError::Internal(format!(
13391340
"Can't generate SQL for literal: {:?}",

rust/cubesql/cubesql/src/compile/engine/udf.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,17 +395,31 @@ pub fn create_isnull_udf() -> ScalarUDF {
395395
Arc::new(builder.finish()) as ArrayRef
396396
}
397397
2 => {
398-
if args[0].data_type() != &DataType::Utf8 || args[1].data_type() != &DataType::Utf8
399-
{
400-
return Err(DataFusionError::Internal(format!(
401-
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
402-
args[0].data_type(),
403-
args[1].data_type(),
404-
)));
405-
}
398+
let expr = match args[0].data_type() {
399+
DataType::Utf8 => Arc::clone(&args[0]),
400+
DataType::Null => cast(&args[0], &DataType::Utf8)?,
401+
_ => {
402+
return Err(DataFusionError::Internal(format!(
403+
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
404+
args[0].data_type(),
405+
args[1].data_type(),
406+
)))
407+
}
408+
};
409+
let replacement = match args[1].data_type() {
410+
DataType::Utf8 => Arc::clone(&args[1]),
411+
DataType::Null => cast(&args[1], &DataType::Utf8)?,
412+
_ => {
413+
return Err(DataFusionError::Internal(format!(
414+
"isnull with 2 arguments supports only (Utf8, Utf8), actual: ({}, {})",
415+
args[0].data_type(),
416+
args[1].data_type(),
417+
)))
418+
}
419+
};
406420

407-
let exprs = downcast_string_arg!(&args[0], "expr", i32);
408-
let replacements = downcast_string_arg!(&args[1], "replacement", i32);
421+
let exprs = downcast_string_arg!(expr, "expr", i32);
422+
let replacements = downcast_string_arg!(replacement, "replacement", i32);
409423

410424
let result = exprs
411425
.iter()

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21233,4 +21233,31 @@ limit
2123321233
assert!(sql.contains("-(EXTRACT(YEAR FROM"));
2123421234
assert!(sql.contains("* INTERVAL '1 DAY'"));
2123521235
}
21236+
21237+
#[tokio::test]
21238+
async fn test_case_mixed_values_with_null() -> Result<(), CubeError> {
21239+
init_logger();
21240+
21241+
insta::assert_snapshot!(
21242+
"test_case_mixed_values_with_null",
21243+
execute_query(
21244+
"
21245+
SELECT LEFT(ACOS(
21246+
CASE i
21247+
WHEN 0 THEN NULL
21248+
ELSE (i::float / 10.0)
21249+
END
21250+
)::text, 10) AS acos
21251+
FROM (
21252+
SELECT generate_series(0, 5) AS i
21253+
) AS t
21254+
"
21255+
.to_string(),
21256+
DatabaseProtocol::PostgreSQL
21257+
)
21258+
.await?
21259+
);
21260+
21261+
Ok(())
21262+
}
2123621263
}

rust/cubesql/cubesql/src/compile/rewrite/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ use datafusion::{
1717
JoinConstraint, JoinType, Operator,
1818
},
1919
physical_plan::{
20-
aggregates::AggregateFunction, functions::BuiltinScalarFunction,
21-
window_functions::WindowFunction,
20+
aggregates::AggregateFunction, functions::BuiltinScalarFunction, windows::WindowFunction,
2221
},
2322
scalar::ScalarValue,
2423
};

0 commit comments

Comments
 (0)