@@ -13,6 +13,7 @@ make_udaf_expr_and_func!(
1313
1414#[ derive( Eq , Hash , PartialEq ) ]
1515pub struct MaxByFunction {
16+ null_first : bool ,
1617 signature : logical_expr:: Signature ,
1718}
1819
@@ -27,13 +28,14 @@ impl fmt::Debug for MaxByFunction {
2728}
2829impl Default for MaxByFunction {
2930 fn default ( ) -> Self {
30- Self :: new ( )
31+ Self :: new ( true )
3132 }
3233}
3334
3435impl MaxByFunction {
35- pub fn new ( ) -> Self {
36+ pub fn new ( null_first : bool ) -> Self {
3637 Self {
38+ null_first,
3739 signature : logical_expr:: Signature :: user_defined ( logical_expr:: Volatility :: Immutable ) ,
3840 }
3941 }
@@ -80,6 +82,7 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8082 ) -> error:: Result < Box < dyn logical_expr:: Accumulator > > {
8183 common:: exec_err!( "should not reach here" )
8284 }
85+
8386 fn coerce_types (
8487 & self ,
8588 arg_types : & [ arrow:: datatypes:: DataType ] ,
@@ -88,25 +91,25 @@ impl logical_expr::AggregateUDFImpl for MaxByFunction {
8891 }
8992
9093 fn simplify ( & self ) -> Option < logical_expr:: function:: AggregateFunctionSimplification > {
91- let simplify = |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
92- _: & dyn logical_expr:: simplify:: SimplifyInfo | {
94+ let null_first = self . null_first ;
95+ let simplify = move |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
96+ _: & dyn logical_expr:: simplify:: SimplifyInfo | {
9397 let mut order_by = aggr_func. params . order_by ;
9498 let ( second_arg, first_arg) = (
9599 aggr_func. params . args . remove ( 1 ) ,
96100 aggr_func. params . args . remove ( 0 ) ,
97101 ) ;
98- let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , true ) ;
102+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, true , null_first ) ;
99103 order_by. push ( sort) ;
100- let func = logical_expr:: expr:: Expr :: AggregateFunction (
101- logical_expr:: expr:: AggregateFunction :: new_udf (
102- functions_aggregate:: first_last:: last_value_udaf ( ) ,
103- vec ! [ first_arg] ,
104- aggr_func. params . distinct ,
105- aggr_func. params . filter ,
106- order_by,
107- aggr_func. params . null_treatment ,
108- ) ,
104+ let func = logical_expr:: expr:: AggregateFunction :: new_udf (
105+ functions_aggregate:: first_last:: last_value_udaf ( ) ,
106+ vec ! [ first_arg] ,
107+ aggr_func. params . distinct ,
108+ aggr_func. params . filter ,
109+ order_by,
110+ aggr_func. params . null_treatment ,
109111 ) ;
112+ let func = logical_expr:: expr:: Expr :: AggregateFunction ( func) ;
110113 Ok ( func)
111114 } ;
112115 Some ( Box :: new ( simplify) )
@@ -123,6 +126,7 @@ make_udaf_expr_and_func!(
123126
124127#[ derive( Eq , Hash , PartialEq ) ]
125128pub struct MinByFunction {
129+ null_first : bool ,
126130 signature : logical_expr:: Signature ,
127131}
128132
@@ -138,13 +142,14 @@ impl fmt::Debug for MinByFunction {
138142
139143impl Default for MinByFunction {
140144 fn default ( ) -> Self {
141- Self :: new ( )
145+ Self :: new ( true )
142146 }
143147}
144148
145149impl MinByFunction {
146- pub fn new ( ) -> Self {
150+ pub fn new ( null_first : bool ) -> Self {
147151 Self {
152+ null_first,
148153 signature : logical_expr:: Signature :: user_defined ( logical_expr:: Volatility :: Immutable ) ,
149154 }
150155 }
@@ -185,26 +190,26 @@ impl logical_expr::AggregateUDFImpl for MinByFunction {
185190 }
186191
187192 fn simplify ( & self ) -> Option < logical_expr:: function:: AggregateFunctionSimplification > {
188- let simplify = |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
189- _: & dyn logical_expr:: simplify:: SimplifyInfo | {
193+ let null_first = self . null_first ;
194+ let simplify = move |mut aggr_func : logical_expr:: expr:: AggregateFunction ,
195+ _: & dyn logical_expr:: simplify:: SimplifyInfo | {
190196 let mut order_by = aggr_func. params . order_by ;
191197 let ( second_arg, first_arg) = (
192198 aggr_func. params . args . remove ( 1 ) ,
193199 aggr_func. params . args . remove ( 0 ) ,
194200 ) ;
195201
196- let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , true ) ;
202+ let sort = logical_expr:: expr:: Sort :: new ( second_arg, false , null_first ) ;
197203 order_by. push ( sort) ; // false for ascending sort
198- let func = logical_expr:: expr:: Expr :: AggregateFunction (
199- logical_expr:: expr:: AggregateFunction :: new_udf (
200- functions_aggregate:: first_last:: last_value_udaf ( ) ,
201- vec ! [ first_arg] ,
202- aggr_func. params . distinct ,
203- aggr_func. params . filter ,
204- order_by,
205- aggr_func. params . null_treatment ,
206- ) ,
204+ let func = logical_expr:: expr:: AggregateFunction :: new_udf (
205+ functions_aggregate:: first_last:: last_value_udaf ( ) ,
206+ vec ! [ first_arg] ,
207+ aggr_func. params . distinct ,
208+ aggr_func. params . filter ,
209+ order_by,
210+ aggr_func. params . null_treatment ,
207211 ) ;
212+ let func = logical_expr:: expr:: Expr :: AggregateFunction ( func) ;
208213 Ok ( func)
209214 } ;
210215 Some ( Box :: new ( simplify) )
@@ -399,33 +404,15 @@ mod tests {
399404 ('c', 2)
400405 ) AS t(v, k)
401406 "# ;
402- let df = ctx ( ) ?. sql ( & query) . await ?;
407+ let df = ctx ( ) ?. sql ( query) . await ?;
403408 let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
404409 assert_eq ! ( result, "c" , "max_by should ignore NULLs" ) ;
405410 Ok ( ( ) )
406411 }
407412
408- #[ tokio:: test]
409- async fn test_max_like_main_test ( ) -> error:: Result < ( ) > {
410- let query = r#"
411- SELECT max_by(v, k)
412- FROM (
413- VALUES
414- (1, 10),
415- (2, 5),
416- (3, 15),
417- (4, 8)
418- ) AS t(v, k)
419- "# ;
420- let df = ctx ( ) ?. sql ( & query) . await ?;
421- let result = extract_single_value :: < i64 , arrow:: array:: Int64Array > ( df) . await ?;
422- assert_eq ! ( result, 3 ) ;
423- Ok ( ( ) )
424- }
425-
426413 fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
427414 let ctx = test_ctx ( ) ?;
428- let max_by_udaf = MaxByFunction :: new ( ) ;
415+ let max_by_udaf = MaxByFunction :: default ( ) ;
429416 ctx. register_udaf ( max_by_udaf. into ( ) ) ;
430417 Ok ( ctx)
431418 }
@@ -507,51 +494,15 @@ mod tests {
507494 ('c', 2)
508495 ) AS t(v, k)
509496 "# ;
510- let df = ctx ( ) ?. sql ( & query) . await ?;
497+ let df = ctx ( ) ?. sql ( query) . await ?;
511498 let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
512499 assert_eq ! ( result, "a" , "min_by should ignore NULLs" ) ;
513500 Ok ( ( ) )
514501 }
515502
516- #[ tokio:: test]
517- async fn test_min_like_main_test_str ( ) -> error:: Result < ( ) > {
518- let query = r#"
519- SELECT min_by(v, k)
520- FROM (
521- VALUES
522- ('a', 10),
523- ('b', 5),
524- ('c', 15),
525- ('d', 8)
526- ) AS t(v, k)
527- "# ;
528- let df = ctx ( ) ?. sql ( & query) . await ?;
529- let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
530- assert_eq ! ( result, "b" ) ;
531- Ok ( ( ) )
532- }
533-
534- #[ tokio:: test]
535- async fn test_min_like_main_test_int ( ) -> error:: Result < ( ) > {
536- let query = r#"
537- SELECT min_by(v, k)
538- FROM (
539- VALUES
540- (1, 10),
541- (2, 5),
542- (3, 15),
543- (4, 8)
544- ) AS t(v, k)
545- "# ;
546- let df = ctx ( ) ?. sql ( & query) . await ?;
547- let result = extract_single_value :: < i64 , arrow:: array:: Int64Array > ( df) . await ?;
548- assert_eq ! ( result, 2 ) ;
549- Ok ( ( ) )
550- }
551-
552503 fn ctx ( ) -> error:: Result < prelude:: SessionContext > {
553504 let ctx = test_ctx ( ) ?;
554- let min_by_udaf = MinByFunction :: new ( ) ;
505+ let min_by_udaf = MinByFunction :: default ( ) ;
555506 ctx. register_udaf ( min_by_udaf. into ( ) ) ;
556507 Ok ( ctx)
557508 }
0 commit comments