Skip to content

Commit 8a713ba

Browse files
author
root
committed
fix: use correct comparator for arg_sort with descending + limit
select_nth_unstable_by always used ascending comparison to partition elements, which meant that with descending=true and a limit, the wrong N elements were selected (smallest instead of largest). Reverse the comparator when descending is set in both arg_sort and arg_sort_no_nulls. Closes #26833
1 parent 737ff7c commit 8a713ba

File tree

2 files changed

+157
-6
lines changed

2 files changed

+157
-6
lines changed

crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,13 @@ where
142142
let out = if limit >= vals.len() {
143143
vals.as_mut_slice()
144144
} else {
145-
let (lower, _el, _upper) = vals
146-
.as_mut_slice()
147-
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
145+
let (lower, _el, _upper) = if options.descending {
146+
vals.as_mut_slice()
147+
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
148+
} else {
149+
vals.as_mut_slice()
150+
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
151+
};
148152
lower
149153
};
150154

@@ -235,9 +239,13 @@ where
235239
let out = if limit >= vals.len() {
236240
vals.as_mut_slice()
237241
} else {
238-
let (lower, _el, _upper) = vals
239-
.as_mut_slice()
240-
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
242+
let (lower, _el, _upper) = if options.descending {
243+
vals.as_mut_slice()
244+
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1))
245+
} else {
246+
vals.as_mut_slice()
247+
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1))
248+
};
241249
lower
242250
};
243251
sort_impl(out, options);
@@ -326,4 +334,111 @@ mod test {
326334
let idx = reverse_stable_no_nulls(&a, 0);
327335
assert_eq!(idx.len(), 0);
328336
}
337+
338+
#[test]
339+
fn test_arg_sort_descending_with_limit() {
340+
let a = Int32Chunked::new(
341+
PlSmallStr::from_static("a"),
342+
&[4, 2, 5, 1, 3],
343+
);
344+
345+
let options = SortOptions {
346+
descending: true,
347+
nulls_last: false,
348+
multithreaded: false,
349+
limit: Some(3),
350+
..Default::default()
351+
};
352+
let result = a.arg_sort(options);
353+
let idx: Vec<IdxSize> = result.into_no_null_iter().collect();
354+
// descending top-3: values 5(idx=2), 4(idx=0), 3(idx=4)
355+
assert_eq!(idx, vec![2, 0, 4]);
356+
}
357+
358+
#[test]
359+
fn test_arg_sort_ascending_with_limit() {
360+
let a = Int32Chunked::new(
361+
PlSmallStr::from_static("a"),
362+
&[4, 2, 5, 1, 3],
363+
);
364+
365+
let options = SortOptions {
366+
descending: false,
367+
nulls_last: false,
368+
multithreaded: false,
369+
limit: Some(3),
370+
..Default::default()
371+
};
372+
let result = a.arg_sort(options);
373+
let idx: Vec<IdxSize> = result.into_no_null_iter().collect();
374+
// ascending top-3: values 1(idx=3), 2(idx=1), 3(idx=4)
375+
assert_eq!(idx, vec![3, 1, 4]);
376+
}
377+
378+
#[test]
379+
fn test_arg_sort_descending_limit_with_nulls() {
380+
let a = Int32Chunked::new(
381+
PlSmallStr::from_static("a"),
382+
&[
383+
Some(4),
384+
None,
385+
Some(5),
386+
Some(1),
387+
None,
388+
Some(3),
389+
],
390+
);
391+
392+
let options = SortOptions {
393+
descending: true,
394+
nulls_last: true,
395+
multithreaded: false,
396+
limit: Some(3),
397+
..Default::default()
398+
};
399+
let result = a.arg_sort(options);
400+
let idx: Vec<IdxSize> = result.into_no_null_iter().collect();
401+
// descending, nulls last, top-3: values 5(idx=2), 4(idx=0), 3(idx=5)
402+
assert_eq!(idx, vec![2, 0, 5]);
403+
}
404+
405+
#[test]
406+
fn test_arg_sort_descending_limit_larger_than_len() {
407+
let a = Int32Chunked::new(
408+
PlSmallStr::from_static("a"),
409+
&[3, 1, 2],
410+
);
411+
412+
let options = SortOptions {
413+
descending: true,
414+
nulls_last: false,
415+
multithreaded: false,
416+
limit: Some(10),
417+
..Default::default()
418+
};
419+
let result = a.arg_sort(options);
420+
let idx: Vec<IdxSize> = result.into_no_null_iter().collect();
421+
assert_eq!(idx, vec![0, 2, 1]);
422+
}
423+
424+
#[test]
425+
fn test_arg_sort_descending_limit_with_duplicates() {
426+
let a = Int32Chunked::new(
427+
PlSmallStr::from_static("a"),
428+
&[3, 5, 5, 1, 3, 2],
429+
);
430+
431+
let options = SortOptions {
432+
descending: true,
433+
nulls_last: false,
434+
multithreaded: false,
435+
limit: Some(4),
436+
..Default::default()
437+
};
438+
let result = a.arg_sort(options);
439+
let idx: Vec<IdxSize> = result.into_no_null_iter().collect();
440+
// descending top-4: 5(idx=1), 5(idx=2), 3(idx=0), 3(idx=4)
441+
let vals: Vec<i32> = idx.iter().map(|&i| a.get(i as usize).unwrap()).collect();
442+
assert_eq!(vals, vec![5, 5, 3, 3]);
443+
}
329444
}

py-polars/tests/unit/operations/test_sort.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,3 +1289,39 @@ def test_sort_by_empty_list_eval_25433() -> None:
12891289
out = df.select(pl.col.a.list.eval(pl.element().sort_by(pl.element())))
12901290
expected = pl.DataFrame({"a": [sorted(some_list), []]})
12911291
assert_frame_equal(out, expected)
1292+
1293+
1294+
def test_top_k_bottom_k_correctness_26833() -> None:
1295+
df = pl.DataFrame({"a": [4, 2, 5, 1, 3]})
1296+
1297+
top3 = df.select(pl.col("a").top_k(3))["a"].to_list()
1298+
assert sorted(top3, reverse=True) == [5, 4, 3]
1299+
1300+
bottom3 = df.select(pl.col("a").bottom_k(3))["a"].to_list()
1301+
assert sorted(bottom3) == [1, 2, 3]
1302+
1303+
1304+
def test_top_k_with_nulls_26833() -> None:
1305+
df = pl.DataFrame({"a": [4, None, 5, 1, None, 3]})
1306+
1307+
top3 = df.select(pl.col("a").top_k(3))["a"].to_list()
1308+
assert sorted([v for v in top3 if v is not None], reverse=True) == [5, 4, 3]
1309+
1310+
bottom2 = df.select(pl.col("a").bottom_k(2))["a"].to_list()
1311+
assert sorted([v for v in bottom2 if v is not None]) == [1, 3]
1312+
1313+
1314+
def test_top_k_with_duplicates_26833() -> None:
1315+
df = pl.DataFrame({"a": [3, 5, 5, 1, 3, 2]})
1316+
1317+
top4 = df.select(pl.col("a").top_k(4))["a"].to_list()
1318+
assert sorted(top4, reverse=True) == [5, 5, 3, 3]
1319+
1320+
bottom4 = df.select(pl.col("a").bottom_k(4))["a"].to_list()
1321+
assert sorted(bottom4) == [1, 2, 3, 3]
1322+
1323+
1324+
def test_top_k_larger_than_len_26833() -> None:
1325+
df = pl.DataFrame({"a": [3, 1, 2]})
1326+
result = df.select(pl.col("a").top_k(10))["a"].to_list()
1327+
assert sorted(result, reverse=True) == [3, 2, 1]

0 commit comments

Comments
 (0)