Skip to content

Commit 1837cd8

Browse files
committed
supports multiple levels
1 parent 7ad8c21 commit 1837cd8

File tree

1 file changed

+91
-29
lines changed

1 file changed

+91
-29
lines changed

src/query/functions/src/aggregates/aggregate_quantile_cont.rs

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ use serde::Serialize;
4141
use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription;
4242
use crate::aggregates::assert_params;
4343
use crate::aggregates::assert_unary_arguments;
44-
use crate::aggregates::assert_unary_params;
4544
use crate::aggregates::AggregateFunction;
4645
use crate::aggregates::AggregateFunctionRef;
4746
use crate::aggregates::StateAddr;
@@ -56,7 +55,7 @@ pub trait QuantileStateFunc<T: ValueType>: Send + Sync + 'static {
5655
fn add(&mut self, other: T::ScalarRef<'_>);
5756
fn add_batch(&mut self, column: &T::Column, validity: Option<&Bitmap>) -> Result<()>;
5857
fn merge(&mut self, rhs: &Self) -> Result<()>;
59-
fn merge_result(&mut self, builder: &mut ColumnBuilder, level: f64) -> Result<()>;
58+
fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec<f64>) -> Result<()>;
6059
fn serialize(&self, writer: &mut Vec<u8>) -> Result<()>;
6160
fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()>;
6261
}
@@ -129,16 +128,36 @@ where
129128
Ok(())
130129
}
131130

132-
fn merge_result(&mut self, builder: &mut ColumnBuilder, level: f64) -> Result<()> {
133-
let builder = T::try_downcast_builder(builder).unwrap();
131+
fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec<f64>) -> Result<()> {
134132
let value_len = self.value.len();
135-
let idx = ((value_len - 1) as f64 * level).floor() as usize;
136-
if idx >= value_len {
137-
T::push_default(builder);
133+
if levels.len() > 1 {
134+
let builder = match builder {
135+
ColumnBuilder::Array(box b) => b,
136+
_ => unreachable!(),
137+
};
138+
let indices = levels
139+
.iter()
140+
.map(|level| ((value_len - 1) as f64 * (*level)).floor() as usize)
141+
.collect::<Vec<usize>>();
142+
for idx in indices {
143+
if idx < value_len {
144+
self.value.as_mut_slice().select_nth_unstable(idx);
145+
let value = self.value.get(idx).unwrap();
146+
builder.put_item(T::to_scalar_ref(value));
147+
} else {
148+
builder.push_default();
149+
}
150+
}
138151
} else {
139-
self.value.as_mut_slice().select_nth_unstable(idx);
140-
let value = self.value.get(idx).unwrap();
141-
T::push_item(builder, T::to_scalar_ref(value));
152+
let builder = T::try_downcast_builder(builder).unwrap();
153+
let idx = ((value_len - 1) as f64 * levels[0]).floor() as usize;
154+
if idx >= value_len {
155+
T::push_default(builder);
156+
} else {
157+
self.value.as_mut_slice().select_nth_unstable(idx);
158+
let value = self.value.get(idx).unwrap();
159+
T::push_item(builder, T::to_scalar_ref(value));
160+
}
142161
}
143162
Ok(())
144163
}
@@ -157,7 +176,7 @@ where
157176
pub struct AggregateQuantileContFunction<T, State> {
158177
display_name: String,
159178
return_type: DataType,
160-
level: f64,
179+
levels: Vec<f64>,
161180
_arguments: Vec<DataType>,
162181
_t: PhantomData<T>,
163182
_state: PhantomData<State>,
@@ -252,7 +271,7 @@ where
252271

253272
fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> {
254273
let state = place.get::<State>();
255-
state.merge_result(builder, self.level)
274+
state.merge_result(builder, self.levels.clone())
256275
}
257276
}
258277

@@ -267,7 +286,7 @@ where
267286
params: Vec<Scalar>,
268287
arguments: Vec<DataType>,
269288
) -> Result<Arc<dyn AggregateFunction>> {
270-
let level = if params.len() == 1 {
289+
let levels = if params.len() == 1 {
271290
let level: F64 = check_number(
272291
None,
273292
FunctionContext::default(),
@@ -283,22 +302,50 @@ where
283302
},
284303
&BUILTIN_FUNCTIONS,
285304
)?;
286-
level.0
305+
let level = level.0;
306+
if !(0.0..=1.0).contains(&level) {
307+
return Err(ErrorCode::BadDataValueType(format!(
308+
"level range between [0, 1], got: {:?}",
309+
level
310+
)));
311+
}
312+
vec![level]
313+
} else if params.len() == 0 {
314+
vec![0.5f64]
287315
} else {
288-
0.5f64
316+
let mut levels = Vec::with_capacity(params.len());
317+
for param in params {
318+
let level: F64 = check_number(
319+
None,
320+
FunctionContext::default(),
321+
&Expr::<usize>::Cast {
322+
span: None,
323+
is_try: false,
324+
expr: Box::new(Expr::Constant {
325+
span: None,
326+
scalar: param.clone(),
327+
data_type: param.as_ref().infer_data_type(),
328+
}),
329+
dest_type: DataType::Number(NumberDataType::Float64),
330+
},
331+
&BUILTIN_FUNCTIONS,
332+
)?;
333+
let level = level.0;
334+
if !(0.0..=1.0).contains(&level) {
335+
return Err(ErrorCode::BadDataValueType(format!(
336+
"level range between [0, 1], got: {:?} in levels",
337+
level
338+
)));
339+
}
340+
levels.push(level);
341+
}
342+
levels
289343
};
290344

291-
if !(0.0..=1.0).contains(&level) {
292-
return Err(ErrorCode::BadDataValueType(format!(
293-
"level range between [0, 1], got: {:?}",
294-
level
295-
)));
296-
}
297-
298345
let func = AggregateQuantileContFunction::<T, State> {
299346
display_name: display_name.to_string(),
300347
return_type,
301-
level,
348+
levels,
302349
_arguments: arguments,
303350
_t: PhantomData,
304351
_state: PhantomData,
@@ -313,9 +360,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
313360
params: Vec<Scalar>,
314361
arguments: Vec<DataType>,
315362
) -> Result<AggregateFunctionRef> {
316-
if TYPE == QUANTILE {
317-
assert_unary_params(display_name, params.len())?;
318-
} else {
363+
if TYPE == MEDIAN {
319364
assert_params(display_name, params.len(), 0)?;
320365
}
321366

@@ -327,9 +372,14 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
327372
with_number_mapped_type!(|NUM| match num_type {
328373
NumberDataType::NUM => {
329374
type State = QuantileState<NumberType<NUM>>;
375+
let return_type = if params.len() > 1 {
376+
DataType::Array(Box::new(data_type))
377+
} else {
378+
data_type
379+
};
330380
AggregateQuantileContFunction::<NumberType<NUM>, State>::try_create(
331381
display_name,
332-
data_type,
382+
return_type,
333383
params,
334384
arguments,
335385
)
@@ -341,10 +391,16 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
341391
precision: s.precision,
342392
scale: s.scale,
343393
};
394+
let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?);
395+
let return_type = if params.len() > 1 {
396+
DataType::Array(Box::new(data_type))
397+
} else {
398+
data_type
399+
};
344400
type State = QuantileState<DecimalType<i128>>;
345401
AggregateQuantileContFunction::<DecimalType<i128>, State>::try_create(
346402
display_name,
347-
DataType::Decimal(DecimalDataType::from_size(decimal_size)?),
403+
return_type,
348404
params,
349405
arguments,
350406
)
@@ -354,10 +410,16 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
354410
precision: s.precision,
355411
scale: s.scale,
356412
};
413+
let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?);
414+
let return_type = if params.len() > 1 {
415+
DataType::Array(Box::new(data_type))
416+
} else {
417+
data_type
418+
};
357419
type State = QuantileState<DecimalType<i256>>;
358420
AggregateQuantileContFunction::<DecimalType<i256>, State>::try_create(
359421
display_name,
360-
DataType::Decimal(DecimalDataType::from_size(decimal_size)?),
422+
return_type,
361423
params,
362424
arguments,
363425
)

0 commit comments

Comments
 (0)