@@ -142,9 +142,13 @@ where
142142 let out = if limit >= vals. len ( ) {
143143 vals. as_mut_slice ( )
144144 } else {
145- let ( lower, _el, _upper) = vals
146- . as_mut_slice ( )
147- . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) ) ;
145+ let ( lower, _el, _upper) = if options. descending {
146+ vals. as_mut_slice ( )
147+ . select_nth_unstable_by ( limit, |a, b| b. 1 . tot_cmp ( & a. 1 ) )
148+ } else {
149+ vals. as_mut_slice ( )
150+ . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) )
151+ } ;
148152 lower
149153 } ;
150154
@@ -235,9 +239,13 @@ where
235239 let out = if limit >= vals. len ( ) {
236240 vals. as_mut_slice ( )
237241 } else {
238- let ( lower, _el, _upper) = vals
239- . as_mut_slice ( )
240- . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) ) ;
242+ let ( lower, _el, _upper) = if options. descending {
243+ vals. as_mut_slice ( )
244+ . select_nth_unstable_by ( limit, |a, b| b. 1 . tot_cmp ( & a. 1 ) )
245+ } else {
246+ vals. as_mut_slice ( )
247+ . select_nth_unstable_by ( limit, |a, b| a. 1 . tot_cmp ( & b. 1 ) )
248+ } ;
241249 lower
242250 } ;
243251 sort_impl ( out, options) ;
@@ -326,4 +334,111 @@ mod test {
326334 let idx = reverse_stable_no_nulls ( & a, 0 ) ;
327335 assert_eq ! ( idx. len( ) , 0 ) ;
328336 }
337+
338+ #[ test]
339+ fn test_arg_sort_descending_with_limit ( ) {
340+ let a = Int32Chunked :: new (
341+ PlSmallStr :: from_static ( "a" ) ,
342+ & [ 4 , 2 , 5 , 1 , 3 ] ,
343+ ) ;
344+
345+ let options = SortOptions {
346+ descending : true ,
347+ nulls_last : false ,
348+ multithreaded : false ,
349+ limit : Some ( 3 ) ,
350+ ..Default :: default ( )
351+ } ;
352+ let result = a. arg_sort ( options) ;
353+ let idx: Vec < IdxSize > = result. into_no_null_iter ( ) . collect ( ) ;
354+ // descending top-3: values 5(idx=2), 4(idx=0), 3(idx=4)
355+ assert_eq ! ( idx, vec![ 2 , 0 , 4 ] ) ;
356+ }
357+
358+ #[ test]
359+ fn test_arg_sort_ascending_with_limit ( ) {
360+ let a = Int32Chunked :: new (
361+ PlSmallStr :: from_static ( "a" ) ,
362+ & [ 4 , 2 , 5 , 1 , 3 ] ,
363+ ) ;
364+
365+ let options = SortOptions {
366+ descending : false ,
367+ nulls_last : false ,
368+ multithreaded : false ,
369+ limit : Some ( 3 ) ,
370+ ..Default :: default ( )
371+ } ;
372+ let result = a. arg_sort ( options) ;
373+ let idx: Vec < IdxSize > = result. into_no_null_iter ( ) . collect ( ) ;
374+ // ascending top-3: values 1(idx=3), 2(idx=1), 3(idx=4)
375+ assert_eq ! ( idx, vec![ 3 , 1 , 4 ] ) ;
376+ }
377+
378+ #[ test]
379+ fn test_arg_sort_descending_limit_with_nulls ( ) {
380+ let a = Int32Chunked :: new (
381+ PlSmallStr :: from_static ( "a" ) ,
382+ & [
383+ Some ( 4 ) ,
384+ None ,
385+ Some ( 5 ) ,
386+ Some ( 1 ) ,
387+ None ,
388+ Some ( 3 ) ,
389+ ] ,
390+ ) ;
391+
392+ let options = SortOptions {
393+ descending : true ,
394+ nulls_last : true ,
395+ multithreaded : false ,
396+ limit : Some ( 3 ) ,
397+ ..Default :: default ( )
398+ } ;
399+ let result = a. arg_sort ( options) ;
400+ let idx: Vec < IdxSize > = result. into_no_null_iter ( ) . collect ( ) ;
401+ // descending, nulls last, top-3: values 5(idx=2), 4(idx=0), 3(idx=5)
402+ assert_eq ! ( idx, vec![ 2 , 0 , 5 ] ) ;
403+ }
404+
405+ #[ test]
406+ fn test_arg_sort_descending_limit_larger_than_len ( ) {
407+ let a = Int32Chunked :: new (
408+ PlSmallStr :: from_static ( "a" ) ,
409+ & [ 3 , 1 , 2 ] ,
410+ ) ;
411+
412+ let options = SortOptions {
413+ descending : true ,
414+ nulls_last : false ,
415+ multithreaded : false ,
416+ limit : Some ( 10 ) ,
417+ ..Default :: default ( )
418+ } ;
419+ let result = a. arg_sort ( options) ;
420+ let idx: Vec < IdxSize > = result. into_no_null_iter ( ) . collect ( ) ;
421+ assert_eq ! ( idx, vec![ 0 , 2 , 1 ] ) ;
422+ }
423+
424+ #[ test]
425+ fn test_arg_sort_descending_limit_with_duplicates ( ) {
426+ let a = Int32Chunked :: new (
427+ PlSmallStr :: from_static ( "a" ) ,
428+ & [ 3 , 5 , 5 , 1 , 3 , 2 ] ,
429+ ) ;
430+
431+ let options = SortOptions {
432+ descending : true ,
433+ nulls_last : false ,
434+ multithreaded : false ,
435+ limit : Some ( 4 ) ,
436+ ..Default :: default ( )
437+ } ;
438+ let result = a. arg_sort ( options) ;
439+ let idx: Vec < IdxSize > = result. into_no_null_iter ( ) . collect ( ) ;
440+ // descending top-4: 5(idx=1), 5(idx=2), 3(idx=0), 3(idx=4)
441+ let vals: Vec < i32 > = idx. iter ( ) . map ( |& i| a. get ( i as usize ) . unwrap ( ) ) . collect ( ) ;
442+ assert_eq ! ( vals, vec![ 5 , 5 , 3 , 3 ] ) ;
443+ }
329444}
0 commit comments