@@ -21,12 +21,15 @@ use std::sync::Arc;
2121use common_arrow:: arrow:: bitmap:: Bitmap ;
2222use common_exception:: ErrorCode ;
2323use common_exception:: Result ;
24+ use common_expression:: type_check:: check_number;
2425use common_expression:: types:: decimal:: * ;
2526use common_expression:: types:: number:: * ;
2627use common_expression:: types:: * ;
2728use common_expression:: with_number_mapped_type;
2829use common_expression:: Column ;
2930use common_expression:: ColumnBuilder ;
31+ use common_expression:: Expr ;
32+ use common_expression:: FunctionContext ;
3033use common_expression:: Scalar ;
3134use common_io:: prelude:: deserialize_from_slice;
3235use common_io:: prelude:: serialize_into_buf;
@@ -36,11 +39,13 @@ use serde::Deserialize;
3639use serde:: Serialize ;
3740
3841use crate :: aggregates:: aggregate_function_factory:: AggregateFunctionDescription ;
42+ use crate :: aggregates:: assert_params;
3943use crate :: aggregates:: assert_unary_arguments;
4044use crate :: aggregates:: assert_unary_params;
4145use crate :: aggregates:: AggregateFunction ;
4246use crate :: aggregates:: AggregateFunctionRef ;
4347use crate :: aggregates:: StateAddr ;
48+ use crate :: scalars:: BUILTIN_FUNCTIONS ;
4449use crate :: with_simple_no_number_mapped_type;
4550
4651const MEDIAN : u8 = 0 ;
@@ -259,9 +264,37 @@ where
259264 fn try_create (
260265 display_name : & str ,
261266 return_type : DataType ,
262- level : f64 ,
267+ params : Vec < Scalar > ,
263268 arguments : Vec < DataType > ,
264269 ) -> Result < Arc < dyn AggregateFunction > > {
270+ let level = if params. len ( ) == 1 {
271+ let level: F64 = check_number (
272+ None ,
273+ FunctionContext :: default ( ) ,
274+ & Expr :: < usize > :: Cast {
275+ span : None ,
276+ is_try : false ,
277+ expr : Box :: new ( Expr :: Constant {
278+ span : None ,
279+ scalar : params[ 0 ] . clone ( ) ,
280+ data_type : params[ 0 ] . as_ref ( ) . infer_data_type ( ) ,
281+ } ) ,
282+ dest_type : DataType :: Number ( NumberDataType :: Float64 ) ,
283+ } ,
284+ & BUILTIN_FUNCTIONS ,
285+ ) ?;
286+ level. 0
287+ } else {
288+ 0.5f64
289+ } ;
290+
291+ if !( 0.0 ..=1.0 ) . contains ( & level) {
292+ return Err ( ErrorCode :: BadDataValueType ( format ! (
293+ "level range between [0, 1], got: {:?}" ,
294+ level
295+ ) ) ) ;
296+ }
297+
265298 let func = AggregateQuantileContFunction :: < T , State > {
266299 display_name : display_name. to_string ( ) ,
267300 return_type,
@@ -280,44 +313,13 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
280313 params : Vec < Scalar > ,
281314 arguments : Vec < DataType > ,
282315) -> Result < AggregateFunctionRef > {
283- assert_unary_arguments ( display_name, arguments. len ( ) ) ?;
284-
285- let level = if TYPE == MEDIAN {
286- 0.5f64
287- } else {
316+ if TYPE == QUANTILE {
288317 assert_unary_params ( display_name, params. len ( ) ) ?;
289- let param = params[ 0 ] . clone ( ) ;
290- match param {
291- Scalar :: Decimal ( d) => {
292- let f = d. to_float64 ( ) ;
293- if f <= 0.01 || f >= 0.99 {
294- return Err ( ErrorCode :: BadDataValueType ( format ! (
295- "level range between 0.01 to 0.99, got: {:?}" ,
296- f
297- ) ) ) ;
298- }
299- f
300- }
301- Scalar :: Number ( NumberScalar :: UInt64 ( i) ) => {
302- if i == 0 {
303- 0.01f64
304- } else if i == 1 {
305- 0.99f64
306- } else {
307- return Err ( ErrorCode :: BadDataValueType ( format ! (
308- "level range between 0.01 to 0.99, got: {:?}" ,
309- i
310- ) ) ) ;
311- }
312- }
313- _ => {
314- return Err ( ErrorCode :: BadDataValueType ( format ! (
315- "level param just support float type, got: {:?}" ,
316- param
317- ) ) ) ;
318- }
319- }
320- } ;
318+ } else {
319+ assert_params ( display_name, params. len ( ) , 0 ) ?;
320+ }
321+
322+ assert_unary_arguments ( display_name, arguments. len ( ) ) ?;
321323
322324 let data_type = arguments[ 0 ] . clone ( ) ;
323325 with_simple_no_number_mapped_type ! ( |T | match data_type {
@@ -328,7 +330,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
328330 AggregateQuantileContFunction :: <NumberType <NUM >, State >:: try_create(
329331 display_name,
330332 data_type,
331- level ,
333+ params ,
332334 arguments,
333335 )
334336 }
@@ -343,7 +345,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
343345 AggregateQuantileContFunction :: <DecimalType <i128 >, State >:: try_create(
344346 display_name,
345347 DataType :: Decimal ( DecimalDataType :: from_size( decimal_size) ?) ,
346- level ,
348+ params ,
347349 arguments,
348350 )
349351 }
@@ -356,7 +358,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
356358 AggregateQuantileContFunction :: <DecimalType <i256>, State >:: try_create(
357359 display_name,
358360 DataType :: Decimal ( DecimalDataType :: from_size( decimal_size) ?) ,
359- level ,
361+ params ,
360362 arguments,
361363 )
362364 }
0 commit comments