@@ -41,7 +41,6 @@ use serde::Serialize;
4141use crate :: aggregates:: aggregate_function_factory:: AggregateFunctionDescription ;
4242use crate :: aggregates:: assert_params;
4343use crate :: aggregates:: assert_unary_arguments;
44- use crate :: aggregates:: assert_unary_params;
4544use crate :: aggregates:: AggregateFunction ;
4645use crate :: aggregates:: AggregateFunctionRef ;
4746use crate :: aggregates:: StateAddr ;
@@ -56,7 +55,7 @@ pub trait QuantileStateFunc<T: ValueType>: Send + Sync + 'static {
5655 fn add ( & mut self , other : T :: ScalarRef < ' _ > ) ;
5756 fn add_batch ( & mut self , column : & T :: Column , validity : Option < & Bitmap > ) -> Result < ( ) > ;
5857 fn merge ( & mut self , rhs : & Self ) -> Result < ( ) > ;
59- fn merge_result ( & mut self , builder : & mut ColumnBuilder , level : f64 ) -> Result < ( ) > ;
58+ fn merge_result ( & mut self , builder : & mut ColumnBuilder , levels : Vec < f64 > ) -> Result < ( ) > ;
6059 fn serialize ( & self , writer : & mut Vec < u8 > ) -> Result < ( ) > ;
6160 fn deserialize ( & mut self , reader : & mut & [ u8 ] ) -> Result < ( ) > ;
6261}
@@ -129,16 +128,36 @@ where
129128 Ok ( ( ) )
130129 }
131130
132- fn merge_result ( & mut self , builder : & mut ColumnBuilder , level : f64 ) -> Result < ( ) > {
133- let builder = T :: try_downcast_builder ( builder) . unwrap ( ) ;
131+ fn merge_result ( & mut self , builder : & mut ColumnBuilder , levels : Vec < f64 > ) -> Result < ( ) > {
134132 let value_len = self . value . len ( ) ;
135- let idx = ( ( value_len - 1 ) as f64 * level) . floor ( ) as usize ;
136- if idx >= value_len {
137- T :: push_default ( builder) ;
133+ if levels. len ( ) > 1 {
134+ let builder = match builder {
135+ ColumnBuilder :: Array ( box b) => b,
136+ _ => unreachable ! ( ) ,
137+ } ;
138+ let indices = levels
139+ . iter ( )
140+ . map ( |level| ( ( value_len - 1 ) as f64 * ( * level) ) . floor ( ) as usize )
141+ . collect :: < Vec < usize > > ( ) ;
142+ for idx in indices {
143+ if idx < value_len {
144+ self . value . as_mut_slice ( ) . select_nth_unstable ( idx) ;
145+ let value = self . value . get ( idx) . unwrap ( ) ;
146+ builder. put_item ( T :: to_scalar_ref ( value) ) ;
147+ } else {
148+ builder. push_default ( ) ;
149+ }
150+ }
138151 } else {
139- self . value . as_mut_slice ( ) . select_nth_unstable ( idx) ;
140- let value = self . value . get ( idx) . unwrap ( ) ;
141- T :: push_item ( builder, T :: to_scalar_ref ( value) ) ;
152+ let builder = T :: try_downcast_builder ( builder) . unwrap ( ) ;
153+ let idx = ( ( value_len - 1 ) as f64 * levels[ 0 ] ) . floor ( ) as usize ;
154+ if idx >= value_len {
155+ T :: push_default ( builder) ;
156+ } else {
157+ self . value . as_mut_slice ( ) . select_nth_unstable ( idx) ;
158+ let value = self . value . get ( idx) . unwrap ( ) ;
159+ T :: push_item ( builder, T :: to_scalar_ref ( value) ) ;
160+ }
142161 }
143162 Ok ( ( ) )
144163 }
@@ -157,7 +176,7 @@ where
157176pub struct AggregateQuantileContFunction < T , State > {
158177 display_name : String ,
159178 return_type : DataType ,
160- level : f64 ,
179+ levels : Vec < f64 > ,
161180 _arguments : Vec < DataType > ,
162181 _t : PhantomData < T > ,
163182 _state : PhantomData < State > ,
@@ -252,7 +271,7 @@ where
252271
253272 fn merge_result ( & self , place : StateAddr , builder : & mut ColumnBuilder ) -> Result < ( ) > {
254273 let state = place. get :: < State > ( ) ;
255- state. merge_result ( builder, self . level )
274+ state. merge_result ( builder, self . levels . clone ( ) )
256275 }
257276}
258277
@@ -267,7 +286,7 @@ where
267286 params : Vec < Scalar > ,
268287 arguments : Vec < DataType > ,
269288 ) -> Result < Arc < dyn AggregateFunction > > {
270- let level = if params. len ( ) == 1 {
289+ let levels = if params. len ( ) == 1 {
271290 let level: F64 = check_number (
272291 None ,
273292 FunctionContext :: default ( ) ,
@@ -283,22 +302,50 @@ where
283302 } ,
284303 & BUILTIN_FUNCTIONS ,
285304 ) ?;
286- level. 0
305+ let level = level. 0 ;
306+ if !( 0.0 ..=1.0 ) . contains ( & level) {
307+ return Err ( ErrorCode :: BadDataValueType ( format ! (
308+ "level range between [0, 1], got: {:?}" ,
309+ level
310+ ) ) ) ;
311+ }
312+ vec ! [ level]
313+ } else if params. len ( ) == 0 {
314+ vec ! [ 0.5f64 ]
287315 } else {
288- 0.5f64
316+ let mut levels = Vec :: with_capacity ( params. len ( ) ) ;
317+ for param in params {
318+ let level: F64 = check_number (
319+ None ,
320+ FunctionContext :: default ( ) ,
321+ & Expr :: < usize > :: Cast {
322+ span : None ,
323+ is_try : false ,
324+ expr : Box :: new ( Expr :: Constant {
325+ span : None ,
326+ scalar : param. clone ( ) ,
327+ data_type : param. as_ref ( ) . infer_data_type ( ) ,
328+ } ) ,
329+ dest_type : DataType :: Number ( NumberDataType :: Float64 ) ,
330+ } ,
331+ & BUILTIN_FUNCTIONS ,
332+ ) ?;
333+ let level = level. 0 ;
334+ if !( 0.0 ..=1.0 ) . contains ( & level) {
335+ return Err ( ErrorCode :: BadDataValueType ( format ! (
336+ "level range between [0, 1], got: {:?} in levels" ,
337+ level
338+ ) ) ) ;
339+ }
340+ levels. push ( level) ;
341+ }
342+ levels
289343 } ;
290344
291- if !( 0.0 ..=1.0 ) . contains ( & level) {
292- return Err ( ErrorCode :: BadDataValueType ( format ! (
293- "level range between [0, 1], got: {:?}" ,
294- level
295- ) ) ) ;
296- }
297-
298345 let func = AggregateQuantileContFunction :: < T , State > {
299346 display_name : display_name. to_string ( ) ,
300347 return_type,
301- level ,
348+ levels ,
302349 _arguments : arguments,
303350 _t : PhantomData ,
304351 _state : PhantomData ,
@@ -313,9 +360,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
313360 params : Vec < Scalar > ,
314361 arguments : Vec < DataType > ,
315362) -> Result < AggregateFunctionRef > {
316- if TYPE == QUANTILE {
317- assert_unary_params ( display_name, params. len ( ) ) ?;
318- } else {
363+ if TYPE == MEDIAN {
319364 assert_params ( display_name, params. len ( ) , 0 ) ?;
320365 }
321366
@@ -327,9 +372,14 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
327372 with_number_mapped_type!( |NUM | match num_type {
328373 NumberDataType :: NUM => {
329374 type State = QuantileState <NumberType <NUM >>;
375+ let return_type = if params. len( ) > 1 {
376+ DataType :: Array ( Box :: new( data_type) )
377+ } else {
378+ data_type
379+ } ;
330380 AggregateQuantileContFunction :: <NumberType <NUM >, State >:: try_create(
331381 display_name,
332- data_type ,
382+ return_type ,
333383 params,
334384 arguments,
335385 )
@@ -341,10 +391,16 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
341391 precision: s. precision,
342392 scale: s. scale,
343393 } ;
394+ let data_type = DataType :: Decimal ( DecimalDataType :: from_size( decimal_size) ?) ;
395+ let return_type = if params. len( ) > 1 {
396+ DataType :: Array ( Box :: new( data_type) )
397+ } else {
398+ data_type
399+ } ;
344400 type State = QuantileState <DecimalType <i128 >>;
345401 AggregateQuantileContFunction :: <DecimalType <i128 >, State >:: try_create(
346402 display_name,
347- DataType :: Decimal ( DecimalDataType :: from_size ( decimal_size ) ? ) ,
403+ return_type ,
348404 params,
349405 arguments,
350406 )
@@ -354,10 +410,16 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
354410 precision: s. precision,
355411 scale: s. scale,
356412 } ;
413+ let data_type = DataType :: Decimal ( DecimalDataType :: from_size( decimal_size) ?) ;
414+ let return_type = if params. len( ) > 1 {
415+ DataType :: Array ( Box :: new( data_type) )
416+ } else {
417+ data_type
418+ } ;
357419 type State = QuantileState <DecimalType <i256>>;
358420 AggregateQuantileContFunction :: <DecimalType <i256>, State >:: try_create(
359421 display_name,
360- DataType :: Decimal ( DecimalDataType :: from_size ( decimal_size ) ? ) ,
422+ return_type ,
361423 params,
362424 arguments,
363425 )
0 commit comments