Skip to content

Commit 7b0ed2d

Browse files
authored
Refactor Spark date_add/date_sub/bitwise_not to remove unnecessary scalar arg check (#19473)
Same as #19466 but for `date_add`, `date_sub` and `bitwise_not` > If we have a scalar argument that is null, that means the datatype it is from is already nullable, so theres no need to check both; we only need to check the nullability of the datatype
1 parent 1441269 commit 7b0ed2d

File tree

3 files changed

+7
-100
lines changed

3 files changed

+7
-100
lines changed

datafusion/spark/src/function/bitwise/bitwise_not.rs

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,11 @@ impl ScalarUDFImpl for SparkBitwiseNot {
7373
}
7474

7575
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
76-
if args.arg_fields.len() != 1 {
77-
return plan_err!("bitwise_not expects exactly 1 argument");
78-
}
79-
80-
let input_field = &args.arg_fields[0];
81-
82-
let out_dt = input_field.data_type().clone();
83-
let mut out_nullable = input_field.is_nullable();
84-
85-
let scalar_null_present = args
86-
.scalar_arguments
87-
.iter()
88-
.any(|opt_s| opt_s.is_some_and(|sv| sv.is_null()));
89-
90-
if scalar_null_present {
91-
out_nullable = true;
92-
}
93-
94-
Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable)))
76+
Ok(Arc::new(Field::new(
77+
self.name(),
78+
args.arg_fields[0].data_type().clone(),
79+
args.arg_fields[0].is_nullable(),
80+
)))
9581
}
9682

9783
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
@@ -196,32 +182,4 @@ mod tests {
196182
assert!(out_i64_null.is_nullable());
197183
assert_eq!(out_i64_null.data_type(), &DataType::Int64);
198184
}
199-
200-
#[test]
201-
fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> {
202-
use arrow::datatypes::{DataType, Field};
203-
use datafusion_common::ScalarValue;
204-
use std::sync::Arc;
205-
206-
let func = SparkBitwiseNot::new();
207-
208-
let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false));
209-
210-
let out = func.return_field_from_args(ReturnFieldArgs {
211-
arg_fields: &[Arc::clone(&non_nullable)],
212-
scalar_arguments: &[None],
213-
})?;
214-
assert!(!out.is_nullable());
215-
assert_eq!(out.data_type(), &DataType::Int32);
216-
217-
let null_scalar = ScalarValue::Int32(None);
218-
let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs {
219-
arg_fields: &[Arc::clone(&non_nullable)],
220-
scalar_arguments: &[Some(&null_scalar)],
221-
})?;
222-
assert!(out_with_null_scalar.is_nullable());
223-
assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32);
224-
225-
Ok(())
226-
}
227185
}

datafusion/spark/src/function/datetime/date_add.rs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,7 @@ impl ScalarUDFImpl for SparkDateAdd {
8282
}
8383

8484
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
85-
let nullable = args.arg_fields.iter().any(|f| f.is_nullable())
86-
|| args
87-
.scalar_arguments
88-
.iter()
89-
.any(|arg| matches!(arg, Some(sv) if sv.is_null()));
90-
85+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
9186
Ok(Arc::new(Field::new(
9287
self.name(),
9388
DataType::Date32,
@@ -142,7 +137,6 @@ fn spark_date_add(args: &[ArrayRef]) -> Result<ArrayRef> {
142137
mod tests {
143138
use super::*;
144139
use arrow::datatypes::Field;
145-
use datafusion_common::ScalarValue;
146140

147141
#[test]
148142
fn test_date_add_non_nullable_inputs() {
@@ -181,25 +175,4 @@ mod tests {
181175
assert_eq!(ret_field.data_type(), &DataType::Date32);
182176
assert!(ret_field.is_nullable());
183177
}
184-
185-
#[test]
186-
fn test_date_add_null_scalar() {
187-
let func = SparkDateAdd::new();
188-
let args = &[
189-
Arc::new(Field::new("date", DataType::Date32, false)),
190-
Arc::new(Field::new("num", DataType::Int32, false)),
191-
];
192-
193-
let null_scalar = ScalarValue::Int32(None);
194-
195-
let ret_field = func
196-
.return_field_from_args(ReturnFieldArgs {
197-
arg_fields: args,
198-
scalar_arguments: &[None, Some(&null_scalar)],
199-
})
200-
.unwrap();
201-
202-
assert_eq!(ret_field.data_type(), &DataType::Date32);
203-
assert!(ret_field.is_nullable());
204-
}
205178
}

datafusion/spark/src/function/datetime/date_sub.rs

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,7 @@ impl ScalarUDFImpl for SparkDateSub {
7575
}
7676

7777
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
78-
let nullable = args.arg_fields.iter().any(|f| f.is_nullable())
79-
|| args
80-
.scalar_arguments
81-
.iter()
82-
.any(|arg| matches!(arg, Some(sv) if sv.is_null()));
83-
78+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
8479
Ok(Arc::new(Field::new(
8580
self.name(),
8681
DataType::Date32,
@@ -139,7 +134,6 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result<ArrayRef> {
139134
#[cfg(test)]
140135
mod tests {
141136
use super::*;
142-
use datafusion_common::ScalarValue;
143137

144138
#[test]
145139
fn test_date_sub_nullability_non_nullable_args() {
@@ -174,22 +168,4 @@ mod tests {
174168
assert!(result.is_nullable());
175169
assert_eq!(result.data_type(), &DataType::Date32);
176170
}
177-
178-
#[test]
179-
fn test_date_sub_nullability_scalar_null_argument() {
180-
let udf = SparkDateSub::new();
181-
let date_field = Arc::new(Field::new("d", DataType::Date32, false));
182-
let days_field = Arc::new(Field::new("n", DataType::Int32, false));
183-
let null_scalar = ScalarValue::Int32(None);
184-
185-
let result = udf
186-
.return_field_from_args(ReturnFieldArgs {
187-
arg_fields: &[date_field, days_field],
188-
scalar_arguments: &[None, Some(&null_scalar)],
189-
})
190-
.unwrap();
191-
192-
assert!(result.is_nullable());
193-
assert_eq!(result.data_type(), &DataType::Date32);
194-
}
195171
}

0 commit comments

Comments
 (0)