Skip to content

Commit 50a3b11

Browse files
committed
feat: Support DISTINCT clause in UDAFs
Signed-off-by: Alex Qyoun-ae <[email protected]>
1 parent 6dcbac9 commit 50a3b11

File tree

14 files changed

+104
-31
lines changed

14 files changed

+104
-31
lines changed

datafusion/core/src/logical_plan/expr.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,9 @@ pub fn create_udaf(
237237
return_type: Arc<DataType>,
238238
volatility: Volatility,
239239
accumulator: AccumulatorFunctionImplementation,
240-
state_type: Arc<Vec<DataType>>,
240+
state_type: StateTypeFunction,
241241
) -> AggregateUDF {
242242
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
243-
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
244243
AggregateUDF::new(
245244
name,
246245
&Signature::exact(vec![input_type], volatility),

datafusion/core/src/logical_plan/expr_rewriter.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,14 @@ impl ExprRewritable for Expr {
281281
))
282282
}
283283
},
284-
Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
285-
args: rewrite_vec(args, rewriter)?,
284+
Expr::AggregateUDF {
285+
args,
286+
fun,
287+
distinct,
288+
} => Expr::AggregateUDF {
286289
fun,
290+
args: rewrite_vec(args, rewriter)?,
291+
distinct,
287292
},
288293
Expr::InList {
289294
expr,

datafusion/core/src/optimizer/common_subexpr_eliminate.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,10 @@ impl ExprIdentifierVisitor<'_> {
501501
desc.push_str(&fun.to_string());
502502
desc.push_str(&distinct.to_string());
503503
}
504-
Expr::AggregateUDF { fun, .. } => {
504+
Expr::AggregateUDF { fun, distinct, .. } => {
505505
desc.push_str("AggregateUDF-");
506506
desc.push_str(&fun.name);
507+
desc.push_str(&distinct.to_string());
507508
}
508509
Expr::InList { negated, .. } => {
509510
desc.push_str("InList-");

datafusion/core/src/optimizer/utils.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
568568
within_group,
569569
})
570570
}
571-
Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF {
571+
Expr::AggregateUDF { fun, distinct, .. } => Ok(Expr::AggregateUDF {
572572
fun: fun.clone(),
573573
args: expressions.to_vec(),
574+
distinct: *distinct,
574575
}),
575576
Expr::GroupingSet(grouping_set) => match grouping_set {
576577
GroupingSet::Rollup(_exprs) => {

datafusion/core/src/physical_plan/planner.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,17 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
217217
args,
218218
within_group.as_ref().map(|exprs| exprs.as_slice()),
219219
),
220-
Expr::AggregateUDF { fun, args } => {
220+
Expr::AggregateUDF {
221+
fun,
222+
args,
223+
distinct,
224+
} => {
221225
let mut names = Vec::with_capacity(args.len());
222226
for e in args {
223227
names.push(create_physical_name(e, false)?);
224228
}
225-
Ok(format!("{}({})", fun.name, names.join(",")))
229+
let distinct_str = if *distinct { "DISTINCT " } else { "" };
230+
Ok(format!("{}({}{})", fun.name, distinct_str, names.join(",")))
226231
}
227232
Expr::GroupingSet(grouping_set) => match grouping_set {
228233
GroupingSet::Rollup(exprs) => Ok(format!(
@@ -1639,7 +1644,12 @@ pub fn create_aggregate_expr_with_name(
16391644
within_group,
16401645
)
16411646
}
1642-
Expr::AggregateUDF { fun, args, .. } => {
1647+
Expr::AggregateUDF {
1648+
fun,
1649+
args,
1650+
distinct,
1651+
..
1652+
} => {
16431653
let args = args
16441654
.iter()
16451655
.map(|e| {
@@ -1652,7 +1662,13 @@ pub fn create_aggregate_expr_with_name(
16521662
})
16531663
.collect::<Result<Vec<_>>>()?;
16541664

1655-
udaf::create_aggregate_expr(fun, &args, physical_input_schema, name)
1665+
udaf::create_aggregate_expr(
1666+
fun,
1667+
&args,
1668+
*distinct,
1669+
physical_input_schema,
1670+
name,
1671+
)
16561672
}
16571673
other => Err(DataFusionError::Internal(format!(
16581674
"Invalid aggregate expression '{:?}'",

datafusion/core/src/physical_plan/udaf.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use std::sync::Arc;
4040
pub fn create_aggregate_expr(
4141
fun: &AggregateUDF,
4242
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
43+
distinct: bool,
4344
input_schema: &Schema,
4445
name: impl Into<String>,
4546
) -> Result<Arc<dyn AggregateExpr>> {
@@ -54,6 +55,7 @@ pub fn create_aggregate_expr(
5455
Ok(Arc::new(AggregateFunctionExpr {
5556
fun: fun.clone(),
5657
args: coerced_phy_exprs.clone(),
58+
distinct,
5759
data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
5860
name: name.into(),
5961
}))
@@ -64,6 +66,7 @@ pub fn create_aggregate_expr(
6466
pub struct AggregateFunctionExpr {
6567
fun: AggregateUDF,
6668
args: Vec<Arc<dyn PhysicalExpr>>,
69+
distinct: bool,
6770
data_type: DataType,
6871
name: String,
6972
}
@@ -79,7 +82,7 @@ impl AggregateExpr for AggregateFunctionExpr {
7982
}
8083

8184
fn state_fields(&self) -> Result<Vec<Field>> {
82-
let fields = (self.fun.state_type)(&self.data_type)?
85+
let fields = (self.fun.state_type)(&self.data_type, self.distinct)?
8386
.iter()
8487
.enumerate()
8588
.map(|(i, data_type)| {
@@ -99,7 +102,7 @@ impl AggregateExpr for AggregateFunctionExpr {
99102
}
100103

101104
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
102-
(self.fun.accumulator)()
105+
(self.fun.accumulator)(self.distinct)
103106
}
104107

105108
fn name(&self) -> &str {

datafusion/core/src/sql/planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2409,7 +2409,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
24092409
None => match self.schema_provider.get_aggregate_meta(&name) {
24102410
Some(fm) => {
24112411
let args = self.function_args_to_expr(function.args, schema)?;
2412-
Ok(Expr::AggregateUDF { fun: fm, args })
2412+
Ok(Expr::AggregateUDF { fun: fm, args, distinct: function.distinct })
24132413
}
24142414
None => match self.schema_provider.get_table_function_meta(&name) {
24152415
Some(fm) => {

datafusion/core/src/sql/utils.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,17 @@ where
337337
.collect::<Result<Vec<_>>>()?,
338338
window_frame: *window_frame,
339339
}),
340-
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
340+
Expr::AggregateUDF {
341+
fun,
342+
args,
343+
distinct,
344+
} => Ok(Expr::AggregateUDF {
341345
fun: fun.clone(),
342346
args: args
343347
.iter()
344348
.map(|e| clone_with_replacement(e, replacement_fn))
345349
.collect::<Result<Vec<Expr>>>()?,
350+
distinct: *distinct,
346351
}),
347352
Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
348353
Box::new(clone_with_replacement(nested_expr, replacement_fn)?),

datafusion/core/tests/sql/udf.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ use super::*;
1919
use arrow::compute::add;
2020
use datafusion::{
2121
logical_plan::{create_udaf, FunctionRegistry, LogicalPlanBuilder},
22-
physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function},
22+
physical_plan::{
23+
expressions::{AvgAccumulator, MaxAccumulator},
24+
functions::make_scalar_function,
25+
},
2326
};
2427

2528
/// test that casting happens on udfs.
@@ -144,15 +147,15 @@ async fn scalar_udf() -> Result<()> {
144147
/// tests the creation, registration and usage of a UDAF
145148
#[tokio::test]
146149
async fn simple_udaf() -> Result<()> {
147-
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
150+
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
148151

149152
let batch1 = RecordBatch::try_new(
150153
Arc::new(schema.clone()),
151-
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
154+
vec![Arc::new(Float64Array::from_slice([5.0, 10.0, 15.0]))],
152155
)?;
153156
let batch2 = RecordBatch::try_new(
154157
Arc::new(schema.clone()),
155-
vec![Arc::new(Int32Array::from_slice([4, 5]))],
158+
vec![Arc::new(Float64Array::from_slice([10.0, 15.0]))],
156159
)?;
157160

158161
let mut ctx = SessionContext::new();
@@ -166,8 +169,22 @@ async fn simple_udaf() -> Result<()> {
166169
DataType::Float64,
167170
Arc::new(DataType::Float64),
168171
Volatility::Immutable,
169-
Arc::new(|| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
170-
Arc::new(vec![DataType::UInt64, DataType::Float64]),
172+
Arc::new(|distinct| {
173+
if distinct {
174+
// Use MAX function when DISTINCT is specified as an example
175+
Ok(Box::new(MaxAccumulator::try_new(&DataType::Float64)?))
176+
} else {
177+
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))
178+
}
179+
}),
180+
Arc::new(|data_type, distinct| {
181+
if distinct {
182+
// When DISTINCT is specified, use state type for MAX function
183+
Ok(Arc::new(vec![data_type.clone()]))
184+
} else {
185+
Ok(Arc::new(vec![DataType::UInt64, data_type.clone()]))
186+
}
187+
}),
171188
);
172189

173190
ctx.register_udaf(my_avg);
@@ -178,10 +195,22 @@ async fn simple_udaf() -> Result<()> {
178195
"+-------------+",
179196
"| my_avg(t.a) |",
180197
"+-------------+",
181-
"| 3 |",
198+
"| 11 |",
182199
"+-------------+",
183200
];
184201
assert_batches_eq!(expected, &result);
185202

203+
// also test DISTINCT. in this case it makes MY_AVG act like MAX function
204+
let result = plan_and_collect(&ctx, "SELECT MY_AVG(DISTINCT a) FROM t").await?;
205+
206+
let expected = vec![
207+
"+----------------------+",
208+
"| my_avg(DISTINCT t.a) |",
209+
"+----------------------+",
210+
"| 15 |",
211+
"+----------------------+",
212+
];
213+
assert_batches_eq!(expected, &result);
214+
186215
Ok(())
187216
}

datafusion/expr/src/expr.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ pub enum Expr {
245245
fun: Arc<AggregateUDF>,
246246
/// List of expressions to feed to the functions as arguments
247247
args: Vec<Expr>,
248+
/// Whether this is a DISTINCT aggregation or not
249+
distinct: bool,
248250
},
249251
/// Returns whether the list contains the expr value.
250252
InList {
@@ -635,9 +637,12 @@ impl fmt::Debug for Expr {
635637
}
636638
Ok(())
637639
}
638-
Expr::AggregateUDF { fun, ref args, .. } => {
639-
fmt_function(f, &fun.name, false, args, false)
640-
}
640+
Expr::AggregateUDF {
641+
fun,
642+
ref args,
643+
distinct,
644+
..
645+
} => fmt_function(f, &fun.name, *distinct, args, false),
641646
Expr::Between {
642647
expr,
643648
negated,
@@ -998,12 +1003,17 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
9981003
}
9991004
Ok(parts.join(" "))
10001005
}
1001-
Expr::AggregateUDF { fun, args } => {
1006+
Expr::AggregateUDF {
1007+
fun,
1008+
args,
1009+
distinct,
1010+
} => {
10021011
let mut names = Vec::with_capacity(args.len());
10031012
for e in args {
10041013
names.push(create_name(e, input_schema)?);
10051014
}
1006-
Ok(format!("{}({})", fun.name, names.join(",")))
1015+
let distinct_str = if *distinct { "DISTINCT " } else { "" };
1016+
Ok(format!("{}({}{})", fun.name, distinct_str, names.join(",")))
10071017
}
10081018
Expr::GroupingSet(grouping_set) => match grouping_set {
10091019
GroupingSet::Rollup(exprs) => Ok(format!(

0 commit comments

Comments
 (0)