@@ -142,9 +142,13 @@ where
142142 let out = if limit >= vals. len ( ) {
143143 vals. as_mut_slice ( )
144144 } else {
145- let ( lower, _el, _upper) = vals
146- . as_mut_slice ( )
147- . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) ) ;
145+ let ( lower, _el, _upper) = if options. descending {
146+ vals. as_mut_slice ( )
147+ . select_nth_unstable_by ( limit, |a, b| b. 1 . tot_cmp ( & a. 1 ) )
148+ } else {
149+ vals. as_mut_slice ( )
150+ . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) )
151+ } ;
148152 lower
149153 } ;
150154
@@ -235,9 +239,13 @@ where
235239 let out = if limit >= vals. len ( ) {
236240 vals. as_mut_slice ( )
237241 } else {
238- let ( lower, _el, _upper) = vals
239- . as_mut_slice ( )
240- . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) ) ;
242+ let ( lower, _el, _upper) = if options. descending {
243+ vals. as_mut_slice ( )
244+ . select_nth_unstable_by ( limit, |a, b| b. 1 . tot_cmp ( & a. 1 ) )
245+ } else {
246+ vals. as_mut_slice ( )
247+ . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) )
248+ } ;
241249 lower
242250 } ;
243251 sort_impl ( out, options) ;
@@ -326,4 +334,52 @@ mod test {
326334 let idx = reverse_stable_no_nulls ( & a, 0 ) ;
327335 assert_eq ! ( idx. len( ) , 0 ) ;
328336 }
337+
338+ #[ test]
339+ fn test_arg_sort_descending_with_limit ( ) {
340+ let a = Int32Chunked :: new ( PlSmallStr :: from_static ( "a" ) , & [ 4 , 2 , 5 , 1 , 3 ] ) ;
341+ let o = SortOptions {
342+ descending : true ,
343+ nulls_last : false ,
344+ multithreaded : false ,
345+ limit : Some ( 3 ) ,
346+ ..Default :: default ( )
347+ } ;
348+ let r = a. arg_sort ( o) ;
349+ let idx: Vec < IdxSize > = r. into_no_null_iter ( ) . collect ( ) ;
350+ assert_eq ! ( idx, vec![ 2 , 0 , 4 ] ) ;
351+ }
352+
353+ #[ test]
354+ fn test_arg_sort_asc_with_limit ( ) {
355+ let a = Int32Chunked :: new ( PlSmallStr :: from_static ( "a" ) , & [ 4 , 2 , 5 , 1 , 3 ] ) ;
356+ let o = SortOptions {
357+ descending : false ,
358+ nulls_last : false ,
359+ multithreaded : false ,
360+ limit : Some ( 3 ) ,
361+ ..Default :: default ( )
362+ } ;
363+ let r = a. arg_sort ( o) ;
364+ let idx: Vec < IdxSize > = r. into_no_null_iter ( ) . collect ( ) ;
365+ assert_eq ! ( idx, vec![ 3 , 1 , 4 ] ) ;
366+ }
367+
368+ #[ test]
369+ fn test_arg_sort_desc_limit_nulls ( ) {
370+ let a = Int32Chunked :: new (
371+ PlSmallStr :: from_static ( "a" ) ,
372+ & [ Some ( 4 ) , None , Some ( 5 ) , Some ( 1 ) , None , Some ( 3 ) ] ,
373+ ) ;
374+ let o = SortOptions {
375+ descending : true ,
376+ nulls_last : true ,
377+ multithreaded : false ,
378+ limit : Some ( 3 ) ,
379+ ..Default :: default ( )
380+ } ;
381+ let r = a. arg_sort ( o) ;
382+ let idx: Vec < IdxSize > = r. into_no_null_iter ( ) . collect ( ) ;
383+ assert_eq ! ( idx, vec![ 2 , 0 , 5 ] ) ;
384+ }
329385}
0 commit comments