Skip to content

Commit e586ff5

Browse files
batmnnnmartin-g
andauthored
fix: implement custom nullability for spark abs function (#19395)
Fixes apache/datafusion #19162 The SparkAbs UDF was using the default is_nullable=true for all outputs, even when inputs were non-nullable. This commit implements return_field_from_args to properly propagate nullability from input arguments. Changes: - Add return_field_from_args implementation to SparkAbs - Output nullability now matches input nullability - Handle edge case where scalar argument is explicitly null - Add tests for nullability behavior ## Which issue does this PR close? Closes #19162 ## Rationale for this change [SparkAbs](cci:2://file:///Users/batman/datafusion/datafusion/spark/src/function/math/abs.rs:41:0-43:1) was always returning `nullable=true` even for non-nullable inputs. ## What changes are included in this PR? Implement [return_field_from_args](cci:1://file:///Users/batman/datafusion/datafusion/expr/src/udf.rs:210:4-215:5) to propagate nullability from input arguments. ## Are these changes tested? Yes, added 2 tests for nullability behavior. ## Are there any user-facing changes? No. --------- Co-authored-by: Martin Grigorov <[email protected]>
1 parent ea2e22c commit e586ff5

File tree

1 file changed

+74
-4
lines changed
  • datafusion/spark/src/function/math

1 file changed

+74
-4
lines changed

datafusion/spark/src/function/math/abs.rs

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
// under the License.
1717

1818
use arrow::array::*;
19-
use arrow::datatypes::DataType;
19+
use arrow::datatypes::{DataType, Field, FieldRef};
2020
use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err};
2121
use datafusion_expr::{
22-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
22+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
23+
Volatility,
2324
};
2425
use datafusion_functions::{
2526
downcast_named_arg, make_abs_function, make_wrapping_abs_function,
@@ -69,8 +70,18 @@ impl ScalarUDFImpl for SparkAbs {
6970
&self.signature
7071
}
7172

72-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
73-
Ok(arg_types[0].clone())
73+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74+
internal_err!(
75+
"SparkAbs: return_type() is not used; return_field_from_args() is implemented"
76+
)
77+
}
78+
79+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
80+
let input_field = &args.arg_fields[0];
81+
let out_dt = input_field.data_type().clone();
82+
let out_nullable = input_field.is_nullable();
83+
84+
Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
7485
}
7586

7687
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -375,4 +386,63 @@ mod tests {
375386
as_decimal256_array
376387
);
377388
}
389+
390+
#[test]
391+
fn test_abs_nullability() {
392+
use arrow::datatypes::{DataType, Field};
393+
use datafusion_expr::ReturnFieldArgs;
394+
use std::sync::Arc;
395+
396+
let abs = SparkAbs::new();
397+
398+
// --- non-nullable Int32 input ---
399+
let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false));
400+
let out_non_null = abs
401+
.return_field_from_args(ReturnFieldArgs {
402+
arg_fields: &[Arc::clone(&non_nullable_i32)],
403+
scalar_arguments: &[None],
404+
})
405+
.unwrap();
406+
407+
// result should be non-nullable and the same DataType as input
408+
assert!(!out_non_null.is_nullable());
409+
assert_eq!(out_non_null.data_type(), &DataType::Int32);
410+
411+
// --- nullable Int32 input ---
412+
let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true));
413+
let out_nullable = abs
414+
.return_field_from_args(ReturnFieldArgs {
415+
arg_fields: &[Arc::clone(&nullable_i32)],
416+
scalar_arguments: &[None],
417+
})
418+
.unwrap();
419+
420+
// result should be nullable and the same DataType as input
421+
assert!(out_nullable.is_nullable());
422+
assert_eq!(out_nullable.data_type(), &DataType::Int32);
423+
424+
// --- non-nullable Float64 input ---
425+
let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false));
426+
let out_f64 = abs
427+
.return_field_from_args(ReturnFieldArgs {
428+
arg_fields: &[Arc::clone(&non_nullable_f64)],
429+
scalar_arguments: &[None],
430+
})
431+
.unwrap();
432+
433+
assert!(!out_f64.is_nullable());
434+
assert_eq!(out_f64.data_type(), &DataType::Float64);
435+
436+
// --- nullable Float64 input ---
437+
let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true));
438+
let out_f64_null = abs
439+
.return_field_from_args(ReturnFieldArgs {
440+
arg_fields: &[Arc::clone(&nullable_f64)],
441+
scalar_arguments: &[None],
442+
})
443+
.unwrap();
444+
445+
assert!(out_f64_null.is_nullable());
446+
assert_eq!(out_f64_null.data_type(), &DataType::Float64);
447+
}
378448
}

0 commit comments

Comments
 (0)