Skip to content

Commit 4ddee14

Browse files
refactor: Refactor spark width bucket signature away from user defined (#19065)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #12725 ## Rationale for this change - As per the goal stated that we should avoid using the user_defined in useful places <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? - Refactor the `user_defined` in spark width_bucket <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Jeffrey Vo <[email protected]>
1 parent ad9b779 commit 4ddee14

File tree

2 files changed

+88
-76
lines changed

2 files changed

+88
-76
lines changed

datafusion/common/src/types/builtin.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::IntervalUnit::*;
19+
use arrow::datatypes::TimeUnit::*;
1920

2021
use crate::types::{LogicalTypeRef, NativeType};
2122
use std::sync::{Arc, LazyLock};
@@ -82,3 +83,17 @@ singleton_variant!(
8283
Interval,
8384
MonthDayNano
8485
);
86+
87+
singleton_variant!(
88+
LOGICAL_INTERVAL_YEAR_MONTH,
89+
logical_interval_year_month,
90+
Interval,
91+
YearMonth
92+
);
93+
94+
singleton_variant!(
95+
LOGICAL_DURATION_MICROSECOND,
96+
logical_duration_microsecond,
97+
Duration,
98+
Microsecond
99+
);

datafusion/spark/src/function/math/width_bucket.rs

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use crate::function::error_utils::unsupported_data_types_exec_err;
2221
use arrow::array::{
2322
Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
2423
IntervalYearMonthArray,
@@ -30,14 +29,21 @@ use datafusion_common::cast::{
3029
as_duration_microsecond_array, as_float64_array, as_int32_array,
3130
as_interval_mdn_array, as_interval_ym_array,
3231
};
33-
use datafusion_common::{exec_err, Result};
32+
use datafusion_common::types::{
33+
logical_duration_microsecond, logical_float64, logical_int32, logical_interval_mdn,
34+
logical_interval_year_month, NativeType,
35+
};
36+
use datafusion_common::{exec_err, internal_err, Result};
3437
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35-
use datafusion_expr::type_coercion::is_signed_numeric;
36-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
38+
use datafusion_expr::{
39+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
40+
TypeSignatureClass,
41+
};
3742
use datafusion_functions::utils::make_scalar_function;
3843

3944
use arrow::array::{Int32Array, Int32Builder};
4045
use arrow::datatypes::TimeUnit::Microsecond;
46+
use datafusion_expr::Coercion;
4147
use datafusion_expr::Volatility::Immutable;
4248

4349
#[derive(Debug, PartialEq, Eq, Hash)]
@@ -53,8 +59,59 @@ impl Default for SparkWidthBucket {
5359

5460
impl SparkWidthBucket {
5561
pub fn new() -> Self {
62+
let numeric = Coercion::new_implicit(
63+
TypeSignatureClass::Native(logical_float64()),
64+
vec![TypeSignatureClass::Numeric],
65+
NativeType::Float64,
66+
);
67+
let duration = Coercion::new_implicit(
68+
TypeSignatureClass::Native(logical_duration_microsecond()),
69+
vec![TypeSignatureClass::Duration],
70+
NativeType::Duration(Microsecond),
71+
);
72+
let interval_ym = Coercion::new_exact(TypeSignatureClass::Native(
73+
logical_interval_year_month(),
74+
));
75+
let interval_mdn =
76+
Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn()));
77+
let bucket = Coercion::new_implicit(
78+
TypeSignatureClass::Native(logical_int32()),
79+
vec![TypeSignatureClass::Integer],
80+
NativeType::Int32,
81+
);
82+
let type_signature = Signature::one_of(
83+
vec![
84+
TypeSignature::Coercible(vec![
85+
numeric.clone(),
86+
numeric.clone(),
87+
numeric.clone(),
88+
bucket.clone(),
89+
]),
90+
TypeSignature::Coercible(vec![
91+
duration.clone(),
92+
duration.clone(),
93+
duration.clone(),
94+
bucket.clone(),
95+
]),
96+
TypeSignature::Coercible(vec![
97+
interval_ym.clone(),
98+
interval_ym.clone(),
99+
interval_ym.clone(),
100+
bucket.clone(),
101+
]),
102+
TypeSignature::Coercible(vec![
103+
interval_mdn.clone(),
104+
interval_mdn.clone(),
105+
interval_mdn.clone(),
106+
bucket.clone(),
107+
]),
108+
],
109+
Immutable,
110+
)
111+
.with_parameter_names(vec!["expr", "min", "max", "num_buckets"])
112+
.expect("valid parameter names");
56113
Self {
57-
signature: Signature::user_defined(Immutable),
114+
signature: type_signature,
58115
}
59116
}
60117
}
@@ -88,63 +145,6 @@ impl ScalarUDFImpl for SparkWidthBucket {
88145
Ok(SortProperties::default())
89146
}
90147
}
91-
92-
fn coerce_types(&self, types: &[DataType]) -> Result<Vec<DataType>> {
93-
use DataType::*;
94-
95-
let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]);
96-
97-
match (v, lo, hi, n) {
98-
(a, b, c, &(Int8 | Int16 | Int32 | Int64))
99-
if is_signed_numeric(a)
100-
&& is_signed_numeric(b)
101-
&& is_signed_numeric(c) =>
102-
{
103-
Ok(vec![Float64, Float64, Float64, Int32])
104-
}
105-
(
106-
&Duration(_),
107-
&Duration(_),
108-
&Duration(_),
109-
&(Int8 | Int16 | Int32 | Int64),
110-
) => Ok(vec![
111-
Duration(Microsecond),
112-
Duration(Microsecond),
113-
Duration(Microsecond),
114-
Int32,
115-
]),
116-
(
117-
&Interval(MonthDayNano),
118-
&Interval(MonthDayNano),
119-
&Interval(MonthDayNano),
120-
&(Int8 | Int16 | Int32 | Int64),
121-
) => Ok(vec![
122-
Interval(MonthDayNano),
123-
Interval(MonthDayNano),
124-
Interval(MonthDayNano),
125-
Int32,
126-
]),
127-
(
128-
&Interval(YearMonth),
129-
&Interval(YearMonth),
130-
&Interval(YearMonth),
131-
&(Int8 | Int16 | Int32 | Int64),
132-
) => Ok(vec![
133-
Interval(YearMonth),
134-
Interval(YearMonth),
135-
Interval(YearMonth),
136-
Int32,
137-
]),
138-
139-
_ => exec_err!(
140-
"width_bucket expects a numeric argument, got {} {} {} {}",
141-
types[0],
142-
types[1],
143-
types[2],
144-
types[3]
145-
),
146-
}
147-
}
148148
}
149149

150150
fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
@@ -182,20 +182,18 @@ fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
182182
let min = as_interval_mdn_array(minv)?;
183183
let max = as_interval_mdn_array(maxv)?;
184184
let n_bucket = as_int32_array(nb)?;
185-
Ok(Arc::new(width_bucket_interval_mdn_exact(v, min, max, n_bucket)))
185+
Ok(Arc::new(width_bucket_interval_mdn_exact(
186+
v, min, max, n_bucket,
187+
)))
186188
}
187189

188-
189-
other => Err(unsupported_data_types_exec_err(
190-
"width_bucket",
191-
"Float/Decimal OR Duration OR Interval(YearMonth) for first 3 args; Int for 4th",
192-
&[
193-
other.clone(),
194-
minv.data_type().clone(),
195-
maxv.data_type().clone(),
196-
nb.data_type().clone(),
197-
],
198-
)),
190+
other => internal_err!(
191+
"width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}",
192+
other,
193+
minv.data_type(),
194+
maxv.data_type(),
195+
nb.data_type()
196+
),
199197
}
200198
}
201199

@@ -780,8 +778,7 @@ mod tests {
780778
let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
781779
let msg = format!("{err}");
782780
assert!(
783-
msg.contains("unsupported data types")
784-
|| msg.contains("Float/Decimal OR Duration OR Interval(YearMonth)"),
781+
msg.contains("width_bucket received unexpected data types"),
785782
"unexpected error: {msg}"
786783
);
787784
}

0 commit comments

Comments
 (0)