Skip to content

Commit ce08307

Browse files
refactor: Use Signature::coercible for isnan/iszero (#19604)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #14763. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? Replace `TypeSignature::Exact` patterns with cleaner APIs: - `isnan/iszero`: `Signature::coercible` with `TypeSignatureClass::Float` - `nanvl`: `Signature::uniform(2, [Float32, Float64])` <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Jeffrey Vo <[email protected]>
1 parent 1f654bb commit ce08307

File tree

3 files changed

+65
-26
lines changed

3 files changed

+65
-26
lines changed

datafusion/functions/src/math/iszero.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, AsArray, BooleanArray};
22-
use arrow::datatypes::DataType::{Boolean, Float32, Float64};
23-
use arrow::datatypes::{DataType, Float32Type, Float64Type};
21+
use arrow::array::{ArrayRef, ArrowNativeTypeOp, AsArray, BooleanArray};
22+
use arrow::datatypes::DataType::{Boolean, Float16, Float32, Float64};
23+
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
2424

25-
use datafusion_common::{Result, exec_err};
26-
use datafusion_expr::TypeSignature::Exact;
25+
use datafusion_common::types::NativeType;
26+
use datafusion_common::{Result, ScalarValue, exec_err};
27+
use datafusion_expr::{Coercion, TypeSignatureClass};
2728
use datafusion_expr::{
2829
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
2930
Volatility,
@@ -59,12 +60,14 @@ impl Default for IsZeroFunc {
5960

6061
impl IsZeroFunc {
6162
pub fn new() -> Self {
62-
use DataType::*;
63+
// Accept any numeric type and coerce to float
64+
let float = Coercion::new_implicit(
65+
TypeSignatureClass::Float,
66+
vec![TypeSignatureClass::Numeric],
67+
NativeType::Float64,
68+
);
6369
Self {
64-
signature: Signature::one_of(
65-
vec![Exact(vec![Float32]), Exact(vec![Float64])],
66-
Volatility::Immutable,
67-
),
70+
signature: Signature::coercible(vec![float], Volatility::Immutable),
6871
}
6972
}
7073
}
@@ -87,6 +90,10 @@ impl ScalarUDFImpl for IsZeroFunc {
8790
}
8891

8992
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93+
// Handle NULL input
94+
if args.args[0].data_type().is_null() {
95+
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
96+
}
9097
make_scalar_function(iszero, vec![])(&args.args)
9198
}
9299

@@ -108,6 +115,11 @@ fn iszero(args: &[ArrayRef]) -> Result<ArrayRef> {
108115
|x| x == 0.0,
109116
)) as ArrayRef),
110117

118+
Float16 => Ok(Arc::new(BooleanArray::from_unary(
119+
args[0].as_primitive::<Float16Type>(),
120+
|x| x.is_zero(),
121+
)) as ArrayRef),
122+
111123
other => exec_err!("Unsupported data type {other:?} for function iszero"),
112124
}
113125
}

datafusion/functions/src/math/nans.rs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
//! Math function: `isnan()`.
1919
20-
use arrow::datatypes::{DataType, Float32Type, Float64Type};
21-
use datafusion_common::{Result, exec_err};
22-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature};
20+
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
21+
use datafusion_common::types::NativeType;
22+
use datafusion_common::{Result, ScalarValue, exec_err};
23+
use datafusion_expr::{Coercion, ColumnarValue, ScalarFunctionArgs, TypeSignatureClass};
2324

2425
use arrow::array::{ArrayRef, AsArray, BooleanArray};
2526
use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility};
@@ -54,15 +55,14 @@ impl Default for IsNanFunc {
5455

5556
impl IsNanFunc {
5657
pub fn new() -> Self {
57-
use DataType::*;
58+
// Accept any numeric type and coerce to float
59+
let float = Coercion::new_implicit(
60+
TypeSignatureClass::Float,
61+
vec![TypeSignatureClass::Numeric],
62+
NativeType::Float64,
63+
);
5864
Self {
59-
signature: Signature::one_of(
60-
vec![
61-
TypeSignature::Exact(vec![Float32]),
62-
TypeSignature::Exact(vec![Float64]),
63-
],
64-
Volatility::Immutable,
65-
),
65+
signature: Signature::coercible(vec![float], Volatility::Immutable),
6666
}
6767
}
6868
}
@@ -84,6 +84,11 @@ impl ScalarUDFImpl for IsNanFunc {
8484
}
8585

8686
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
87+
// Handle NULL input
88+
if args.args[0].data_type().is_null() {
89+
return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)));
90+
}
91+
8792
let args = ColumnarValue::values_to_arrays(&args.args)?;
8893

8994
let arr: ArrayRef = match args[0].data_type() {
@@ -96,6 +101,11 @@ impl ScalarUDFImpl for IsNanFunc {
96101
args[0].as_primitive::<Float32Type>(),
97102
f32::is_nan,
98103
)) as ArrayRef,
104+
105+
DataType::Float16 => Arc::new(BooleanArray::from_unary(
106+
args[0].as_primitive::<Float16Type>(),
107+
|x| x.is_nan(),
108+
)) as ArrayRef,
99109
other => {
100110
return exec_err!(
101111
"Unsupported data type {other:?} for function {}",

datafusion/functions/src/math/nanvl.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ use std::sync::Arc;
2020

2121
use crate::utils::make_scalar_function;
2222

23-
use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24-
use arrow::datatypes::DataType::{Float32, Float64};
25-
use arrow::datatypes::{DataType, Float32Type, Float64Type};
23+
use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array};
24+
use arrow::datatypes::DataType::{Float16, Float32, Float64};
25+
use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type};
2626
use datafusion_common::{DataFusionError, Result, exec_err};
2727
use datafusion_expr::TypeSignature::Exact;
2828
use datafusion_expr::{
@@ -66,10 +66,13 @@ impl Default for NanvlFunc {
6666

6767
impl NanvlFunc {
6868
pub fn new() -> Self {
69-
use DataType::*;
7069
Self {
7170
signature: Signature::one_of(
72-
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
71+
vec![
72+
Exact(vec![Float16, Float16]),
73+
Exact(vec![Float32, Float32]),
74+
Exact(vec![Float64, Float64]),
75+
],
7376
Volatility::Immutable,
7477
),
7578
}
@@ -91,6 +94,7 @@ impl ScalarUDFImpl for NanvlFunc {
9194

9295
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
9396
match &arg_types[0] {
97+
Float16 => Ok(Float16),
9498
Float32 => Ok(Float32),
9599
_ => Ok(Float64),
96100
}
@@ -130,6 +134,19 @@ fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
130134
.map(|res| Arc::new(res) as _)
131135
.map_err(DataFusionError::from)
132136
}
137+
Float16 => {
138+
let compute_nanvl =
139+
|x: <Float16Type as arrow::datatypes::ArrowPrimitiveType>::Native,
140+
y: <Float16Type as arrow::datatypes::ArrowPrimitiveType>::Native| {
141+
if x.is_nan() { y } else { x }
142+
};
143+
144+
let x = args[0].as_primitive() as &Float16Array;
145+
let y = args[1].as_primitive() as &Float16Array;
146+
arrow::compute::binary::<_, _, _, Float16Type>(x, y, compute_nanvl)
147+
.map(|res| Arc::new(res) as _)
148+
.map_err(DataFusionError::from)
149+
}
133150
other => exec_err!("Unsupported data type {other:?} for function nanvl"),
134151
}
135152
}

0 commit comments

Comments
 (0)