Skip to content

Commit 22a1eab

Browse files
authored
Introduce avg_distinct() and sum_distinct() functions to DataFrame API (#17536)
* Introduce `avg_distinct()` and `sum_distinct()` functions to DataFrame API * Add to roundtrip proto tests
1 parent 46a47a9 commit 22a1eab

File tree

6 files changed

+59
-23
lines changed

6 files changed

+59
-23
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ use arrow_schema::{SortOptions, TimeUnit};
3535
use datafusion::{assert_batches_eq, dataframe};
3636
use datafusion_functions_aggregate::count::{count_all, count_all_window};
3737
use datafusion_functions_aggregate::expr_fn::{
38-
array_agg, avg, count, count_distinct, max, median, min, sum,
38+
array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum,
39+
sum_distinct,
3940
};
4041
use datafusion_functions_nested::make_array::make_array_udf;
4142
use datafusion_functions_window::expr_fn::{first_value, row_number};
@@ -502,32 +503,35 @@ async fn drop_with_periods() -> Result<()> {
502503
#[tokio::test]
503504
async fn aggregate() -> Result<()> {
504505
// build plan using DataFrame API
505-
let df = test_table().await?;
506+
// union so some of the distincts have a clearly distinct result
507+
let df = test_table().await?.union(test_table().await?)?;
506508
let group_expr = vec![col("c1")];
507509
let aggr_expr = vec![
508-
min(col("c12")),
509-
max(col("c12")),
510-
avg(col("c12")),
511-
sum(col("c12")),
512-
count(col("c12")),
513-
count_distinct(col("c12")),
510+
min(col("c4")).alias("min(c4)"),
511+
max(col("c4")).alias("max(c4)"),
512+
avg(col("c4")).alias("avg(c4)"),
513+
avg_distinct(col("c4")).alias("avg_distinct(c4)"),
514+
sum(col("c4")).alias("sum(c4)"),
515+
sum_distinct(col("c4")).alias("sum_distinct(c4)"),
516+
count(col("c4")).alias("count(c4)"),
517+
count_distinct(col("c4")).alias("count_distinct(c4)"),
514518
];
515519

516520
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
517521

518522
assert_snapshot!(
519523
batches_to_sort_string(&df),
520-
@r###"
521-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
522-
| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |
523-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
524-
| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |
525-
| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |
526-
| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |
527-
| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |
528-
| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |
529-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
530-
"###
524+
@r"
525+
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
526+
| c1 | min(c4) | max(c4) | avg(c4) | avg_distinct(c4) | sum(c4) | sum_distinct(c4) | count(c4) | count_distinct(c4) |
527+
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
528+
| a | -28462 | 32064 | 306.04761904761904 | 306.04761904761904 | 12854 | 6427 | 42 | 21 |
529+
| b | -28070 | 25286 | 7732.315789473684 | 7732.315789473684 | 293828 | 146914 | 38 | 19 |
530+
| c | -30508 | 29106 | -1320.5238095238096 | -1320.5238095238096 | -55462 | -27731 | 42 | 21 |
531+
| d | -24558 | 31106 | 10890.111111111111 | 10890.111111111111 | 392044 | 196022 | 36 | 18 |
532+
| e | -31500 | 32514 | -4268.333333333333 | -4268.333333333333 | -179270 | -89635 | 42 | 21 |
533+
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
534+
"
531535
);
532536

533537
Ok(())
@@ -542,7 +546,9 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> {
542546
min(col("c12")),
543547
max(col("c12")),
544548
avg(col("c12")),
549+
avg_distinct(col("c12")),
545550
sum(col("c12")),
551+
sum_distinct(col("c12")),
546552
count(col("c12")),
547553
count_distinct(col("c12")),
548554
median(col("c12")),

datafusion/functions-aggregate/src/average.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_typ
3838
use datafusion_expr::utils::format_state_name;
3939
use datafusion_expr::Volatility::Immutable;
4040
use datafusion_expr::{
41-
Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
41+
Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
4242
ReversedUDAF, Signature,
4343
};
4444

@@ -66,6 +66,17 @@ make_udaf_expr_and_func!(
6666
avg_udaf
6767
);
6868

69+
pub fn avg_distinct(expr: Expr) -> Expr {
70+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
71+
avg_udaf(),
72+
vec![expr],
73+
true,
74+
None,
75+
vec![],
76+
None,
77+
))
78+
}
79+
6980
#[user_doc(
7081
doc_section(label = "General Functions"),
7182
description = "Returns the average of numeric values in the specified column.",

datafusion/functions-aggregate/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ pub mod expr_fn {
105105
pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight;
106106
pub use super::array_agg::array_agg;
107107
pub use super::average::avg;
108+
pub use super::average::avg_distinct;
108109
pub use super::bit_and_or_xor::bit_and;
109110
pub use super::bit_and_or_xor::bit_or;
110111
pub use super::bit_and_or_xor::bit_xor;
@@ -134,6 +135,7 @@ pub mod expr_fn {
134135
pub use super::stddev::stddev;
135136
pub use super::stddev::stddev_pop;
136137
pub use super::sum::sum;
138+
pub use super::sum::sum_distinct;
137139
pub use super::variance::var_pop;
138140
pub use super::variance::var_sample;
139141
}

datafusion/functions-aggregate/src/sum.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use ahash::RandomState;
2121
use arrow::datatypes::DECIMAL32_MAX_PRECISION;
2222
use arrow::datatypes::DECIMAL64_MAX_PRECISION;
2323
use datafusion_expr::utils::AggregateOrderSensitivity;
24+
use datafusion_expr::Expr;
2425
use std::any::Any;
2526
use std::mem::size_of_val;
2627

@@ -55,6 +56,17 @@ make_udaf_expr_and_func!(
5556
sum_udaf
5657
);
5758

59+
pub fn sum_distinct(expr: Expr) -> Expr {
60+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
61+
sum_udaf(),
62+
vec![expr],
63+
true,
64+
None,
65+
vec![],
66+
None,
67+
))
68+
}
69+
5870
/// Sum only supports a subset of numeric types, instead relying on type coercion
5971
///
6072
/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion::execution::options::ArrowReadOptions;
3232
use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion;
3333
use datafusion::optimizer::Optimizer;
3434
use datafusion_common::parsers::CompressionTypeVariant;
35+
use datafusion_functions_aggregate::sum::sum_distinct;
3536
use prost::Message;
3637
use std::any::Any;
3738
use std::collections::HashMap;
@@ -82,8 +83,8 @@ use datafusion_expr::{
8283
};
8384
use datafusion_functions_aggregate::average::avg_udaf;
8485
use datafusion_functions_aggregate::expr_fn::{
85-
approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr,
86-
nth_value,
86+
approx_distinct, array_agg, avg, avg_distinct, bit_and, bit_or, bit_xor, bool_and,
87+
bool_or, corr, nth_value,
8788
};
8889
use datafusion_functions_aggregate::string_agg::string_agg;
8990
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
@@ -967,10 +968,12 @@ async fn roundtrip_expr_api() -> Result<()> {
967968
functions_window::nth_value::last_value(lit(1)),
968969
functions_window::nth_value::nth_value(lit(1), 1),
969970
avg(lit(1.5)),
971+
avg_distinct(lit(1.5)),
970972
covar_samp(lit(1.5), lit(2.2)),
971973
covar_pop(lit(1.5), lit(2.2)),
972974
corr(lit(1.5), lit(2.2)),
973975
sum(lit(1)),
976+
sum_distinct(lit(1)),
974977
max(lit(1)),
975978
median(lit(2)),
976979
min(lit(2)),

docs/source/user-guide/expressions.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ select log(-1), log(0), sqrt(-1);
288288
| Syntax | Description |
289289
| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
290290
| avg(expr) | Сalculates the average value for `expr`. |
291+
| avg_distinct(expr) | Creates an expression to represent the avg(distinct) aggregate function |
291292
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
292293
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
293294
| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). |
@@ -298,14 +299,15 @@ select log(-1), log(0), sqrt(-1);
298299
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
299300
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
300301
| count(expr) | Returns the number of rows for `expr`. |
301-
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
302+
| count_distinct(expr) | Creates an expression to represent the count(distinct) aggregate function |
302303
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
303304
| grouping_set(exprs) | Create a grouping set. |
304305
| max(expr) | Finds the maximum value of `expr`. |
305306
| median(expr) | Сalculates the median of `expr`. |
306307
| min(expr) | Finds the minimum value of `expr`. |
307308
| rollup(exprs) | Creates a grouping set for rollup sets. |
308309
| sum(expr) | Сalculates the sum of `expr`. |
310+
| sum_distinct(expr) | Creates an expression to represent the sum(distinct) aggregate function |
309311

310312
## Aggregate Function Builder
311313

0 commit comments

Comments
 (0)