diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs index b83fa3834aa4..bc425398ea89 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs @@ -142,9 +142,13 @@ where let out = if limit >= vals.len() { vals.as_mut_slice() } else { - let (lower, _el, _upper) = vals - .as_mut_slice() - .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)); + let (lower, _el, _upper) = if options.descending { + vals.as_mut_slice() + .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1)) + } else { + vals.as_mut_slice() + .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)) + }; lower }; @@ -235,9 +239,13 @@ where let out = if limit >= vals.len() { vals.as_mut_slice() } else { - let (lower, _el, _upper) = vals - .as_mut_slice() - .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)); + let (lower, _el, _upper) = if options.descending { + vals.as_mut_slice() + .select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1)) + } else { + vals.as_mut_slice() + .select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1)) + }; lower }; sort_impl(out, options); @@ -326,4 +334,52 @@ mod test { let idx = reverse_stable_no_nulls(&a, 0); assert_eq!(idx.len(), 0); } + + #[test] + fn test_arg_sort_descending_with_limit() { + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]); + let o = SortOptions { + descending: true, + nulls_last: false, + multithreaded: false, + limit: Some(3), + ..Default::default() + }; + let r = a.arg_sort(o); + let idx: Vec = r.into_no_null_iter().collect(); + assert_eq!(idx, vec![2, 0, 4]); + } + + #[test] + fn test_arg_sort_asc_with_limit() { + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]); + let o = SortOptions { + descending: false, + nulls_last: false, + multithreaded: false, + limit: Some(3), + ..Default::default() + }; + let r = a.arg_sort(o); + let idx: Vec = r.into_no_null_iter().collect(); + assert_eq!(idx, vec![3, 1, 4]); + } + + #[test] + fn test_arg_sort_desc_limit_nulls() { + let a = Int32Chunked::new( + PlSmallStr::from_static("a"), + &[Some(4), None, Some(5), Some(1), None, Some(3)], + ); + let o = SortOptions { + descending: true, + nulls_last: true, + multithreaded: false, + limit: Some(3), + ..Default::default() + }; + let r = a.arg_sort(o); + let idx: Vec = r.into_no_null_iter().collect(); + assert_eq!(idx, vec![2, 0, 5]); + } }