Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 19 additions & 46 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,36 @@ use std::fmt::{Debug, Formatter};
use std::mem::size_of_val;
use std::sync::Arc;

use arrow::array::{Array, RecordBatch};
use arrow::array::Array;
use arrow::compute::{filter, is_not_null};
use arrow::datatypes::FieldRef;
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
},
datatypes::{DataType, Field, Schema},
datatypes::{DataType, Field},
};
use datafusion_common::{
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
Result, ScalarValue,
downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature,
TypeSignature, Volatility,
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
Volatility,
};
use datafusion_functions_aggregate_common::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use crate::utils::{get_percentile_scalar_value, validate_percentile_expr};

create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);

/// Computes the approximate percentile continuous of a set of numbers
Expand Down Expand Up @@ -164,7 +166,8 @@ impl ApproxPercentileCont {
&self,
args: AccumulatorArgs,
) -> Result<ApproxPercentileAccumulator> {
let percentile = validate_input_percentile_expr(&args.exprs[1])?;
let percentile =
validate_percentile_expr(&args.exprs[1], "APPROX_PERCENTILE_CONT")?;

let is_descending = args
.order_bys
Expand Down Expand Up @@ -214,45 +217,15 @@ impl ApproxPercentileCont {
}
}

fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
let empty_schema = Arc::new(Schema::empty());
let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
Ok(s)
} else {
internal_err!("Didn't expect ColumnarValue::Array")
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
let percentile = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
ScalarValue::Float32(Some(value)) => {
value as f64
}
ScalarValue::Float64(Some(value)) => {
value
}
sv => {
return not_impl_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
sv.data_type()
)
}
};

// Ensure the percentile is between 0 and 1.
if !(0.0..=1.0).contains(&percentile) {
return plan_err!(
"Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
);
}
Ok(percentile)
}

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
let max_size = match get_scalar_value(expr)
.map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? {
let scalar_value = get_percentile_scalar_value(expr).map_err(|_e| {
DataFusionError::Plan(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal"
.to_string(),
)
})?;

let max_size = match scalar_value {
ScalarValue::UInt8(Some(q)) => q as usize,
ScalarValue::UInt16(Some(q)) => q as usize,
ScalarValue::UInt32(Some(q)) => q as usize,
Expand All @@ -262,7 +235,7 @@ fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
ScalarValue::Int16(Some(q)) if q > 0 => q as usize,
ScalarValue::Int8(Some(q)) if q > 0 => q as usize,
sv => {
return not_impl_err!(
return plan_err!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
sv.data_type()
)
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ pub mod hyperloglog;
pub mod median;
pub mod min_max;
pub mod nth_value;
pub mod percentile_cont;
pub mod regr;
pub mod stddev;
pub mod string_agg;
pub mod sum;
pub mod variance;

pub mod planner;
mod utils;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
Expand Down Expand Up @@ -123,6 +125,7 @@ pub mod expr_fn {
pub use super::min_max::max;
pub use super::min_max::min;
pub use super::nth_value::nth_value;
pub use super::percentile_cont::percentile_cont;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
Expand Down Expand Up @@ -171,6 +174,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
approx_distinct::approx_distinct_udaf(),
approx_percentile_cont_udaf(),
approx_percentile_cont_with_weight_udaf(),
percentile_cont::percentile_cont_udaf(),
string_agg::string_agg_udaf(),
bit_and_or_xor::bit_and_udaf(),
bit_and_or_xor::bit_or_udaf(),
Expand Down
Loading