@@ -13,6 +13,7 @@ make_udaf_expr_and_func!(
1313
1414#[ derive( Eq , Hash , PartialEq ) ]
1515pub struct MaxByFunction {
16+ null_first : bool ,
1617 signature : logical_expr:: Signature ,
1718}
1819
@@ -27,13 +28,14 @@ impl fmt::Debug for MaxByFunction {
2728}
2829impl Default for MaxByFunction {
2930 fn default ( ) -> Self {
30- Self :: new ( )
31+ Self :: new ( true )
3132 }
3233}
3334
3435impl MaxByFunction {
35- pub fn new ( ) -> Self {
36+ pub fn new ( null_first : bool ) -> Self {
3637 Self {
38+ null_first,
3739 signature : logical_expr:: Signature :: user_defined ( logical_expr:: Volatility :: Immutable ) ,
3840 }
3941 }
@@ -80,6 +82,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8082 ) -> error:: Result < Box < dyn logical_expr:: Accumulator > > {
8183 common:: exec_err!( "should not reach here" )
8284 }
85+
8386 fn coerce_types (
8487 & self ,
8588 arg_types : & [ arrow:: datatypes:: DataType ] ,
@@ -88,25 +91,25 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8891 }
8992
9093 fn simplify ( & self ) -> Option < logical_expr:: function:: AggregateFunctionSimplification > {
91- let simplify = |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
92- _: & dyn logical_expr:: simplify:: SimplifyInfo | {
94+ let null_first = self . null_first ;
95+ let simplify = move |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
96+ _: & dyn logical_expr:: simplify:: SimplifyInfo | {
9397 let mut order_by = aggr_func. params . order_by ;
9498 let ( second_arg, first_arg) = (
9599 aggr_func. params . args . remove ( 1 ) ,
96100 aggr_func. params . args . remove ( 0 ) ,
97101 ) ;
98- let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , false ) ;
102+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , null_first ) ;
99103 order_by. push ( sort) ;
100- let func = logical_expr:: expr:: Expr :: AggregateFunction (
101- logical_expr:: expr:: AggregateFunction :: new_udf (
102- functions_aggregate:: first_last:: last_value_udaf ( ) ,
103- vec ! [ first_arg] ,
104- aggr_func. params . distinct ,
105- aggr_func. params . filter ,
106- order_by,
107- aggr_func. params . null_treatment ,
108- ) ,
104+ let func = logical_expr:: expr:: AggregateFunction :: new_udf (
105+ functions_aggregate:: first_last:: last_value_udaf ( ) ,
106+ vec ! [ first_arg] ,
107+ aggr_func. params . distinct ,
108+ aggr_func. params . filter ,
109+ order_by,
110+ aggr_func. params . null_treatment ,
109111 ) ;
112+ let func = logical_expr:: expr:: Expr :: AggregateFunction ( func) ;
110113 Ok ( func)
111114 } ;
112115 Some ( Box :: new ( simplify) )
@@ -123,6 +126,7 @@ make_udaf_expr_and_func!(
123126
124127#[ derive( Eq , Hash , PartialEq ) ]
125128pub struct MinByFunction {
129+ null_first : bool ,
126130 signature : logical_expr:: Signature ,
127131}
128132
@@ -138,13 +142,14 @@ impl fmt::Debug for MinByFunction {
138142
139143impl Default for MinByFunction {
140144 fn default ( ) -> Self {
141- Self :: new ( )
145+ Self :: new ( true )
142146 }
143147}
144148
145149impl MinByFunction {
146- pub fn new ( ) -> Self {
150+ pub fn new ( null_first : bool ) -> Self {
147151 Self {
152+ null_first,
148153 signature : logical_expr:: Signature :: user_defined ( logical_expr:: Volatility :: Immutable ) ,
149154 }
150155 }
@@ -185,26 +190,26 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
185190 }
186191
187192 fn simplify ( & self ) -> Option < logical_expr:: function:: AggregateFunctionSimplification > {
188- let simplify = |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
189- _: & dyn logical_expr:: simplify:: SimplifyInfo | {
193+ let null_first = self . null_first ;
194+ let simplify = move |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
195+ _: & dyn logical_expr:: simplify:: SimplifyInfo | {
190196 let mut order_by = aggr_func. params . order_by ;
191197 let ( second_arg, first_arg) = (
192198 aggr_func. params . args . remove ( 1 ) ,
193199 aggr_func. params . args . remove ( 0 ) ,
194200 ) ;
195201
196- let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , false ) ;
202+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , null_first ) ;
197203 order_by. push ( sort) ; // false for ascending sort
198- let func = logical_expr:: expr:: Expr :: AggregateFunction (
199- logical_expr:: expr:: AggregateFunction :: new_udf (
200- functions_aggregate:: first_last:: last_value_udaf ( ) ,
201- vec ! [ first_arg] ,
202- aggr_func. params . distinct ,
203- aggr_func. params . filter ,
204- order_by,
205- aggr_func. params . null_treatment ,
206- ) ,
204+ let func = logical_expr:: expr:: AggregateFunction :: new_udf (
205+ functions_aggregate:: first_last:: last_value_udaf ( ) ,
206+ vec ! [ first_arg] ,
207+ aggr_func. params . distinct ,
208+ aggr_func. params . filter ,
209+ order_by,
210+ aggr_func. params . null_treatment ,
207211 ) ;
212+ let func = logical_expr:: expr:: Expr :: AggregateFunction ( func) ;
208213 Ok ( func)
209214 } ;
210215 Some ( Box :: new ( simplify) )
@@ -325,6 +330,7 @@ mod tests {
325330
326331 #[ cfg( test) ]
327332 mod max_by {
333+
328334 use super :: * ;
329335
330336 #[ tokio:: test]
@@ -387,9 +393,26 @@ mod tests {
387393 Ok ( ( ) )
388394 }
389395
396+ #[ tokio:: test]
397+ async fn test_max_by_ignores_nulls ( ) -> error:: Result < ( ) > {
398+ let query = r#"
399+ SELECT max_by(v, k)
400+ FROM (
401+ VALUES
402+ ('a', 1),
403+ ('b', CAST(NULL AS INT)),
404+ ('c', 2)
405+ ) AS t(v, k)
406+ "# ;
407+ let df = ctx ( ) ?. sql ( query) . await ?;
408+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
409+ assert_eq ! ( result, "c" , "max_by should ignore NULLs" ) ;
410+ Ok ( ( ) )
411+ }
412+
390413 fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
391414 let ctx = test_ctx ( ) ?;
392- let max_by_udaf = MaxByFunction :: new ( ) ;
415+ let max_by_udaf = MaxByFunction :: default ( ) ;
393416 ctx. register_udaf ( max_by_udaf. into ( ) ) ;
394417 Ok ( ctx)
395418 }
@@ -460,9 +483,26 @@ mod tests {
460483 Ok ( ( ) )
461484 }
462485
486+ #[ tokio:: test]
487+ async fn test_min_by_ignores_nulls ( ) -> error:: Result < ( ) > {
488+ let query = r#"
489+ SELECT min_by(v, k)
490+ FROM (
491+ VALUES
492+ ('a', 1),
493+ ('b', CAST(NULL AS INT)),
494+ ('c', 2)
495+ ) AS t(v, k)
496+ "# ;
497+ let df = ctx ( ) ?. sql ( query) . await ?;
498+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
499+ assert_eq ! ( result, "a" , "min_by should ignore NULLs" ) ;
500+ Ok ( ( ) )
501+ }
502+
463503 fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
464504 let ctx = test_ctx ( ) ?;
465- let min_by_udaf = MinByFunction :: new ( ) ;
505+ let min_by_udaf = MinByFunction :: default ( ) ;
466506 ctx. register_udaf ( min_by_udaf. into ( ) ) ;
467507 Ok ( ctx)
468508 }
0 commit comments