Skip to content

Commit 183ff66

Browse files
authored
Support centroids config for approx_percentile_cont_with_weight (#17003)
* Support centroids config for `approx_percentile_cont_with_weight` * Match two functions' signature * Update docs * Address comments and unify centroids config
1 parent 79c4c05 commit 183ff66

File tree

6 files changed

+151
-57
lines changed

6 files changed

+151
-57
lines changed

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ pub fn approx_percentile_cont(
7777
#[user_doc(
7878
doc_section(label = "Approximate Functions"),
7979
description = "Returns the approximate percentile of input values using the t-digest algorithm.",
80-
syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
80+
syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
8181
sql_example = r#"```sql
82+
> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+
+------------------------------------------------------------------+
84+
| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
85+
+------------------------------------------------------------------+
86+
| 65.0 |
87+
+------------------------------------------------------------------+
8288
> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
8389
+-----------------------------------------------------------------------+
8490
| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
@@ -313,7 +319,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
313319
}
314320
if arg_types.len() == 3 && !arg_types[2].is_integer() {
315321
return plan_err!(
316-
"approx_percentile_cont requires integer max_size input types"
322+
"approx_percentile_cont requires integer centroids input types"
317323
);
318324
}
319325
Ok(arg_types[0].clone())
@@ -360,6 +366,11 @@ impl ApproxPercentileAccumulator {
360366
}
361367
}
362368

369+
// public for approx_percentile_cont_with_weight
370+
pub(crate) fn max_size(&self) -> usize {
371+
self.digest.max_size()
372+
}
373+
363374
// public for approx_percentile_cont_with_weight
364375
pub fn merge_digests(&mut self, digests: &[TDigest]) {
365376
let digests = digests.iter().chain(std::iter::once(&self.digest));

datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,39 +25,66 @@ use arrow::datatypes::FieldRef;
2525
use arrow::{array::ArrayRef, datatypes::DataType};
2626
use datafusion_common::ScalarValue;
2727
use datafusion_common::{not_impl_err, plan_err, Result};
28+
use datafusion_expr::expr::{AggregateFunction, Sort};
2829
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29-
use datafusion_expr::type_coercion::aggregates::NUMERICS;
30+
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
3031
use datafusion_expr::Volatility::Immutable;
3132
use datafusion_expr::{
32-
Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
33-
};
34-
use datafusion_functions_aggregate_common::tdigest::{
35-
Centroid, TDigest, DEFAULT_MAX_SIZE,
33+
Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature,
3634
};
35+
use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
3736
use datafusion_macros::user_doc;
3837

3938
use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont};
4039

41-
make_udaf_expr_and_func!(
40+
create_func!(
4241
ApproxPercentileContWithWeight,
43-
approx_percentile_cont_with_weight,
44-
expression weight percentile,
45-
"Computes the approximate percentile continuous with weight of a set of numbers",
4642
approx_percentile_cont_with_weight_udaf
4743
);
4844

45+
/// Computes the approximate percentile continuous with weight of a set of numbers
46+
pub fn approx_percentile_cont_with_weight(
47+
order_by: Sort,
48+
weight: Expr,
49+
percentile: Expr,
50+
centroids: Option<Expr>,
51+
) -> Expr {
52+
let expr = order_by.expr.clone();
53+
54+
let args = if let Some(centroids) = centroids {
55+
vec![expr, weight, percentile, centroids]
56+
} else {
57+
vec![expr, weight, percentile]
58+
};
59+
60+
Expr::AggregateFunction(AggregateFunction::new_udf(
61+
approx_percentile_cont_with_weight_udaf(),
62+
args,
63+
false,
64+
None,
65+
vec![order_by],
66+
None,
67+
))
68+
}
69+
4970
/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
5071
#[user_doc(
5172
doc_section(label = "Approximate Functions"),
5273
description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
53-
syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
74+
syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)",
5475
sql_example = r#"```sql
5576
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
5677
+---------------------------------------------------------------------------------------------+
5778
| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
5879
+---------------------------------------------------------------------------------------------+
5980
| 78.5 |
6081
+---------------------------------------------------------------------------------------------+
82+
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
83+
+--------------------------------------------------------------------------------------------------+
84+
| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) |
85+
+--------------------------------------------------------------------------------------------------+
86+
| 78.5 |
87+
+--------------------------------------------------------------------------------------------------+
6188
```"#,
6289
standard_argument(name = "expression", prefix = "The"),
6390
argument(
@@ -67,6 +94,10 @@ make_udaf_expr_and_func!(
6794
argument(
6895
name = "percentile",
6996
description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)."
97+
),
98+
argument(
99+
name = "centroids",
100+
description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory."
70101
)
71102
)]
72103
pub struct ApproxPercentileContWithWeight {
@@ -91,21 +122,26 @@ impl Default for ApproxPercentileContWithWeight {
91122
impl ApproxPercentileContWithWeight {
92123
/// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
93124
pub fn new() -> Self {
125+
let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
126+
// Accept any numeric value paired with weight and float64 percentile
127+
for num in NUMERICS {
128+
variants.push(TypeSignature::Exact(vec![
129+
num.clone(),
130+
num.clone(),
131+
DataType::Float64,
132+
]));
133+
// Additionally accept an integer number of centroids for T-Digest
134+
for int in INTEGERS {
135+
variants.push(TypeSignature::Exact(vec![
136+
num.clone(),
137+
num.clone(),
138+
DataType::Float64,
139+
int.clone(),
140+
]));
141+
}
142+
}
94143
Self {
95-
signature: Signature::one_of(
96-
// Accept any numeric value paired with a float64 percentile
97-
NUMERICS
98-
.iter()
99-
.map(|t| {
100-
TypeSignature::Exact(vec![
101-
t.clone(),
102-
t.clone(),
103-
DataType::Float64,
104-
])
105-
})
106-
.collect(),
107-
Immutable,
108-
),
144+
signature: Signature::one_of(variants, Immutable),
109145
approx_percentile_cont: ApproxPercentileCont::new(),
110146
}
111147
}
@@ -138,6 +174,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
138174
if arg_types[2] != DataType::Float64 {
139175
return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types");
140176
}
177+
if arg_types.len() == 4 && !arg_types[3].is_integer() {
178+
return plan_err!(
179+
"approx_percentile_cont_with_weight requires integer centroids input types"
180+
);
181+
}
141182
Ok(arg_types[0].clone())
142183
}
143184

@@ -148,17 +189,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
148189
);
149190
}
150191

151-
if acc_args.exprs.len() != 3 {
192+
if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
152193
return plan_err!(
153-
"approx_percentile_cont_with_weight requires three arguments: value, weight, percentile"
194+
"approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]"
154195
);
155196
}
156197

157198
let sub_args = AccumulatorArgs {
158-
exprs: &[
159-
Arc::clone(&acc_args.exprs[0]),
160-
Arc::clone(&acc_args.exprs[2]),
161-
],
199+
exprs: if acc_args.exprs.len() == 4 {
200+
&[
201+
Arc::clone(&acc_args.exprs[0]), // value
202+
Arc::clone(&acc_args.exprs[2]), // percentile
203+
Arc::clone(&acc_args.exprs[3]), // centroids
204+
]
205+
} else {
206+
&[
207+
Arc::clone(&acc_args.exprs[0]), // value
208+
Arc::clone(&acc_args.exprs[2]), // percentile
209+
]
210+
},
162211
..acc_args
163212
};
164213
let approx_percentile_cont_accumulator =
@@ -244,7 +293,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
244293
let mut digests: Vec<TDigest> = vec![];
245294
for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
246295
digests.push(TDigest::new_with_centroid(
247-
DEFAULT_MAX_SIZE,
296+
self.approx_percentile_cont_accumulator.max_size(),
248297
Centroid::new(*mean, *weight),
249298
))
250299
}

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,18 @@ async fn roundtrip_expr_api() -> Result<()> {
982982
approx_median(lit(2)),
983983
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None),
984984
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))),
985-
approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
985+
approx_percentile_cont_with_weight(
986+
lit(2).sort(true, false),
987+
lit(1),
988+
lit(0.5),
989+
None,
990+
),
991+
approx_percentile_cont_with_weight(
992+
lit(2).sort(true, false),
993+
lit(1),
994+
lit(0.5),
995+
Some(lit(50)),
996+
),
986997
grouping(lit(1)),
987998
bit_and(lit(2)),
988999
bit_or(lit(2)),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,16 @@ c 123
18401840
d 124
18411841
e 115
18421842

1843+
# approx_percentile_cont_with_weight with centroids
1844+
query TI
1845+
SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
1846+
----
1847+
a 74
1848+
b 68
1849+
c 123
1850+
d 124
1851+
e 115
1852+
18431853
# csv_query_sum_crossjoin
18441854
query TTI
18451855
SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1

docs/source/user-guide/expressions.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1);
285285

286286
## Aggregate Functions
287287

288-
| Syntax | Description |
289-
| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
290-
| avg(expr) | Сalculates the average value for `expr`. |
291-
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
292-
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
293-
| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. |
294-
| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. |
295-
| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. |
296-
| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. |
297-
| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. |
298-
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
299-
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
300-
| count(expr) | Returns the number of rows for `expr`. |
301-
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
302-
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
303-
| grouping_set(exprs) | Create a grouping set. |
304-
| max(expr) | Finds the maximum value of `expr`. |
305-
| median(expr) | Сalculates the median of `expr`. |
306-
| min(expr) | Finds the minimum value of `expr`. |
307-
| rollup(exprs) | Creates a grouping set for rollup sets. |
308-
| sum(expr) | Сalculates the sum of `expr`. |
288+
| Syntax | Description |
289+
| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
290+
| avg(expr) | Сalculates the average value for `expr`. |
291+
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
292+
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
293+
| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). |
294+
| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). |
295+
| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. |
296+
| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. |
297+
| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. |
298+
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
299+
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
300+
| count(expr) | Returns the number of rows for `expr`. |
301+
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
302+
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
303+
| grouping_set(exprs) | Create a grouping set. |
304+
| max(expr) | Finds the maximum value of `expr`. |
305+
| median(expr) | Сalculates the median of `expr`. |
306+
| min(expr) | Finds the minimum value of `expr`. |
307+
| rollup(exprs) | Creates a grouping set for rollup sets. |
308+
| sum(expr) | Сalculates the sum of `expr`. |
309309

310310
## Aggregate Function Builder
311311

0 commit comments

Comments
 (0)