Skip to content

Commit 02ceacc

Browse files
committed
fix(cubesql): Generate typed null literals
This is to avoid expression like SUM(NULL), which are ambiguous in PostgreSQL
1 parent 84f90c0 commit 02ceacc

File tree

2 files changed

+153
-33
lines changed

2 files changed

+153
-33
lines changed

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 114 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ pub struct SqlGenerationResult {
514514
static DATE_PART_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new("^[A-Za-z_ ]+$").unwrap());
515515

516516
macro_rules! generate_sql_for_timestamp {
517-
(@generic $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
517+
(@generic $literal:ident, $value:ident, $value_block:expr, $sql_generator:expr, $sql_query:expr) => {
518518
if let Some($value) = $value {
519519
let value = $value_block.to_rfc3339_opts(SecondsFormat::Millis, true);
520520
(
@@ -530,27 +530,27 @@ macro_rules! generate_sql_for_timestamp {
530530
$sql_query,
531531
)
532532
} else {
533-
("NULL".to_string(), $sql_query)
533+
(Self::generate_null_for_literal($sql_generator, &$literal)?, $sql_query)
534534
}
535535
};
536-
($value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
536+
($literal:ident, $value:ident, timestamp, $sql_generator:expr, $sql_query:expr) => {
537537
generate_sql_for_timestamp!(
538-
@generic $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
538+
@generic $literal, $value, { Utc.timestamp_opt($value as i64, 0).unwrap() }, $sql_generator, $sql_query
539539
)
540540
};
541-
($value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
541+
($literal:ident, $value:ident, timestamp_millis_opt, $sql_generator:expr, $sql_query:expr) => {
542542
generate_sql_for_timestamp!(
543-
@generic $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
543+
@generic $literal, $value, { Utc.timestamp_millis_opt($value as i64).unwrap() }, $sql_generator, $sql_query
544544
)
545545
};
546-
($value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
546+
($literal:ident, $value:ident, timestamp_micros, $sql_generator:expr, $sql_query:expr) => {
547547
generate_sql_for_timestamp!(
548-
@generic $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
548+
@generic $literal, $value, { Utc.timestamp_micros($value as i64).unwrap() }, $sql_generator, $sql_query
549549
)
550550
};
551-
($value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
551+
($literal:ident, $value:ident, $method:ident, $sql_generator:expr, $sql_query:expr) => {
552552
generate_sql_for_timestamp!(
553-
@generic $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
553+
@generic $literal, $value, { Utc.$method($value as i64) }, $sql_generator, $sql_query
554554
)
555555
};
556556
}
@@ -1599,6 +1599,27 @@ impl CubeScanWrapperNode {
15991599
.map_err(|e| DataFusionError::Internal(format!("Can't generate SQL for type: {}", e)))
16001600
}
16011601

1602+
fn generate_typed_null(
1603+
sql_generator: Arc<dyn SqlGenerator>,
1604+
data_type: Option<DataType>,
1605+
) -> result::Result<String, DataFusionError> {
1606+
let Some(data_type) = data_type else {
1607+
return Ok("NULL".to_string());
1608+
};
1609+
1610+
let sql_type = Self::generate_sql_type(sql_generator.clone(), data_type)?;
1611+
let result = Self::generate_sql_cast_expr(sql_generator, "NULL".to_string(), sql_type)?;
1612+
Ok(result)
1613+
}
1614+
1615+
fn generate_null_for_literal(
1616+
sql_generator: Arc<dyn SqlGenerator>,
1617+
value: &ScalarValue,
1618+
) -> result::Result<String, DataFusionError> {
1619+
let data_type = value.get_datatype();
1620+
Self::generate_typed_null(sql_generator, Some(data_type))
1621+
}
1622+
16021623
/// This function is async to be able to call to JS land,
16031624
/// in case some SQL generation could not be done through Jinja
16041625
pub fn generate_sql_for_expr<'ctx>(
@@ -2076,15 +2097,25 @@ impl CubeScanWrapperNode {
20762097
))
20772098
})
20782099
})
2079-
.unwrap_or(Ok("NULL".to_string()))?,
2100+
.transpose()?
2101+
.map_or_else(
2102+
|| Self::generate_null_for_literal(sql_generator, &literal),
2103+
Ok,
2104+
)?,
20802105
sql_query,
20812106
),
20822107
ScalarValue::Float32(f) => (
2083-
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
2108+
f.map(|f| format!("{f}")).map_or_else(
2109+
|| Self::generate_null_for_literal(sql_generator, &literal),
2110+
Ok,
2111+
)?,
20842112
sql_query,
20852113
),
20862114
ScalarValue::Float64(f) => (
2087-
f.map(|f| format!("{}", f)).unwrap_or("NULL".to_string()),
2115+
f.map(|f| format!("{f}")).map_or_else(
2116+
|| Self::generate_null_for_literal(sql_generator, &literal),
2117+
Ok,
2118+
)?,
20882119
sql_query,
20892120
),
20902121
ScalarValue::Decimal128(x, precision, scale) => {
@@ -2104,49 +2135,76 @@ impl CubeScanWrapperNode {
21042135
data_type,
21052136
)?
21062137
} else {
2107-
"NULL".to_string()
2138+
Self::generate_null_for_literal(sql_generator, &literal)?
21082139
},
21092140
sql_query,
21102141
)
21112142
}
21122143
ScalarValue::Int8(x) => (
2113-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2144+
x.map(|x| format!("{x}")).map_or_else(
2145+
|| Self::generate_null_for_literal(sql_generator, &literal),
2146+
Ok,
2147+
)?,
21142148
sql_query,
21152149
),
21162150
ScalarValue::Int16(x) => (
2117-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2151+
x.map(|x| format!("{x}")).map_or_else(
2152+
|| Self::generate_null_for_literal(sql_generator, &literal),
2153+
Ok,
2154+
)?,
21182155
sql_query,
21192156
),
21202157
ScalarValue::Int32(x) => (
2121-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2158+
x.map(|x| format!("{x}")).map_or_else(
2159+
|| Self::generate_null_for_literal(sql_generator, &literal),
2160+
Ok,
2161+
)?,
21222162
sql_query,
21232163
),
21242164
ScalarValue::Int64(x) => (
2125-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2165+
x.map(|x| format!("{x}")).map_or_else(
2166+
|| Self::generate_null_for_literal(sql_generator, &literal),
2167+
Ok,
2168+
)?,
21262169
sql_query,
21272170
),
21282171
ScalarValue::UInt8(x) => (
2129-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2172+
x.map(|x| format!("{x}")).map_or_else(
2173+
|| Self::generate_null_for_literal(sql_generator, &literal),
2174+
Ok,
2175+
)?,
21302176
sql_query,
21312177
),
21322178
ScalarValue::UInt16(x) => (
2133-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2179+
x.map(|x| format!("{x}")).map_or_else(
2180+
|| Self::generate_null_for_literal(sql_generator, &literal),
2181+
Ok,
2182+
)?,
21342183
sql_query,
21352184
),
21362185
ScalarValue::UInt32(x) => (
2137-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2186+
x.map(|x| format!("{x}")).map_or_else(
2187+
|| Self::generate_null_for_literal(sql_generator, &literal),
2188+
Ok,
2189+
)?,
21382190
sql_query,
21392191
),
21402192
ScalarValue::UInt64(x) => (
2141-
x.map(|x| format!("{}", x)).unwrap_or("NULL".to_string()),
2193+
x.map(|x| format!("{x}")).map_or_else(
2194+
|| Self::generate_null_for_literal(sql_generator, &literal),
2195+
Ok,
2196+
)?,
21422197
sql_query,
21432198
),
21442199
ScalarValue::Utf8(x) => {
21452200
if x.is_some() {
21462201
let param_index = sql_query.add_value(x);
21472202
(format!("${}$", param_index), sql_query)
21482203
} else {
2149-
("NULL".into(), sql_query)
2204+
(
2205+
Self::generate_typed_null(sql_generator, Some(DataType::Utf8))?,
2206+
sql_query,
2207+
)
21502208
}
21512209
}
21522210
// ScalarValue::LargeUtf8(_) => {}
@@ -2187,42 +2245,54 @@ impl CubeScanWrapperNode {
21872245
sql_query,
21882246
)
21892247
} else {
2190-
("NULL".to_string(), sql_query)
2248+
(
2249+
Self::generate_null_for_literal(sql_generator, &literal)?,
2250+
sql_query,
2251+
)
21912252
}
21922253
}
21932254
// ScalarValue::Date64(_) => {}
21942255

21952256
// generate_sql_for_timestamp will call Utc constructors, so only support UTC zone for now
21962257
// DataFusion can return "UTC" for stuff like `NOW()` during constant folding
2197-
ScalarValue::TimestampSecond(s, tz)
2258+
ScalarValue::TimestampSecond(s, ref tz)
21982259
if matches!(tz.as_deref(), None | Some("UTC")) =>
21992260
{
2200-
generate_sql_for_timestamp!(s, timestamp, sql_generator, sql_query)
2261+
generate_sql_for_timestamp!(
2262+
literal,
2263+
s,
2264+
timestamp,
2265+
sql_generator,
2266+
sql_query
2267+
)
22012268
}
2202-
ScalarValue::TimestampMillisecond(ms, tz)
2269+
ScalarValue::TimestampMillisecond(ms, ref tz)
22032270
if matches!(tz.as_deref(), None | Some("UTC")) =>
22042271
{
22052272
generate_sql_for_timestamp!(
2273+
literal,
22062274
ms,
22072275
timestamp_millis_opt,
22082276
sql_generator,
22092277
sql_query
22102278
)
22112279
}
2212-
ScalarValue::TimestampMicrosecond(ms, tz)
2280+
ScalarValue::TimestampMicrosecond(ms, ref tz)
22132281
if matches!(tz.as_deref(), None | Some("UTC")) =>
22142282
{
22152283
generate_sql_for_timestamp!(
2284+
literal,
22162285
ms,
22172286
timestamp_micros,
22182287
sql_generator,
22192288
sql_query
22202289
)
22212290
}
2222-
ScalarValue::TimestampNanosecond(nanoseconds, tz)
2291+
ScalarValue::TimestampNanosecond(nanoseconds, ref tz)
22232292
if matches!(tz.as_deref(), None | Some("UTC")) =>
22242293
{
22252294
generate_sql_for_timestamp!(
2295+
literal,
22262296
nanoseconds,
22272297
timestamp_nanos,
22282298
sql_generator,
@@ -2246,7 +2316,10 @@ impl CubeScanWrapperNode {
22462316
sql_query,
22472317
)
22482318
} else {
2249-
("NULL".to_string(), sql_query)
2319+
(
2320+
Self::generate_null_for_literal(sql_generator, &literal)?,
2321+
sql_query,
2322+
)
22502323
}
22512324
}
22522325
ScalarValue::IntervalDayTime(x) => {
@@ -2256,7 +2329,10 @@ impl CubeScanWrapperNode {
22562329
let generated_sql = decomposed.generate_interval_sql(&templates)?;
22572330
(generated_sql, sql_query)
22582331
} else {
2259-
("NULL".to_string(), sql_query)
2332+
(
2333+
Self::generate_null_for_literal(sql_generator, &literal)?,
2334+
sql_query,
2335+
)
22602336
}
22612337
}
22622338
ScalarValue::IntervalMonthDayNano(x) => {
@@ -2266,11 +2342,16 @@ impl CubeScanWrapperNode {
22662342
let generated_sql = decomposed.generate_interval_sql(&templates)?;
22672343
(generated_sql, sql_query)
22682344
} else {
2269-
("NULL".to_string(), sql_query)
2345+
(
2346+
Self::generate_null_for_literal(sql_generator, &literal)?,
2347+
sql_query,
2348+
)
22702349
}
22712350
}
22722351
// ScalarValue::Struct(_, _) => {}
2273-
ScalarValue::Null => ("NULL".to_string(), sql_query),
2352+
ScalarValue::Null => {
2353+
(Self::generate_typed_null(sql_generator, None)?, sql_query)
2354+
}
22742355
x => {
22752356
return Err(DataFusionError::Internal(format!(
22762357
"Can't generate SQL for literal: {:?}",

rust/cubesql/cubesql/src/compile/test/test_wrapper.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,3 +1456,42 @@ WHERE
14561456
assert_eq!(request.measures.unwrap().len(), 1);
14571457
assert_eq!(request.dimensions.unwrap().len(), 0);
14581458
}
1459+
1460+
#[tokio::test]
1461+
async fn wrapper_typed_null() {
1462+
if !Rewriter::sql_push_down_enabled() {
1463+
return;
1464+
}
1465+
init_testing_logger();
1466+
1467+
let query_plan = convert_select_to_query_plan(
1468+
// language=PostgreSQL
1469+
r#"
1470+
SELECT
1471+
dim_str0,
1472+
AVG(avgPrice),
1473+
CASE
1474+
WHEN SUM((NULLIF(0.0, 0.0))) IS NOT NULL THEN SUM((NULLIF(0.0, 0.0)))
1475+
ELSE 0
1476+
END
1477+
FROM MultiTypeCube
1478+
GROUP BY 1
1479+
;"#
1480+
.to_string(),
1481+
DatabaseProtocol::PostgreSQL,
1482+
)
1483+
.await;
1484+
1485+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
1486+
println!(
1487+
"Physical plan: {}",
1488+
displayable(physical_plan.as_ref()).indent()
1489+
);
1490+
1491+
assert!(query_plan
1492+
.as_logical_plan()
1493+
.find_cube_scan_wrapped_sql()
1494+
.wrapped_sql
1495+
.sql
1496+
.contains("SUM(CAST(NULL AS DOUBLE))"));
1497+
}

0 commit comments

Comments
 (0)