Skip to content

Commit 56a2348

Browse files
committed
feat: enhance min/max by functions
1 parent e244b18 commit 56a2348

File tree

1 file changed

+37
-86
lines changed

1 file changed

+37
-86
lines changed

src/max_min_by.rs

Lines changed: 37 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ make_udaf_expr_and_func!(
1313

1414
#[derive(Eq, Hash, PartialEq)]
1515
pub struct MaxByFunction {
16+
null_first: bool,
1617
signature: logical_expr::Signature,
1718
}
1819

@@ -27,13 +28,14 @@ impl fmt::Debug for MaxByFunction {
2728
}
2829
impl Default for MaxByFunction {
2930
fn default() -> Self {
30-
Self::new()
31+
Self::new(true)
3132
}
3233
}
3334

3435
impl MaxByFunction {
35-
pub fn new() -> Self {
36+
pub fn new(null_first: bool) -> Self {
3637
Self {
38+
null_first,
3739
signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
3840
}
3941
}
@@ -80,6 +82,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8082
) -> error::Result<Box<dyn logical_expr::Accumulator>> {
8183
common::exec_err!("should not reach here")
8284
}
85+
8386
fn coerce_types(
8487
&self,
8588
arg_types: &[arrow::datatypes::DataType],
@@ -88,25 +91,25 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8891
}
8992

9093
fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
91-
let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
92-
_: &dyn logical_expr::simplify::SimplifyInfo| {
94+
let null_first = self.null_first;
95+
let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction,
96+
_: &dyn logical_expr::simplify::SimplifyInfo| {
9397
let mut order_by = aggr_func.params.order_by;
9498
let (second_arg, first_arg) = (
9599
aggr_func.params.args.remove(1),
96100
aggr_func.params.args.remove(0),
97101
);
98-
let sort = logical_expr::expr::Sort::new(second_arg, true, true);
102+
let sort = logical_expr::expr::Sort::new(second_arg, true, null_first);
99103
order_by.push(sort);
100-
let func = logical_expr::expr::Expr::AggregateFunction(
101-
logical_expr::expr::AggregateFunction::new_udf(
102-
functions_aggregate::first_last::last_value_udaf(),
103-
vec![first_arg],
104-
aggr_func.params.distinct,
105-
aggr_func.params.filter,
106-
order_by,
107-
aggr_func.params.null_treatment,
108-
),
104+
let func = logical_expr::expr::AggregateFunction::new_udf(
105+
functions_aggregate::first_last::last_value_udaf(),
106+
vec![first_arg],
107+
aggr_func.params.distinct,
108+
aggr_func.params.filter,
109+
order_by,
110+
aggr_func.params.null_treatment,
109111
);
112+
let func = logical_expr::expr::Expr::AggregateFunction(func);
110113
Ok(func)
111114
};
112115
Some(Box::new(simplify))
@@ -123,6 +126,7 @@ make_udaf_expr_and_func!(
123126

124127
#[derive(Eq, Hash, PartialEq)]
125128
pub struct MinByFunction {
129+
null_first: bool,
126130
signature: logical_expr::Signature,
127131
}
128132

@@ -138,13 +142,14 @@ impl fmt::Debug for MinByFunction {
138142

139143
impl Default for MinByFunction {
140144
fn default() -> Self {
141-
Self::new()
145+
Self::new(true)
142146
}
143147
}
144148

145149
impl MinByFunction {
146-
pub fn new() -> Self {
150+
pub fn new(null_first: bool) -> Self {
147151
Self {
152+
null_first,
148153
signature: logical_expr::Signature::user_defined(logical_expr::Volatility::Immutable),
149154
}
150155
}
@@ -185,26 +190,26 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
185190
}
186191

187192
fn simplify(&self) -> Option<logical_expr::function::AggregateFunctionSimplification> {
188-
let simplify = |mut aggr_func: logical_expr::expr::AggregateFunction,
189-
_: &dyn logical_expr::simplify::SimplifyInfo| {
193+
let null_first = self.null_first;
194+
let simplify = move |mut aggr_func: logical_expr::expr::AggregateFunction,
195+
_: &dyn logical_expr::simplify::SimplifyInfo| {
190196
let mut order_by = aggr_func.params.order_by;
191197
let (second_arg, first_arg) = (
192198
aggr_func.params.args.remove(1),
193199
aggr_func.params.args.remove(0),
194200
);
195201

196-
let sort = logical_expr::expr::Sort::new(second_arg, false, true);
202+
let sort = logical_expr::expr::Sort::new(second_arg, false, null_first);
197203
order_by.push(sort); // false for ascending sort
198-
let func = logical_expr::expr::Expr::AggregateFunction(
199-
logical_expr::expr::AggregateFunction::new_udf(
200-
functions_aggregate::first_last::last_value_udaf(),
201-
vec![first_arg],
202-
aggr_func.params.distinct,
203-
aggr_func.params.filter,
204-
order_by,
205-
aggr_func.params.null_treatment,
206-
),
204+
let func = logical_expr::expr::AggregateFunction::new_udf(
205+
functions_aggregate::first_last::last_value_udaf(),
206+
vec![first_arg],
207+
aggr_func.params.distinct,
208+
aggr_func.params.filter,
209+
order_by,
210+
aggr_func.params.null_treatment,
207211
);
212+
let func = logical_expr::expr::Expr::AggregateFunction(func);
208213
Ok(func)
209214
};
210215
Some(Box::new(simplify))
@@ -399,33 +404,15 @@ mod tests {
399404
('c', 2)
400405
) AS t(v, k)
401406
"#;
402-
let df = ctx()?.sql(&query).await?;
407+
let df = ctx()?.sql(query).await?;
403408
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
404409
assert_eq!(result, "c", "max_by should ignore NULLs");
405410
Ok(())
406411
}
407412

408-
#[tokio::test]
409-
async fn test_max_like_main_test() -> error::Result<()> {
410-
let query = r#"
411-
SELECT max_by(v, k)
412-
FROM (
413-
VALUES
414-
(1, 10),
415-
(2, 5),
416-
(3, 15),
417-
(4, 8)
418-
) AS t(v, k)
419-
"#;
420-
let df = ctx()?.sql(&query).await?;
421-
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
422-
assert_eq!(result, 3);
423-
Ok(())
424-
}
425-
426413
fn ctx() -> error::Result<prelude::SessionContext> {
427414
let ctx = test_ctx()?;
428-
let max_by_udaf = MaxByFunction::new();
415+
let max_by_udaf = MaxByFunction::default();
429416
ctx.register_udaf(max_by_udaf.into());
430417
Ok(ctx)
431418
}
@@ -507,51 +494,15 @@ mod tests {
507494
('c', 2)
508495
) AS t(v, k)
509496
"#;
510-
let df = ctx()?.sql(&query).await?;
497+
let df = ctx()?.sql(query).await?;
511498
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
512499
assert_eq!(result, "a", "min_by should ignore NULLs");
513500
Ok(())
514501
}
515502

516-
#[tokio::test]
517-
async fn test_min_like_main_test_str() -> error::Result<()> {
518-
let query = r#"
519-
SELECT min_by(v, k)
520-
FROM (
521-
VALUES
522-
('a', 10),
523-
('b', 5),
524-
('c', 15),
525-
('d', 8)
526-
) AS t(v, k)
527-
"#;
528-
let df = ctx()?.sql(&query).await?;
529-
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
530-
assert_eq!(result, "b");
531-
Ok(())
532-
}
533-
534-
#[tokio::test]
535-
async fn test_min_like_main_test_int() -> error::Result<()> {
536-
let query = r#"
537-
SELECT min_by(v, k)
538-
FROM (
539-
VALUES
540-
(1, 10),
541-
(2, 5),
542-
(3, 15),
543-
(4, 8)
544-
) AS t(v, k)
545-
"#;
546-
let df = ctx()?.sql(&query).await?;
547-
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
548-
assert_eq!(result, 2);
549-
Ok(())
550-
}
551-
552503
fn ctx() -> error::Result<prelude::SessionContext> {
553504
let ctx = test_ctx()?;
554-
let min_by_udaf = MinByFunction::new();
505+
let min_by_udaf = MinByFunction::default();
555506
ctx.register_udaf(min_by_udaf.into());
556507
Ok(ctx)
557508
}

0 commit comments

Comments
 (0)