Skip to content

Commit 211877f

Browse files
Jefffreyalamb
andauthored
Refactor approx_median signature & support f16 (#18647)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> Part of #18092 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Making more use of coercible signature API for consistency. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Refactor approx_median signature to use coercible API. Also implement support for Float16 inputs as that is now permitted in the new coercible API. Various other refactors too. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added SLT tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 54a7868 commit 211877f

File tree

4 files changed

+82
-135
lines changed

4 files changed

+82
-135
lines changed

datafusion/functions-aggregate-common/src/tdigest.rs

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
use arrow::datatypes::DataType;
3333
use arrow::datatypes::Float64Type;
3434
use datafusion_common::cast::as_primitive_array;
35-
use datafusion_common::Result;
3635
use datafusion_common::ScalarValue;
3736
use std::cmp::Ordering;
3837
use std::mem::{size_of, size_of_val};
@@ -61,41 +60,6 @@ macro_rules! cast_scalar_u64 {
6160
};
6261
}
6362

64-
/// This trait is implemented for each type a [`TDigest`] can operate on,
65-
/// allowing it to support both numerical rust types (obtained from
66-
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
67-
pub trait TryIntoF64 {
68-
/// A fallible conversion of a possibly null `self` into a [`f64`].
69-
///
70-
/// If `self` is null, this method must return `Ok(None)`.
71-
///
72-
/// If `self` cannot be coerced to the desired type, this method must return
73-
/// an `Err` variant.
74-
fn try_as_f64(&self) -> Result<Option<f64>>;
75-
}
76-
77-
/// Generate an infallible conversion from `type` to an [`f64`].
78-
macro_rules! impl_try_ordered_f64 {
79-
($type:ty) => {
80-
impl TryIntoF64 for $type {
81-
fn try_as_f64(&self) -> Result<Option<f64>> {
82-
Ok(Some(*self as f64))
83-
}
84-
}
85-
};
86-
}
87-
88-
impl_try_ordered_f64!(f64);
89-
impl_try_ordered_f64!(f32);
90-
impl_try_ordered_f64!(i64);
91-
impl_try_ordered_f64!(i32);
92-
impl_try_ordered_f64!(i16);
93-
impl_try_ordered_f64!(i8);
94-
impl_try_ordered_f64!(u64);
95-
impl_try_ordered_f64!(u32);
96-
impl_try_ordered_f64!(u16);
97-
impl_try_ordered_f64!(u8);
98-
9963
/// Centroid implementation to the cluster mentioned in the paper.
10064
#[derive(Debug, PartialEq, Clone)]
10165
pub struct Centroid {

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@
1919
2020
use arrow::datatypes::DataType::{Float64, UInt64};
2121
use arrow::datatypes::{DataType, Field, FieldRef};
22+
use datafusion_common::types::NativeType;
23+
use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator;
2224
use std::any::Any;
2325
use std::fmt::Debug;
2426
use std::sync::Arc;
2527

26-
use datafusion_common::{not_impl_err, plan_err, Result};
28+
use datafusion_common::{not_impl_err, Result};
2729
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
28-
use datafusion_expr::type_coercion::aggregates::NUMERICS;
2930
use datafusion_expr::utils::format_state_name;
3031
use datafusion_expr::{
31-
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
32+
Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature,
33+
TypeSignatureClass, Volatility,
3234
};
3335
use datafusion_macros::user_doc;
3436

@@ -57,20 +59,11 @@ make_udaf_expr_and_func!(
5759
```"#,
5860
standard_argument(name = "expression",)
5961
)]
60-
#[derive(PartialEq, Eq, Hash)]
62+
#[derive(Debug, PartialEq, Eq, Hash)]
6163
pub struct ApproxMedian {
6264
signature: Signature,
6365
}
6466

65-
impl Debug for ApproxMedian {
66-
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
67-
f.debug_struct("ApproxMedian")
68-
.field("name", &self.name())
69-
.field("signature", &self.signature)
70-
.finish()
71-
}
72-
}
73-
7467
impl Default for ApproxMedian {
7568
fn default() -> Self {
7669
Self::new()
@@ -81,33 +74,53 @@ impl ApproxMedian {
8174
/// Create a new APPROX_MEDIAN aggregate function
8275
pub fn new() -> Self {
8376
Self {
84-
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
77+
signature: Signature::one_of(
78+
vec![
79+
TypeSignature::Coercible(vec![Coercion::new_exact(
80+
TypeSignatureClass::Integer,
81+
)]),
82+
TypeSignature::Coercible(vec![Coercion::new_implicit(
83+
TypeSignatureClass::Float,
84+
vec![TypeSignatureClass::Decimal],
85+
NativeType::Float64,
86+
)]),
87+
],
88+
Volatility::Immutable,
89+
),
8590
}
8691
}
8792
}
8893

8994
impl AggregateUDFImpl for ApproxMedian {
90-
/// Return a reference to Any that can be used for downcasting
9195
fn as_any(&self) -> &dyn Any {
9296
self
9397
}
9498

9599
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
96-
Ok(vec![
97-
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
98-
Field::new(format_state_name(args.name, "sum"), Float64, false),
99-
Field::new(format_state_name(args.name, "count"), UInt64, false),
100-
Field::new(format_state_name(args.name, "max"), Float64, false),
101-
Field::new(format_state_name(args.name, "min"), Float64, false),
102-
Field::new_list(
103-
format_state_name(args.name, "centroids"),
104-
Field::new_list_field(Float64, true),
105-
false,
106-
),
107-
]
108-
.into_iter()
109-
.map(Arc::new)
110-
.collect())
100+
if args.input_fields[0].data_type().is_null() {
101+
Ok(vec![Field::new(
102+
format_state_name(args.name, self.name()),
103+
DataType::Null,
104+
true,
105+
)
106+
.into()])
107+
} else {
108+
Ok(vec![
109+
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
110+
Field::new(format_state_name(args.name, "sum"), Float64, false),
111+
Field::new(format_state_name(args.name, "count"), UInt64, false),
112+
Field::new(format_state_name(args.name, "max"), Float64, false),
113+
Field::new(format_state_name(args.name, "min"), Float64, false),
114+
Field::new_list(
115+
format_state_name(args.name, "centroids"),
116+
Field::new_list_field(Float64, true),
117+
false,
118+
),
119+
]
120+
.into_iter()
121+
.map(Arc::new)
122+
.collect())
123+
}
111124
}
112125

113126
fn name(&self) -> &str {
@@ -119,9 +132,6 @@ impl AggregateUDFImpl for ApproxMedian {
119132
}
120133

121134
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
122-
if !arg_types[0].is_numeric() {
123-
return plan_err!("ApproxMedian requires numeric input types");
124-
}
125135
Ok(arg_types[0].clone())
126136
}
127137

@@ -132,10 +142,14 @@ impl AggregateUDFImpl for ApproxMedian {
132142
);
133143
}
134144

135-
Ok(Box::new(ApproxPercentileAccumulator::new(
136-
0.5_f64,
137-
acc_args.expr_fields[0].data_type().clone(),
138-
)))
145+
if acc_args.expr_fields[0].data_type().is_null() {
146+
Ok(Box::new(NoopAccumulator::default()))
147+
} else {
148+
Ok(Box::new(ApproxPercentileAccumulator::new(
149+
0.5_f64,
150+
acc_args.expr_fields[0].data_type().clone(),
151+
)))
152+
}
139153
}
140154

141155
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 21 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
// under the License.
1717

1818
use std::any::Any;
19-
use std::fmt::{Debug, Formatter};
19+
use std::fmt::Debug;
2020
use std::mem::size_of_val;
2121
use std::sync::Arc;
2222

23-
use arrow::array::Array;
23+
use arrow::array::{Array, Float16Array};
2424
use arrow::compute::{filter, is_not_null};
2525
use arrow::datatypes::FieldRef;
2626
use arrow::{
@@ -42,9 +42,7 @@ use datafusion_expr::{
4242
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
4343
Volatility,
4444
};
45-
use datafusion_functions_aggregate_common::tdigest::{
46-
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
47-
};
45+
use datafusion_functions_aggregate_common::tdigest::{TDigest, DEFAULT_MAX_SIZE};
4846
use datafusion_macros::user_doc;
4947
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
5048

@@ -121,20 +119,11 @@ An alternate syntax is also supported:
121119
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
122120
)
123121
)]
124-
#[derive(PartialEq, Eq, Hash)]
122+
#[derive(Debug, PartialEq, Eq, Hash)]
125123
pub struct ApproxPercentileCont {
126124
signature: Signature,
127125
}
128126

129-
impl Debug for ApproxPercentileCont {
130-
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
131-
f.debug_struct("ApproxPercentileCont")
132-
.field("name", &self.name())
133-
.field("signature", &self.signature)
134-
.finish()
135-
}
136-
}
137-
138127
impl Default for ApproxPercentileCont {
139128
fn default() -> Self {
140129
Self::new()
@@ -197,6 +186,7 @@ impl ApproxPercentileCont {
197186
| DataType::Int16
198187
| DataType::Int32
199188
| DataType::Int64
189+
| DataType::Float16
200190
| DataType::Float32
201191
| DataType::Float64 => {
202192
if let Some(max_size) = tdigest_max_size {
@@ -372,83 +362,51 @@ impl ApproxPercentileAccumulator {
372362
match values.data_type() {
373363
DataType::Float64 => {
374364
let array = downcast_value!(values, Float64Array);
375-
Ok(array
376-
.values()
377-
.iter()
378-
.filter_map(|v| v.try_as_f64().transpose())
379-
.collect::<Result<Vec<_>>>()?)
365+
Ok(array.values().iter().copied().collect::<Vec<_>>())
380366
}
381367
DataType::Float32 => {
382368
let array = downcast_value!(values, Float32Array);
369+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
370+
}
371+
DataType::Float16 => {
372+
let array = downcast_value!(values, Float16Array);
383373
Ok(array
384374
.values()
385375
.iter()
386-
.filter_map(|v| v.try_as_f64().transpose())
387-
.collect::<Result<Vec<_>>>()?)
376+
.map(|v| v.to_f64())
377+
.collect::<Vec<_>>())
388378
}
389379
DataType::Int64 => {
390380
let array = downcast_value!(values, Int64Array);
391-
Ok(array
392-
.values()
393-
.iter()
394-
.filter_map(|v| v.try_as_f64().transpose())
395-
.collect::<Result<Vec<_>>>()?)
381+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
396382
}
397383
DataType::Int32 => {
398384
let array = downcast_value!(values, Int32Array);
399-
Ok(array
400-
.values()
401-
.iter()
402-
.filter_map(|v| v.try_as_f64().transpose())
403-
.collect::<Result<Vec<_>>>()?)
385+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
404386
}
405387
DataType::Int16 => {
406388
let array = downcast_value!(values, Int16Array);
407-
Ok(array
408-
.values()
409-
.iter()
410-
.filter_map(|v| v.try_as_f64().transpose())
411-
.collect::<Result<Vec<_>>>()?)
389+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
412390
}
413391
DataType::Int8 => {
414392
let array = downcast_value!(values, Int8Array);
415-
Ok(array
416-
.values()
417-
.iter()
418-
.filter_map(|v| v.try_as_f64().transpose())
419-
.collect::<Result<Vec<_>>>()?)
393+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
420394
}
421395
DataType::UInt64 => {
422396
let array = downcast_value!(values, UInt64Array);
423-
Ok(array
424-
.values()
425-
.iter()
426-
.filter_map(|v| v.try_as_f64().transpose())
427-
.collect::<Result<Vec<_>>>()?)
397+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
428398
}
429399
DataType::UInt32 => {
430400
let array = downcast_value!(values, UInt32Array);
431-
Ok(array
432-
.values()
433-
.iter()
434-
.filter_map(|v| v.try_as_f64().transpose())
435-
.collect::<Result<Vec<_>>>()?)
401+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
436402
}
437403
DataType::UInt16 => {
438404
let array = downcast_value!(values, UInt16Array);
439-
Ok(array
440-
.values()
441-
.iter()
442-
.filter_map(|v| v.try_as_f64().transpose())
443-
.collect::<Result<Vec<_>>>()?)
405+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
444406
}
445407
DataType::UInt8 => {
446408
let array = downcast_value!(values, UInt8Array);
447-
Ok(array
448-
.values()
449-
.iter()
450-
.filter_map(|v| v.try_as_f64().transpose())
451-
.collect::<Result<Vec<_>>>()?)
409+
Ok(array.values().iter().map(|v| *v as f64).collect::<Vec<_>>())
452410
}
453411
e => internal_err!(
454412
"APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}"
@@ -491,6 +449,7 @@ impl Accumulator for ApproxPercentileAccumulator {
491449
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
492450
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
493451
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
452+
DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))),
494453
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
495454
DataType::Float64 => ScalarValue::Float64(Some(q)),
496455
v => unreachable!("unexpected return type {}", v),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,16 @@ SELECT approx_median(col_f64_nan) FROM median_table
991991
----
992992
NaN
993993

994+
query RT
995+
select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table;
996+
----
997+
2.75 Float16
998+
999+
query ?T
1000+
select approx_median(NULL), arrow_typeof(approx_median(NULL)) from median_table;
1001+
----
1002+
NULL Null
1003+
9941004
# median decimal
9951005
statement ok
9961006
create table t(c decimal(10, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006);

0 commit comments

Comments
 (0)