@@ -23,7 +23,7 @@ use core::num::FpCategory;
2323
2424use arrow:: {
2525 array:: { Array , ArrayRef , LargeStringArray , StringArray , StringViewArray } ,
26- datatypes:: DataType ,
26+ datatypes:: { DataType , Field , FieldRef } ,
2727} ;
2828use bigdecimal:: {
2929 BigDecimal , ToPrimitive ,
@@ -34,8 +34,8 @@ use datafusion_common::{
3434 DataFusionError , Result , ScalarValue , exec_datafusion_err, exec_err, plan_err,
3535} ;
3636use datafusion_expr:: {
37- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature ,
38- Volatility ,
37+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
38+ TypeSignature , Volatility ,
3939} ;
4040
4141/// Spark-compatible `format_string` expression
@@ -78,14 +78,23 @@ impl ScalarUDFImpl for FormatStringFunc {
7878 & self . signature
7979 }
8080
81- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
82- match arg_types[ 0 ] {
83- DataType :: Null => Ok ( DataType :: Utf8 ) ,
81+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
82+ datafusion_common:: internal_err!(
83+ "return_type should not be called, use return_field_from_args instead"
84+ )
85+ }
86+
87+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
88+ match args. arg_fields [ 0 ] . data_type ( ) {
89+ DataType :: Null => {
90+ Ok ( Arc :: new ( Field :: new ( "format_string" , DataType :: Utf8 , true ) ) )
91+ }
8492 DataType :: Utf8 | DataType :: LargeUtf8 | DataType :: Utf8View => {
85- Ok ( arg_types [ 0 ] . clone ( ) )
93+ Ok ( Arc :: clone ( & args . arg_fields [ 0 ] ) )
8694 }
87- _ => plan_err ! (
88- "The format_string function expects the first argument to be Utf8, LargeUtf8 or Utf8View"
95+ _ => exec_err ! (
96+ "format_string expects the first argument to be Utf8, LargeUtf8 or Utf8View, got {} instead" ,
97+ args. arg_fields[ 0 ] . data_type( )
8998 ) ,
9099 }
91100 }
@@ -2347,3 +2356,39 @@ fn trim_trailing_0s_hex(number: &str) -> &str {
23472356 }
23482357 number
23492358}
2359+
2360+ #[ cfg( test) ]
2361+ mod tests {
2362+ use super :: * ;
2363+ use arrow:: datatypes:: DataType :: Utf8 ;
2364+ use datafusion_common:: Result ;
2365+
2366+ #[ test]
2367+ fn test_format_string_nullability ( ) -> Result < ( ) > {
2368+ let func = FormatStringFunc :: new ( ) ;
2369+ let nullable_format: FieldRef = Arc :: new ( Field :: new ( "fmt" , Utf8 , true ) ) ;
2370+
2371+ let out_nullable = func. return_field_from_args ( ReturnFieldArgs {
2372+ arg_fields : & [ nullable_format] ,
2373+ scalar_arguments : & [ None ] ,
2374+ } ) ?;
2375+
2376+ assert ! (
2377+ out_nullable. is_nullable( ) ,
2378+ "format_string(fmt, ...) should be nullable when fmt is nullable"
2379+ ) ;
2380+ let non_nullable_format: FieldRef = Arc :: new ( Field :: new ( "fmt" , Utf8 , false ) ) ;
2381+
2382+ let out_non_nullable = func. return_field_from_args ( ReturnFieldArgs {
2383+ arg_fields : & [ non_nullable_format] ,
2384+ scalar_arguments : & [ None ] ,
2385+ } ) ?;
2386+
2387+ assert ! (
2388+ !out_non_nullable. is_nullable( ) ,
2389+ "format_string(fmt, ...) should NOT be nullable when fmt is NOT nullable"
2390+ ) ;
2391+
2392+ Ok ( ( ) )
2393+ }
2394+ }
0 commit comments