diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ff7762e816ad..6fd90130e674 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -111,20 +111,12 @@ An alternative syntax is also supported: description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, } -impl Debug for ApproxPercentileContWithWeight { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ApproxPercentileContWithWeight") - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxPercentileContWithWeight { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index a107024e2fb4..77b99cd1ae99 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -114,11 +114,7 @@ pub struct BoolAnd { impl BoolAnd { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } @@ -251,11 +247,7 @@ pub struct BoolOr { impl BoolOr { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 538311dfa263..9c66d714386a 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -367,7 +367,7 @@ fn accumulate_correlation_states( /// where: /// n = number of observations /// sum_x = sum of x values -/// sum_y = sum of y values +/// sum_y = sum of y values /// sum_xy = sum of (x * y) /// sum_xx = sum of x^2 values /// sum_yy = sum of y^2 values diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 10cc2ad33f56..376cf3974590 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -147,20 +147,11 @@ pub fn count_all_window() -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Count { signature: Signature, } -impl Debug for Count { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("Count") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Count { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index e86d742db3d4..8252cf1b19c4 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,19 +17,13 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, Float64Array, UInt64Array}, - compute::kernels::cast, - datatypes::{DataType, Field}, -}; -use datafusion_common::{ - Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, }; use datafusion_functions_aggregate_common::stats::StatsType; @@ -69,21 +63,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovarianceSample { signature: Signature, aliases: Vec, } -impl Debug for CovarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovarianceSample { fn default() -> Self { Self::new() @@ -94,7 +79,10 @@ impl CovarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("covar")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -112,11 +100,7 @@ impl AggregateUDFImpl for CovarianceSample { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -165,20 +149,11 @@ impl AggregateUDFImpl for CovarianceSample { standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovariancePopulation { signature: Signature, } -impl Debug for CovariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovariancePopulation { fn default() -> Self { Self::new() @@ -188,7 +163,10 @@ impl Default for CovariancePopulation { impl CovariancePopulation { pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -206,11 +184,7 @@ impl AggregateUDFImpl for CovariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -304,30 +278,15 @@ impl Accumulator for CovarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, }; - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); let new_count = self.count + 1; let delta1 = value1 - self.mean1; let new_mean1 = delta1 / new_count as f64 + self.mean1; @@ -345,29 +304,14 @@ impl Accumulator for CovarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, + }; let new_count = self.count - 1; let delta1 = self.mean1 - value1; @@ -386,10 +330,10 @@ impl Accumulator for CovarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means1 = as_float64_array(&states[1])?; + let means2 = as_float64_array(&states[2])?; + let cs = as_float64_array(&states[3])?; for i in 0..counts.len() { let c = counts.value(i); diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 5f3490f535a4..b339479b35e9 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -90,22 +90,12 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for FirstValue { fn default() -> Self { Self::new() @@ -1040,22 +1030,12 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for LastValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("LastValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for LastValue { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 43218b1147d3..c7af2df4b10f 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -18,7 +18,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt; use arrow::datatypes::Field; use arrow::datatypes::{DataType, FieldRef}; @@ -60,20 +59,11 @@ make_udaf_expr_and_func!( description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Grouping { signature: Signature, } -impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Grouping") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Grouping { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f137ae0801f0..e986c9a612c4 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -85,20 +85,11 @@ make_udaf_expr_and_func!( /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Median { signature: Signature, } -impl Debug for Median { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("Median") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Median { fn default() -> Self { Self::new() diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index bbc5567dab9d..066fa3c5f32e 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,20 +17,12 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::Float64Array; use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::{ - HashMap, Result, ScalarValue, downcast_value, plan_err, unwrap_or_internal_err, -}; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{HashMap, Result, ScalarValue}; use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -58,26 +50,20 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Regr { signature: Signature, regr_type: RegrType, func_name: &'static str, } -impl Debug for Regr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("regr") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Regr { pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), regr_type, func_name, } @@ -468,11 +454,7 @@ impl AggregateUDFImpl for Regr { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { if matches!(self.regr_type, RegrType::Count) { Ok(DataType::UInt64) } else { @@ -606,32 +588,18 @@ impl Accumulator for RegrAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // regr_slope(Y, X) calculates k in y = k*x + b - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None - }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - self.count += 1; let delta_x = value_x - self.mean_x; let delta_y = value_y - self.mean_y; @@ -652,32 +620,18 @@ impl Accumulator for RegrAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None - }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - if self.count > 1 { self.count -= 1; let delta_x = value_x - self.mean_x; @@ -703,12 +657,12 @@ impl Accumulator for RegrAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let count_arr = downcast_value!(states[0], UInt64Array); - let mean_x_arr = downcast_value!(states[1], Float64Array); - let mean_y_arr = downcast_value!(states[2], Float64Array); - let m2_x_arr = downcast_value!(states[3], Float64Array); - let m2_y_arr = downcast_value!(states[4], Float64Array); - let algo_const_arr = downcast_value!(states[5], Float64Array); + let count_arr = as_uint64_array(&states[0])?; + let mean_x_arr = as_float64_array(&states[1])?; + let mean_y_arr = as_float64_array(&states[2])?; + let m2_x_arr = as_float64_array(&states[3])?; + let m2_y_arr = as_float64_array(&states[4])?; + let algo_const_arr = as_float64_array(&states[5])?; for i in 0..count_arr.len() { let count_b = count_arr.value(i); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 13eb5e1660b5..6f77e7df9254 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -18,7 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -26,8 +26,8 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::ScalarValue; use datafusion_common::{Result, internal_err, not_impl_err}; -use datafusion_common::{ScalarValue, plan_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -62,21 +62,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Stddev { signature: Signature, alias: Vec, } -impl Debug for Stddev { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Stddev") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Stddev { fn default() -> Self { Self::new() @@ -87,7 +78,7 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } @@ -180,20 +171,11 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV_POP population aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct StddevPop { signature: Signature, } -impl Debug for StddevPop { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StddevPop") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for StddevPop { fn default() -> Self { Self::new() @@ -204,7 +186,7 @@ impl StddevPop { /// Create a new STDDEV_POP aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -249,11 +231,7 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("StddevPop requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -318,13 +296,8 @@ impl Accumulator for StddevAccumulator { fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))), _ => internal_err!("Variance should be f64"), } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 9e35bf0a2bea..fb089ba4f9ce 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -22,10 +22,10 @@ use arrow::datatypes::{FieldRef, Float64Type}; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, - compute::kernels::cast, datatypes::{DataType, Field}, }; -use datafusion_common::{Result, ScalarValue, downcast_value, plan_err}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility, @@ -62,21 +62,12 @@ make_udaf_expr_and_func!( syntax_example = "var(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VarianceSample { signature: Signature, aliases: Vec, } -impl Debug for VarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VarianceSample { fn default() -> Self { Self::new() @@ -87,7 +78,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -171,21 +162,12 @@ impl AggregateUDFImpl for VarianceSample { syntax_example = "var_pop(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VariancePopulation { signature: Signature, aliases: Vec, } -impl Debug for VariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VariancePopulation { fn default() -> Self { Self::new() @@ -196,7 +178,7 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } @@ -214,11 +196,7 @@ impl AggregateUDFImpl for VariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Variance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -278,6 +256,7 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -365,10 +344,8 @@ impl Accumulator for VarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { (self.count, self.mean, self.m2) = update(self.count, self.mean, self.m2, value) } @@ -377,10 +354,8 @@ impl Accumulator for VarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { let new_count = self.count - 1; let delta1 = self.mean - value; let new_mean = delta1 / new_count as f64 + self.mean; @@ -396,9 +371,9 @@ impl Accumulator for VarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means = as_float64_array(&states[1])?; + let m2s = as_float64_array(&states[2])?; for i in 0..counts.len() { let c = counts.value(i); @@ -533,8 +508,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &cast(&values[0], &DataType::Float64)?; - let values = downcast_value!(values, Float64Array); + let values = as_float64_array(&values[0])?; self.resize(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { @@ -561,9 +535,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); // first batch is counts, second is partial means, third is partial m2s - let partial_counts = downcast_value!(values[0], UInt64Array); - let partial_means = downcast_value!(values[1], Float64Array); - let partial_m2s = downcast_value!(values[2], Float64Array); + let partial_counts = as_uint64_array(&values[0])?; + let partial_means = as_float64_array(&values[1])?; + let partial_m2s = as_float64_array(&values[2])?; self.resize(total_num_groups); Self::merge( @@ -633,9 +607,7 @@ impl DistinctVarianceAccumulator { impl Accumulator for DistinctVarianceAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let cast_values = cast(&values[0], &DataType::Float64)?; - self.distinct_values - .update_batch(vec![cast_values].as_ref()) + self.distinct_values.update_batch(values) } fn evaluate(&mut self) -> Result {