Skip to content

Commit 88e32b3

Browse files
committed
fix: max_by min_by
1 parent 92edc25 commit 88e32b3

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

src/max_min_by.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
9595
aggr_func.params.args.remove(1),
9696
aggr_func.params.args.remove(0),
9797
);
98-
let sort = logical_expr::expr::Sort::new(second_arg, true, false);
98+
let sort = logical_expr::expr::Sort::new(second_arg, true, true);
9999
order_by.push(sort);
100100
let func = logical_expr::expr::Expr::AggregateFunction(
101101
logical_expr::expr::AggregateFunction::new_udf(
@@ -193,7 +193,7 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
193193
aggr_func.params.args.remove(0),
194194
);
195195

196-
let sort = logical_expr::expr::Sort::new(second_arg, false, false);
196+
let sort = logical_expr::expr::Sort::new(second_arg, false, true);
197197
order_by.push(sort); // false for ascending sort
198198
let func = logical_expr::expr::Expr::AggregateFunction(
199199
logical_expr::expr::AggregateFunction::new_udf(
@@ -326,7 +326,18 @@ mod tests {
326326
#[cfg(test)]
327327
mod max_by {
328328
use super::*;
329+
async fn extract_string(df: prelude::DataFrame) -> error::Result<String> {
330+
let results = df.collect().await?;
331+
let col = results[0].column(0);
332+
let arr = col.as_any().downcast_ref::<arrow::array::StringArray>().unwrap();
333+
Ok(arr.value(0).to_string())
334+
}
329335

336+
fn ctx_max() -> error::Result<prelude::SessionContext> {
337+
let ctx = prelude::SessionContext::new();
338+
ctx.register_udaf(MaxByFunction::new().into());
339+
Ok(ctx)
340+
}
330341
#[tokio::test]
331342
async fn test_max_by_string_int() -> error::Result<()> {
332343
let query = format!(
@@ -339,6 +350,41 @@ mod tests {
339350
Ok(())
340351
}
341352

353+
#[tokio::test]
354+
async fn test_max_by_ignores_nulls_in_ok() -> error::Result<()> {
355+
let ctx = ctx_max()?;
356+
let sql = r#"
357+
SELECT max_by(v, k)
358+
FROM (
359+
VALUES
360+
('a', 1),
361+
('b', CAST(NULL AS INT)),
362+
('c', 2)
363+
) AS t(v, k)
364+
"#;
365+
let df = ctx.sql(sql).await?;
366+
let got = extract_string(df).await?;
367+
assert_eq!(got, "c", "max_by should ignore NULLs");
368+
Ok(())
369+
}
370+
#[tokio::test]
371+
async fn test_max_by_ignores_nulls_in_ko() -> error::Result<()> {
372+
let ctx = ctx_max()?;
373+
let sql = r#"
374+
SELECT max_by(v, k)
375+
FROM (
376+
VALUES
377+
('a', 1),
378+
('b', CAST(NULL AS INT)),
379+
('c', 2)
380+
) AS t(v, k)
381+
"#;
382+
let df = ctx.sql(sql).await?;
383+
let got = extract_string(df).await?;
384+
assert_eq!(got, "b", "max_by should ignore NULLs");
385+
Ok(())
386+
}
387+
342388
#[tokio::test]
343389
async fn test_max_by_string_float() -> error::Result<()> {
344390
let query = format!(

0 commit comments

Comments
 (0)