Skip to content

Commit 058bcb0

Browse files
skushagraalamb
andauthored
fix: custom nullability for format_string (#19173) (#19190)
## Which issue does this PR close? - Closes #19173 ## What changes are included in this PR? - includes custom nullability for `format_string`. --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent e586ff5 commit 058bcb0

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

datafusion/spark/src/function/string/format_string.rs

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use core::num::FpCategory;
2323

2424
use arrow::{
2525
array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray},
26-
datatypes::DataType,
26+
datatypes::{DataType, Field, FieldRef},
2727
};
2828
use bigdecimal::{
2929
BigDecimal, ToPrimitive,
@@ -34,8 +34,8 @@ use datafusion_common::{
3434
DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, plan_err,
3535
};
3636
use datafusion_expr::{
37-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
38-
Volatility,
37+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
38+
TypeSignature, Volatility,
3939
};
4040

4141
/// Spark-compatible `format_string` expression
@@ -78,14 +78,23 @@ impl ScalarUDFImpl for FormatStringFunc {
7878
&self.signature
7979
}
8080

81-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
82-
match arg_types[0] {
83-
DataType::Null => Ok(DataType::Utf8),
81+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
82+
datafusion_common::internal_err!(
83+
"return_type should not be called, use return_field_from_args instead"
84+
)
85+
}
86+
87+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
88+
match args.arg_fields[0].data_type() {
89+
DataType::Null => {
90+
Ok(Arc::new(Field::new("format_string", DataType::Utf8, true)))
91+
}
8492
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
85-
Ok(arg_types[0].clone())
93+
Ok(Arc::clone(&args.arg_fields[0]))
8694
}
87-
_ => plan_err!(
88-
"The format_string function expects the first argument to be Utf8, LargeUtf8 or Utf8View"
95+
_ => exec_err!(
96+
"format_string expects the first argument to be Utf8, LargeUtf8 or Utf8View, got {} instead",
97+
args.arg_fields[0].data_type()
8998
),
9099
}
91100
}
@@ -2347,3 +2356,39 @@ fn trim_trailing_0s_hex(number: &str) -> &str {
23472356
}
23482357
number
23492358
}
2359+
2360+
#[cfg(test)]
2361+
mod tests {
2362+
use super::*;
2363+
use arrow::datatypes::DataType::Utf8;
2364+
use datafusion_common::Result;
2365+
2366+
#[test]
2367+
fn test_format_string_nullability() -> Result<()> {
2368+
let func = FormatStringFunc::new();
2369+
let nullable_format: FieldRef = Arc::new(Field::new("fmt", Utf8, true));
2370+
2371+
let out_nullable = func.return_field_from_args(ReturnFieldArgs {
2372+
arg_fields: &[nullable_format],
2373+
scalar_arguments: &[None],
2374+
})?;
2375+
2376+
assert!(
2377+
out_nullable.is_nullable(),
2378+
"format_string(fmt, ...) should be nullable when fmt is nullable"
2379+
);
2380+
let non_nullable_format: FieldRef = Arc::new(Field::new("fmt", Utf8, false));
2381+
2382+
let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
2383+
arg_fields: &[non_nullable_format],
2384+
scalar_arguments: &[None],
2385+
})?;
2386+
2387+
assert!(
2388+
!out_non_nullable.is_nullable(),
2389+
"format_string(fmt, ...) should NOT be nullable when fmt is NOT nullable"
2390+
);
2391+
2392+
Ok(())
2393+
}
2394+
}

0 commit comments

Comments
 (0)