Skip to content

Commit efb5d8f

Browse files
authored
Merge pull request #10550 from ariesdevil/quantile
2 parents f35912a + 5150c71 commit efb5d8f

File tree

3 files changed

+67
-41
lines changed

3 files changed

+67
-41
lines changed

src/query/functions/src/aggregates/aggregate_quantile_cont.rs

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ use std::sync::Arc;
2121
use common_arrow::arrow::bitmap::Bitmap;
2222
use common_exception::ErrorCode;
2323
use common_exception::Result;
24+
use common_expression::type_check::check_number;
2425
use common_expression::types::decimal::*;
2526
use common_expression::types::number::*;
2627
use common_expression::types::*;
2728
use common_expression::with_number_mapped_type;
2829
use common_expression::Column;
2930
use common_expression::ColumnBuilder;
31+
use common_expression::Expr;
32+
use common_expression::FunctionContext;
3033
use common_expression::Scalar;
3134
use common_io::prelude::deserialize_from_slice;
3235
use common_io::prelude::serialize_into_buf;
@@ -36,11 +39,13 @@ use serde::Deserialize;
3639
use serde::Serialize;
3740

3841
use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription;
42+
use crate::aggregates::assert_params;
3943
use crate::aggregates::assert_unary_arguments;
4044
use crate::aggregates::assert_unary_params;
4145
use crate::aggregates::AggregateFunction;
4246
use crate::aggregates::AggregateFunctionRef;
4347
use crate::aggregates::StateAddr;
48+
use crate::scalars::BUILTIN_FUNCTIONS;
4449
use crate::with_simple_no_number_mapped_type;
4550

4651
const MEDIAN: u8 = 0;
@@ -259,9 +264,37 @@ where
259264
fn try_create(
260265
display_name: &str,
261266
return_type: DataType,
262-
level: f64,
267+
params: Vec<Scalar>,
263268
arguments: Vec<DataType>,
264269
) -> Result<Arc<dyn AggregateFunction>> {
270+
let level = if params.len() == 1 {
271+
let level: F64 = check_number(
272+
None,
273+
FunctionContext::default(),
274+
&Expr::<usize>::Cast {
275+
span: None,
276+
is_try: false,
277+
expr: Box::new(Expr::Constant {
278+
span: None,
279+
scalar: params[0].clone(),
280+
data_type: params[0].as_ref().infer_data_type(),
281+
}),
282+
dest_type: DataType::Number(NumberDataType::Float64),
283+
},
284+
&BUILTIN_FUNCTIONS,
285+
)?;
286+
level.0
287+
} else {
288+
0.5f64
289+
};
290+
291+
if !(0.0..=1.0).contains(&level) {
292+
return Err(ErrorCode::BadDataValueType(format!(
293+
"level range between [0, 1], got: {:?}",
294+
level
295+
)));
296+
}
297+
265298
let func = AggregateQuantileContFunction::<T, State> {
266299
display_name: display_name.to_string(),
267300
return_type,
@@ -280,44 +313,13 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
280313
params: Vec<Scalar>,
281314
arguments: Vec<DataType>,
282315
) -> Result<AggregateFunctionRef> {
283-
assert_unary_arguments(display_name, arguments.len())?;
284-
285-
let level = if TYPE == MEDIAN {
286-
0.5f64
287-
} else {
316+
if TYPE == QUANTILE {
288317
assert_unary_params(display_name, params.len())?;
289-
let param = params[0].clone();
290-
match param {
291-
Scalar::Decimal(d) => {
292-
let f = d.to_float64();
293-
if f <= 0.01 || f >= 0.99 {
294-
return Err(ErrorCode::BadDataValueType(format!(
295-
"level range between 0.01 to 0.99, got: {:?}",
296-
f
297-
)));
298-
}
299-
f
300-
}
301-
Scalar::Number(NumberScalar::UInt64(i)) => {
302-
if i == 0 {
303-
0.01f64
304-
} else if i == 1 {
305-
0.99f64
306-
} else {
307-
return Err(ErrorCode::BadDataValueType(format!(
308-
"level range between 0.01 to 0.99, got: {:?}",
309-
i
310-
)));
311-
}
312-
}
313-
_ => {
314-
return Err(ErrorCode::BadDataValueType(format!(
315-
"level param just support float type, got: {:?}",
316-
param
317-
)));
318-
}
319-
}
320-
};
318+
} else {
319+
assert_params(display_name, params.len(), 0)?;
320+
}
321+
322+
assert_unary_arguments(display_name, arguments.len())?;
321323

322324
let data_type = arguments[0].clone();
323325
with_simple_no_number_mapped_type!(|T| match data_type {
@@ -328,7 +330,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
328330
AggregateQuantileContFunction::<NumberType<NUM>, State>::try_create(
329331
display_name,
330332
data_type,
331-
level,
333+
params,
332334
arguments,
333335
)
334336
}
@@ -343,7 +345,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
343345
AggregateQuantileContFunction::<DecimalType<i128>, State>::try_create(
344346
display_name,
345347
DataType::Decimal(DecimalDataType::from_size(decimal_size)?),
346-
level,
348+
params,
347349
arguments,
348350
)
349351
}
@@ -356,7 +358,7 @@ pub fn try_create_aggregate_quantile_function<const TYPE: u8>(
356358
AggregateQuantileContFunction::<DecimalType<i256>, State>::try_create(
357359
display_name,
358360
DataType::Decimal(DecimalDataType::from_size(decimal_size)?),
359-
level,
361+
params,
360362
arguments,
361363
)
362364
}

src/query/functions/src/aggregates/aggregator_common.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ pub fn assert_unary_params<D: Display>(name: D, actual: usize) -> Result<()> {
3636
Ok(())
3737
}
3838

39+
pub fn assert_params<D: Display>(name: D, actual: usize, expected: usize) -> Result<()> {
40+
if actual != expected {
41+
return Err(ErrorCode::NumberArgumentsNotMatch(format!(
42+
"{} expect to have {} params, but got {}",
43+
name, expected, actual
44+
)));
45+
}
46+
Ok(())
47+
}
48+
3949
pub fn assert_unary_arguments<D: Display>(name: D, actual: usize) -> Result<()> {
4050
if actual != 1 {
4151
return Err(ErrorCode::NumberArgumentsNotMatch(format!(

tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ SELECT quantile(0.6)(number) from numbers_mt(10000)
212212
----
213213
5999
214214

215+
query I
216+
SELECT quantile(0)(number) from numbers_mt(10000)
217+
----
218+
0
219+
220+
query I
221+
SELECT quantile(1)(number) from numbers_mt(10000)
222+
----
223+
9999
224+
225+
statement error 1010
226+
SELECT quantile(5)(number) from numbers_mt(10000)
227+
228+
215229
statement ok
216230
DROP DATABASE db1
217231

0 commit comments

Comments
 (0)