|
17 | 17 |
|
18 | 18 | use arrow::array::ArrayRef; |
19 | 19 | use arrow::array::GenericStringBuilder; |
20 | | -use arrow::datatypes::DataType; |
21 | 20 | use arrow::datatypes::DataType::Int64; |
22 | 21 | use arrow::datatypes::DataType::Utf8; |
| 22 | +use arrow::datatypes::{DataType, Field, FieldRef}; |
23 | 23 | use std::{any::Any, sync::Arc}; |
24 | 24 |
|
25 | 25 | use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; |
26 | 26 | use datafusion_expr::{ |
27 | | - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, |
| 27 | + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, |
| 28 | + Volatility, |
28 | 29 | }; |
29 | 30 |
|
30 | 31 | /// Spark-compatible `char` expression |
@@ -62,12 +63,19 @@ impl ScalarUDFImpl for CharFunc { |
62 | 63 | } |
63 | 64 |
|
64 | 65 | fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { |
65 | | - Ok(Utf8) |
| 66 | + datafusion_common::internal_err!( |
| 67 | + "return_type should not be called, use return_field_from_args instead" |
| 68 | + ) |
66 | 69 | } |
67 | 70 |
|
68 | 71 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
69 | 72 | spark_chr(&args.args) |
70 | 73 | } |
| 74 | + |
| 75 | + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> { |
| 76 | + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); |
| 77 | + Ok(Arc::new(Field::new(self.name(), Utf8, nullable))) |
| 78 | + } |
71 | 79 | } |
72 | 80 |
|
73 | 81 | /// Returns the ASCII character having the binary equivalent to the input expression. |
@@ -130,3 +138,48 @@ fn chr(args: &[ArrayRef]) -> Result<ArrayRef> { |
130 | 138 |
|
131 | 139 | Ok(Arc::new(builder.finish()) as ArrayRef) |
132 | 140 | } |
| 141 | + |
| 142 | +#[test] |
| 143 | +fn test_char_nullability() -> Result<()> { |
| 144 | + use arrow::datatypes::{DataType::Utf8, Field, FieldRef}; |
| 145 | + use datafusion_expr::ReturnFieldArgs; |
| 146 | + use std::sync::Arc; |
| 147 | + |
| 148 | + let func = CharFunc::new(); |
| 149 | + |
| 150 | + let nullable_field: FieldRef = Arc::new(Field::new("col", Int64, true)); |
| 151 | + |
| 152 | + let out_nullable = func.return_field_from_args(ReturnFieldArgs { |
| 153 | + arg_fields: &[nullable_field], |
| 154 | + scalar_arguments: &[None], |
| 155 | + })?; |
| 156 | + |
| 157 | + assert!( |
| 158 | + out_nullable.is_nullable(), |
| 159 | + "char(col) should be nullable when input column is nullable" |
| 160 | + ); |
| 161 | + assert_eq!( |
| 162 | + out_nullable.data_type(), |
| 163 | + &Utf8, |
| 164 | + "char always returns Utf8 regardless of input type" |
| 165 | + ); |
| 166 | + |
| 167 | + let non_nullable_field: FieldRef = Arc::new(Field::new("col", Int64, false)); |
| 168 | + |
| 169 | + let out_non_nullable = func.return_field_from_args(ReturnFieldArgs { |
| 170 | + arg_fields: &[non_nullable_field], |
| 171 | + scalar_arguments: &[None], |
| 172 | + })?; |
| 173 | + |
| 174 | + assert!( |
| 175 | + !out_non_nullable.is_nullable(), |
| 176 | + "char(col) should NOT be nullable when input column is NOT nullable" |
| 177 | + ); |
| 178 | + assert_eq!( |
| 179 | + out_non_nullable.data_type(), |
| 180 | + &Utf8, |
| 181 | + "char always returns Utf8 regardless of input type" |
| 182 | + ); |
| 183 | + |
| 184 | + Ok(()) |
| 185 | +} |
0 commit comments