Skip to content

Commit f110a45

Browse files
boris324claudenameexhaustion
authored
fix(rust): Incorrect arg_sort with descending+limit (#26839)
Co-authored-by: boris324 <boris324@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: nameexhaustion <simonlin.rqmmw@slmail.me>
1 parent 7a2a218 commit f110a45

File tree

1 file changed

+62
-6
lines changed
  • crates/polars-core/src/chunked_array/ops/sort

1 file changed

+62
-6
lines changed

crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,13 @@ where
142142
let out = if limit >= vals.len() {
143143
vals.as_mut_slice()
144144
} else {
145-
let (lower, _el, _upper) = vals
146-
.as_mut_slice()
147-
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
145+
let (lower, _el, _upper) = if options.descending {
146+
vals.as_mut_slice()
147+
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
148+
} else {
149+
vals.as_mut_slice()
150+
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
151+
};
148152
lower
149153
};
150154

@@ -235,9 +239,13 @@ where
235239
let out = if limit >= vals.len() {
236240
vals.as_mut_slice()
237241
} else {
238-
let (lower, _el, _upper) = vals
239-
.as_mut_slice()
240-
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
242+
let (lower, _el, _upper) = if options.descending {
243+
vals.as_mut_slice()
244+
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
245+
} else {
246+
vals.as_mut_slice()
247+
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
248+
};
241249
lower
242250
};
243251
sort_impl(out, options);
@@ -326,4 +334,52 @@ mod test {
326334
let idx = reverse_stable_no_nulls(&a, 0);
327335
assert_eq!(idx.len(), 0);
328336
}
337+
338+
#[test]
339+
fn test_arg_sort_descending_with_limit() {
340+
let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
341+
let o = SortOptions {
342+
descending: true,
343+
nulls_last: false,
344+
multithreaded: false,
345+
limit: Some(3),
346+
..Default::default()
347+
};
348+
let r = a.arg_sort(o);
349+
let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
350+
assert_eq!(idx, vec![2, 0, 4]);
351+
}
352+
353+
#[test]
354+
fn test_arg_sort_asc_with_limit() {
355+
let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 2, 5, 1, 3]);
356+
let o = SortOptions {
357+
descending: false,
358+
nulls_last: false,
359+
multithreaded: false,
360+
limit: Some(3),
361+
..Default::default()
362+
};
363+
let r = a.arg_sort(o);
364+
let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
365+
assert_eq!(idx, vec![3, 1, 4]);
366+
}
367+
368+
#[test]
369+
fn test_arg_sort_desc_limit_nulls() {
370+
let a = Int32Chunked::new(
371+
PlSmallStr::from_static("a"),
372+
&[Some(4), None, Some(5), Some(1), None, Some(3)],
373+
);
374+
let o = SortOptions {
375+
descending: true,
376+
nulls_last: true,
377+
multithreaded: false,
378+
limit: Some(3),
379+
..Default::default()
380+
};
381+
let r = a.arg_sort(o);
382+
let idx: Vec<IdxSize> = r.into_no_null_iter().collect();
383+
assert_eq!(idx, vec![2, 0, 5]);
384+
}
329385
}

0 commit comments

Comments
 (0)