Skip to content

Commit 86bbfc1

Browse files
davidlghellindariocurr
authored andcommitted
feat: enhance min/max by functions
1 parent c385088 commit 86bbfc1

File tree

3 files changed

+73
-33
lines changed

3 files changed

+73
-33
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[package]
1919
name = "datafusion-extra-functions"
20-
version = "0.4.0"
20+
version = "0.5.0"
2121
edition = "2024"
2222
description = "Extra Functions for DataFusion"
2323
readme = "README.md"

src/max_min_by.rs

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ make_udaf_expr_and_func!(
1313

1414
#[derive(Eq, Hash, PartialEq)]
1515
pub struct MaxByFunction {
16+
null_first: bool,
1617
signature: logical_expr::Signature,
1718
}
1819

@@ -27,13 +28,14 @@ impl fmt::Debug for MaxByFunction {
2728
}
2829
impl Default for MaxByFunction {
2930
fn default() -> Self {
30-
Self::new()
31+
Self::new(true)
3132
}
3233
}
3334

3435
impl MaxByFunction {
35-
pub fn new() -> Self {
36+
pub fn new(null_first: bool) -> Self {
3637
Self {
38+
null_first,
3739
signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
3840
}
3941
}
@@ -80,6 +82,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8082
) -> error::Result<Box<dyn logical_expr::Accumulator>> {
8183
common::exec_err!("should not reach here")
8284
}
85+
8386
fn coerce_types(
8487
&self,
8588
arg_types: &[arrow::datatypes::DataType],
@@ -88,25 +91,25 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8891
}
8992

9093
fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
91-
let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
92-
_: &dyn logical_expr::simplify::SimplifyInfo| {
94+
let null_first = self.null_first;
95+
let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction,
96+
_: &dyn logical_expr::simplify::SimplifyInfo| {
9397
let mut order_by = aggr_func.params.order_by;
9498
let (second_arg, first_arg) = (
9599
aggr_func.params.args.remove(1),
96100
aggr_func.params.args.remove(0),
97101
);
98-
let sort = logical_expr::expr::Sort::new(second_arg, true, false);
102+
let sort = logical_expr::expr::Sort::new(second_arg, true, null_first);
99103
order_by.push(sort);
100-
let func = logical_expr::expr::Expr::AggregateFunction(
101-
logical_expr::expr::AggregateFunction::new_udf(
102-
functions_aggregate::first_last::last_value_udaf(),
103-
vec![first_arg],
104-
aggr_func.params.distinct,
105-
aggr_func.params.filter,
106-
order_by,
107-
aggr_func.params.null_treatment,
108-
),
104+
let func = logical_expr::expr::AggregateFunction::new_udf(
105+
functions_aggregate::first_last::last_value_udaf(),
106+
vec![first_arg],
107+
aggr_func.params.distinct,
108+
aggr_func.params.filter,
109+
order_by,
110+
aggr_func.params.null_treatment,
109111
);
112+
let func = logical_expr::expr::Expr::AggregateFunction(func);
110113
Ok(func)
111114
};
112115
Some(Box::new(simplify))
@@ -123,6 +126,7 @@ make_udaf_expr_and_func!(
123126

124127
#[derive(Eq, Hash, PartialEq)]
125128
pub struct MinByFunction {
129+
null_first: bool,
126130
signature: logical_expr::Signature,
127131
}
128132

@@ -138,13 +142,14 @@ impl fmt::Debug for MinByFunction {
138142

139143
impl Default for MinByFunction {
140144
fn default() -> Self {
141-
Self::new()
145+
Self::new(true)
142146
}
143147
}
144148

145149
impl MinByFunction {
146-
pub fn new() -> Self {
150+
pub fn new(null_first: bool) -> Self {
147151
Self {
152+
null_first,
148153
signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
149154
}
150155
}
@@ -185,26 +190,26 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
185190
}
186191

187192
fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
188-
let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
189-
_: &dyn logical_expr::simplify::SimplifyInfo| {
193+
let null_first = self.null_first;
194+
let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction,
195+
_: &dyn logical_expr::simplify::SimplifyInfo| {
190196
let mut order_by = aggr_func.params.order_by;
191197
let (second_arg, first_arg) = (
192198
aggr_func.params.args.remove(1),
193199
aggr_func.params.args.remove(0),
194200
);
195201

196-
let sort = logical_expr::expr::Sort::new(second_arg, false, false);
202+
let sort = logical_expr::expr::Sort::new(second_arg, false, null_first);
197203
order_by.push(sort); // false for ascending sort
198-
let func = logical_expr::expr::Expr::AggregateFunction(
199-
logical_expr::expr::AggregateFunction::new_udf(
200-
functions_aggregate::first_last::last_value_udaf(),
201-
vec![first_arg],
202-
aggr_func.params.distinct,
203-
aggr_func.params.filter,
204-
order_by,
205-
aggr_func.params.null_treatment,
206-
),
204+
let func = logical_expr::expr::AggregateFunction::new_udf(
205+
functions_aggregate::first_last::last_value_udaf(),
206+
vec![first_arg],
207+
aggr_func.params.distinct,
208+
aggr_func.params.filter,
209+
order_by,
210+
aggr_func.params.null_treatment,
207211
);
212+
let func = logical_expr::expr::Expr::AggregateFunction(func);
208213
Ok(func)
209214
};
210215
Some(Box::new(simplify))
@@ -325,6 +330,7 @@ mod tests {
325330

326331
#[cfg(test)]
327332
mod max_by {
333+
328334
use super::*;
329335

330336
#[tokio::test]
@@ -387,9 +393,26 @@ mod tests {
387393
Ok(())
388394
}
389395

396+
#[tokio::test]
397+
async fn test_max_by_ignores_nulls() -> error::Result<()> {
398+
let query = r#"
399+
SELECT max_by(v, k)
400+
FROM (
401+
VALUES
402+
('a', 1),
403+
('b', CAST(NULL AS INT)),
404+
('c', 2)
405+
) AS t(v, k)
406+
"#;
407+
let df = ctx()?.sql(query).await?;
408+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
409+
assert_eq!(result, "c", "max_by should ignore NULLs");
410+
Ok(())
411+
}
412+
390413
fn ctx() -> error::Result<prelude::SessionContext> {
391414
let ctx = test_ctx()?;
392-
let max_by_udaf = MaxByFunction::new();
415+
let max_by_udaf = MaxByFunction::default();
393416
ctx.register_udaf(max_by_udaf.into());
394417
Ok(ctx)
395418
}
@@ -460,9 +483,26 @@ mod tests {
460483
Ok(())
461484
}
462485

486+
#[tokio::test]
487+
async fn test_min_by_ignores_nulls() -> error::Result<()> {
488+
let query = r#"
489+
SELECT min_by(v, k)
490+
FROM (
491+
VALUES
492+
('a', 1),
493+
('b', CAST(NULL AS INT)),
494+
('c', 2)
495+
) AS t(v, k)
496+
"#;
497+
let df = ctx()?.sql(query).await?;
498+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
499+
assert_eq!(result, "a", "min_by should ignore NULLs");
500+
Ok(())
501+
}
502+
463503
fn ctx() -> error::Result<prelude::SessionContext> {
464504
let ctx = test_ctx()?;
465-
let min_by_udaf = MinByFunction::new();
505+
let min_by_udaf = MinByFunction::default();
466506
ctx.register_udaf(min_by_udaf.into());
467507
Ok(ctx)
468508
}

tests/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ async fn test_max_by_and_min_by() {
185185
- +---------------------+
186186
- "| max_by(tab.x,tab.y) |"
187187
- +---------------------+
188-
- "| 2 |"
188+
- "| 3 |"
189189
- +---------------------+
190190
"###);
191191

@@ -200,7 +200,7 @@ async fn test_max_by_and_min_by() {
200200
- +---------------------+
201201
- "| min_by(tab.x,tab.y) |"
202202
- +---------------------+
203-
- "| 2 |"
203+
- "| |"
204204
- +---------------------+
205205
"###);
206206

0 commit comments

Comments
 (0)