@@ -95,7 +95,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
9595 aggr_func. params . args . remove ( 1 ) ,
9696 aggr_func. params . args . remove ( 0 ) ,
9797 ) ;
98- let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , false ) ;
98+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , true ) ;
9999 order_by. push ( sort) ;
100100 let func = logical_expr:: expr:: Expr :: AggregateFunction (
101101 logical_expr:: expr:: AggregateFunction :: new_udf (
@@ -193,7 +193,7 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
193193 aggr_func. params . args . remove ( 0 ) ,
194194 ) ;
195195
196- let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , false ) ;
196+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , true ) ;
197197 order_by. push ( sort) ; // false for ascending sort
198198 let func = logical_expr:: expr:: Expr :: AggregateFunction (
199199 logical_expr:: expr:: AggregateFunction :: new_udf (
@@ -326,7 +326,18 @@ mod tests {
326326 #[ cfg( test) ]
327327 mod max_by {
328328 use super :: * ;
329+ async fn extract_string ( df : prelude:: DataFrame ) -> error:: Result < String > {
330+ let results = df. collect ( ) . await ?;
331+ let col = results[ 0 ] . column ( 0 ) ;
332+ let arr = col. as_any ( ) . downcast_ref :: < arrow:: array:: StringArray > ( ) . unwrap ( ) ;
333+ Ok ( arr. value ( 0 ) . to_string ( ) )
334+ }
329335
336+ fn ctx_max ( ) -> error:: Result < prelude:: SessionContext > {
337+ let ctx = prelude:: SessionContext :: new ( ) ;
338+ ctx. register_udaf ( MaxByFunction :: new ( ) . into ( ) ) ;
339+ Ok ( ctx)
340+ }
330341 #[ tokio:: test]
331342 async fn test_max_by_string_int ( ) -> error:: Result < ( ) > {
332343 let query = format ! (
@@ -339,6 +350,41 @@ mod tests {
339350 Ok ( ( ) )
340351 }
341352
353+ #[ tokio:: test]
354+ async fn test_max_by_ignores_nulls_in_ok ( ) -> error:: Result < ( ) > {
355+ let ctx = ctx_max ( ) ?;
356+ let sql = r#"
357+ SELECT max_by(v, k)
358+ FROM (
359+ VALUES
360+ ('a', 1),
361+ ('b', CAST(NULL AS INT)),
362+ ('c', 2)
363+ ) AS t(v, k)
364+ "# ;
365+ let df = ctx. sql ( sql) . await ?;
366+ let got = extract_string ( df) . await ?;
367+ assert_eq ! ( got, "c" , "max_by should ignore NULLs" ) ;
368+ Ok ( ( ) )
369+ }
370+ #[ tokio:: test]
371+ async fn test_max_by_ignores_nulls_in_ko ( ) -> error:: Result < ( ) > {
372+ let ctx = ctx_max ( ) ?;
373+ let sql = r#"
374+ SELECT max_by(v, k)
375+ FROM (
376+ VALUES
377+ ('a', 1),
378+ ('b', CAST(NULL AS INT)),
379+ ('c', 2)
380+ ) AS t(v, k)
381+ "# ;
382+ let df = ctx. sql ( sql) . await ?;
383+ let got = extract_string ( df) . await ?;
384+ assert_eq ! ( got, "b" , "max_by should ignore NULLs" ) ;
385+ Ok ( ( ) )
386+ }
387+
342388 #[ tokio:: test]
343389 async fn test_max_by_string_float ( ) -> error:: Result < ( ) > {
344390 let query = format ! (
0 commit comments