Skip to content

Commit 1dfa10d

Browse files
authored
fix(cubesql): Generate typed null literals (#9238)
This is to avoid expression like SUM(NULL), which are ambiguous in PostgreSQL
1 parent 75095e1 commit 1dfa10d

File tree

2 files changed

+153
-33
lines changed

2 files changed

+153
-33
lines changed

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ pub struct SqlGenerationResult {
511511
static DATE_PART_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("^[A-Za-z_ ]+$").unwrap());
512512

513513
macro_rules! generate_sql_for_timestamp {
514-
(@generic $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
514+
(@generic $literal:ident, $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
515515
if let Some($value) = $value {
516516
let value = $value_block.to_rfc3339_opts(SecondsFormat::Millis, true);
517517
(
@@ -527,27 +527,27 @@ macro_rules! generate_sql_for_timestamp {
527527
$sql_query,
528528
)
529529
} else {
530-
("NULL".to_string(), $sql_query)
530+
(Self::generate_null_for_literal($sql_generator, &$literal)?, $sql_query)
531531
}
532532
};
533-
($value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
533+
($literal:ident, $value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
534534
generate_sql_for_timestamp!(
535-
@generic $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
535+
@generic $literal, $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
536536
)
537537
};
538-
($value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
538+
($literal:ident, $value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
539539
generate_sql_for_timestamp!(
540-
@generic $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
540+
@generic $literal, $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
541541
)
542542
};
543-
($value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
543+
($literal:ident, $value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
544544
generate_sql_for_timestamp!(
545-
@generic $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
545+
@generic $literal, $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
546546
)
547547
};
548-
($value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
548+
($literal:ident, $value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
549549
generate_sql_for_timestamp!(
550-
@generic $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
550+
@generic $literal, $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
551551
)
552552
};
553553
}
@@ -1606,6 +1606,27 @@ impl CubeScanWrapperNode {
16061606
.map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for type: {}", e)))
16071607
}
16081608

1609+
fn generate_typed_null(
1610+
sql_generator: Arc<dyn SqlGenerator>,
1611+
data_type: Option<DataType>,
1612+
) -> result::Result<String, DataFusionError> {
1613+
let Some(data_type) = data_type else {
1614+
return Ok("NULL".to_string());
1615+
};
1616+
1617+
let sql_type = Self::generate_sql_type(sql_generator.clone(), data_type)?;
1618+
let result = Self::generate_sql_cast_expr(sql_generator, "NULL".to_string(), sql_type)?;
1619+
Ok(result)
1620+
}
1621+
1622+
fn generate_null_for_literal(
1623+
sql_generator: Arc<dyn SqlGenerator>,
1624+
value: &ScalarValue,
1625+
) -> result::Result<String, DataFusionError> {
1626+
let data_type = value.get_datatype();
1627+
Self::generate_typed_null(sql_generator, Some(data_type))
1628+
}
1629+
16091630
/// This function is async to be able to call to JS land,
16101631
/// in case some SQL generation could not be done through Jinja
16111632
pub fn generate_sql_for_expr<'ctx>(
@@ -2083,15 +2104,25 @@ impl CubeScanWrapperNode {
20832104
))
20842105
})
20852106
})
2086-
.unwrap_or(Ok("NULL".to_string()))?,
2107+
.transpose()?
2108+
.map_or_else(
2109+
|| Self::generate_null_for_literal(sql_generator, &literal),
2110+
Ok,
2111+
)?,
20872112
sql_query,
20882113
),
20892114
ScalarValue::Float32(f) => (
2090-
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
2115+
f.map(|f| format!("{f}")).map_or_else(
2116+
|| Self::generate_null_for_literal(sql_generator, &literal),
2117+
Ok,
2118+
)?,
20912119
sql_query,
20922120
),
20932121
ScalarValue::Float64(f) => (
2094-
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
2122+
f.map(|f| format!("{f}")).map_or_else(
2123+
|| Self::generate_null_for_literal(sql_generator, &literal),
2124+
Ok,
2125+
)?,
20952126
sql_query,
20962127
),
20972128
ScalarValue::Decimal128(x, precision, scale) => {
@@ -2111,49 +2142,76 @@ impl CubeScanWrapperNode {
21112142
data_type,
21122143
)?
21132144
} else {
2114-
"NULL".to_string()
2145+
Self::generate_null_for_literal(sql_generator, &literal)?
21152146
},
21162147
sql_query,
21172148
)
21182149
}
21192150
ScalarValue::Int8(x) => (
2120-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2151+
x.map(|x| format!("{x}")).map_or_else(
2152+
|| Self::generate_null_for_literal(sql_generator, &literal),
2153+
Ok,
2154+
)?,
21212155
sql_query,
21222156
),
21232157
ScalarValue::Int16(x) => (
2124-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2158+
x.map(|x| format!("{x}")).map_or_else(
2159+
|| Self::generate_null_for_literal(sql_generator, &literal),
2160+
Ok,
2161+
)?,
21252162
sql_query,
21262163
),
21272164
ScalarValue::Int32(x) => (
2128-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2165+
x.map(|x| format!("{x}")).map_or_else(
2166+
|| Self::generate_null_for_literal(sql_generator, &literal),
2167+
Ok,
2168+
)?,
21292169
sql_query,
21302170
),
21312171
ScalarValue::Int64(x) => (
2132-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2172+
x.map(|x| format!("{x}")).map_or_else(
2173+
|| Self::generate_null_for_literal(sql_generator, &literal),
2174+
Ok,
2175+
)?,
21332176
sql_query,
21342177
),
21352178
ScalarValue::UInt8(x) => (
2136-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2179+
x.map(|x| format!("{x}")).map_or_else(
2180+
|| Self::generate_null_for_literal(sql_generator, &literal),
2181+
Ok,
2182+
)?,
21372183
sql_query,
21382184
),
21392185
ScalarValue::UInt16(x) => (
2140-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2186+
x.map(|x| format!("{x}")).map_or_else(
2187+
|| Self::generate_null_for_literal(sql_generator, &literal),
2188+
Ok,
2189+
)?,
21412190
sql_query,
21422191
),
21432192
ScalarValue::UInt32(x) => (
2144-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2193+
x.map(|x| format!("{x}")).map_or_else(
2194+
|| Self::generate_null_for_literal(sql_generator, &literal),
2195+
Ok,
2196+
)?,
21452197
sql_query,
21462198
),
21472199
ScalarValue::UInt64(x) => (
2148-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2200+
x.map(|x| format!("{x}")).map_or_else(
2201+
|| Self::generate_null_for_literal(sql_generator, &literal),
2202+
Ok,
2203+
)?,
21492204
sql_query,
21502205
),
21512206
ScalarValue::Utf8(x) => {
21522207
if x.is_some() {
21532208
let param_index = sql_query.add_value(x);
21542209
(format!("${}$", param_index), sql_query)
21552210
} else {
2156-
("NULL".into(), sql_query)
2211+
(
2212+
Self::generate_typed_null(sql_generator, Some(DataType::Utf8))?,
2213+
sql_query,
2214+
)
21572215
}
21582216
}
21592217
// ScalarValue::LargeUtf8(_) => {}
@@ -2194,42 +2252,54 @@ impl CubeScanWrapperNode {
21942252
sql_query,
21952253
)
21962254
} else {
2197-
("NULL".to_string(), sql_query)
2255+
(
2256+
Self::generate_null_for_literal(sql_generator, &literal)?,
2257+
sql_query,
2258+
)
21982259
}
21992260
}
22002261
// ScalarValue::Date64(_) => {}
22012262

22022263
// generate_sql_for_timestamp will call Utc constructors, so only support UTC zone for now
22032264
// DataFusion can return "UTC" for stuff like `NOW()` during constant folding
2204-
ScalarValue::TimestampSecond(s, tz)
2265+
ScalarValue::TimestampSecond(s, ref tz)
22052266
if matches!(tz.as_deref(), None | Some("UTC")) =>
22062267
{
2207-
generate_sql_for_timestamp!(s, timestamp, sql_generator, sql_query)
2268+
generate_sql_for_timestamp!(
2269+
literal,
2270+
s,
2271+
timestamp,
2272+
sql_generator,
2273+
sql_query
2274+
)
22082275
}
2209-
ScalarValue::TimestampMillisecond(ms, tz)
2276+
ScalarValue::TimestampMillisecond(ms, ref tz)
22102277
if matches!(tz.as_deref(), None | Some("UTC")) =>
22112278
{
22122279
generate_sql_for_timestamp!(
2280+
literal,
22132281
ms,
22142282
timestamp_millis_opt,
22152283
sql_generator,
22162284
sql_query
22172285
)
22182286
}
2219-
ScalarValue::TimestampMicrosecond(ms, tz)
2287+
ScalarValue::TimestampMicrosecond(ms, ref tz)
22202288
if matches!(tz.as_deref(), None | Some("UTC")) =>
22212289
{
22222290
generate_sql_for_timestamp!(
2291+
literal,
22232292
ms,
22242293
timestamp_micros,
22252294
sql_generator,
22262295
sql_query
22272296
)
22282297
}
2229-
ScalarValue::TimestampNanosecond(nanoseconds, tz)
2298+
ScalarValue::TimestampNanosecond(nanoseconds, ref tz)
22302299
if matches!(tz.as_deref(), None | Some("UTC")) =>
22312300
{
22322301
generate_sql_for_timestamp!(
2302+
literal,
22332303
nanoseconds,
22342304
timestamp_nanos,
22352305
sql_generator,
@@ -2253,7 +2323,10 @@ impl CubeScanWrapperNode {
22532323
sql_query,
22542324
)
22552325
} else {
2256-
("NULL".to_string(), sql_query)
2326+
(
2327+
Self::generate_null_for_literal(sql_generator, &literal)?,
2328+
sql_query,
2329+
)
22572330
}
22582331
}
22592332
ScalarValue::IntervalDayTime(x) => {
@@ -2263,7 +2336,10 @@ impl CubeScanWrapperNode {
22632336
let generated_sql = decomposed.generate_interval_sql(&templates)?;
22642337
(generated_sql, sql_query)
22652338
} else {
2266-
("NULL".to_string(), sql_query)
2339+
(
2340+
Self::generate_null_for_literal(sql_generator, &literal)?,
2341+
sql_query,
2342+
)
22672343
}
22682344
}
22692345
ScalarValue::IntervalMonthDayNano(x) => {
@@ -2273,11 +2349,16 @@ impl CubeScanWrapperNode {
22732349
let generated_sql = decomposed.generate_interval_sql(&templates)?;
22742350
(generated_sql, sql_query)
22752351
} else {
2276-
("NULL".to_string(), sql_query)
2352+
(
2353+
Self::generate_null_for_literal(sql_generator, &literal)?,
2354+
sql_query,
2355+
)
22772356
}
22782357
}
22792358
// ScalarValue::Struct(_, _) => {}
2280-
ScalarValue::Null => ("NULL".to_string(), sql_query),
2359+
ScalarValue::Null => {
2360+
(Self::generate_typed_null(sql_generator, None)?, sql_query)
2361+
}
22812362
x => {
22822363
return Err(DataFusionError::Internal(format!(
22832364
"Can't generate SQL for literal: {:?}",

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,3 +1559,42 @@ async fn wrapper_cast_limit_explicit_members() {
15591559
assert_eq!(request.measures.unwrap().len(), 1);
15601560
assert_eq!(request.dimensions.unwrap().len(), 0);
15611561
}
1562+
1563+
#[tokio::test]
1564+
async fn wrapper_typed_null() {
1565+
if !Rewriter::sql_push_down_enabled() {
1566+
return;
1567+
}
1568+
init_testing_logger();
1569+
1570+
let query_plan = convert_select_to_query_plan(
1571+
// language=PostgreSQL
1572+
r#"
1573+
SELECT
1574+
dim_str0,
1575+
AVG(avgPrice),
1576+
CASE
1577+
WHEN SUM((NULLIF(0.0, 0.0))) IS NOT NULL THEN SUM((NULLIF(0.0, 0.0)))
1578+
ELSE 0
1579+
END
1580+
FROM MultiTypeCube
1581+
GROUP BY 1
1582+
;"#
1583+
.to_string(),
1584+
DatabaseProtocol::PostgreSQL,
1585+
)
1586+
.await;
1587+
1588+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1589+
println!(
1590+
"Physical plan: {}",
1591+
displayable(physical_plan.as_ref()).indent()
1592+
);
1593+
1594+
assert!(query_plan
1595+
.as_logical_plan()
1596+
.find_cube_scan_wrapped_sql()
1597+
.wrapped_sql
1598+
.sql
1599+
.contains("SUM(CAST(NULL AS DOUBLE))"));
1600+
}

0 commit comments

Comments
 (0)