Skip to content

Commit bf121d6

Browse files
tim-oneAlexWaygoodeendebakpt
authored
GH-116554: Relax list.sort()'s notion of "descending" runs (#116578)
* GH-116554: Relax list.sort()'s notion of "descending" run Rewrote `count_run()` so that sub-runs of equal elements no longer end a descending run. Both ascending and descending runs can have arbitrarily many sub-runs of arbitrarily many equal elements now. This is tricky, because we only use ``<`` comparisons, so checking for equality doesn't come "for free". Surprisingly, it turned out there's a very cheap (one comparison) way to determine whether an ascending run consisted of all-equal elements. That sealed the deal. In addition, after a descending run is reversed in-place, we now go on to see whether it can be extended by an ascending run that just happens to be adjacent. This succeeds in finding at least one additional element to append about half the time, and so appears to more than repay its cost (the savings come from getting to skip a binary search, when a short run is artificially forced to length MIINRUN later, for each new element `count_run()` can add to the initial run). While these have been in the back of my mind for years, a question on StackOverflow pushed it to action: https://stackoverflow.com/questions/78108792/ They were wondering why it took about 4x longer to sort a list like: [999_999, 999_999, ..., 2, 2, 1, 1, 0, 0] than "similar" lists. Of course that runs very much faster after this patch. Co-authored-by: Alex Waygood <[email protected]> Co-authored-by: Pieter Eendebak <[email protected]>
1 parent 7d1abe9 commit bf121d6

File tree

4 files changed

+156
-66
lines changed

4 files changed

+156
-66
lines changed

Lib/test/test_sort.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,27 @@ def bad_key(x):
128128
x = [e for e, i in augmented] # a stable sort of s
129129
check("stability", x, s)
130130

131+
def test_small_stability(self):
132+
from itertools import product
133+
from operator import itemgetter
134+
135+
# Exhaustively test stability across all lists of small lengths
136+
# and only a few distinct elements.
137+
# This can provoke edge cases that randomization is unlikely to find.
138+
# But it can grow very expensive quickly, so don't overdo it.
139+
NELTS = 3
140+
MAXSIZE = 9
141+
142+
pick0 = itemgetter(0)
143+
for length in range(MAXSIZE + 1):
144+
# There are NELTS ** length distinct lists.
145+
for t in product(range(NELTS), repeat=length):
146+
xs = list(zip(t, range(length)))
147+
# Stability forced by index in each element.
148+
forced = sorted(xs)
149+
# Use key= to hide the index from compares.
150+
native = sorted(xs, key=pick0)
151+
self.assertEqual(forced, native)
131152
#==============================================================================
132153

133154
class TestBugs(unittest.TestCase):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``list.sort()`` now exploits more cases of partial ordering, particularly those with long descending runs with sub-runs of equal values. Those are recognized as single runs now (previously, each block of repeated values caused a new run to be created).

Objects/listobject.c

Lines changed: 103 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,10 +1618,11 @@ struct s_MergeState {
16181618
/* binarysort is the best method for sorting small arrays: it does
16191619
few compares, but can do data movement quadratic in the number of
16201620
elements.
1621-
[lo, hi) is a contiguous slice of a list, and is sorted via
1621+
[lo.keys, hi) is a contiguous slice of a list of keys, and is sorted via
16221622
binary insertion. This sort is stable.
1623-
On entry, must have lo <= start <= hi, and that [lo, start) is already
1624-
sorted (pass start == lo if you don't know!).
1623+
On entry, must have lo.keys <= start <= hi, and that
1624+
[lo.keys, start) is already sorted (pass start == lo.keys if you don't
1625+
know!).
16251626
If islt() complains return -1, else 0.
16261627
Even in case of error, the output slice will be some permutation of
16271628
the input (nothing is lost or duplicated).
@@ -1634,7 +1635,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
16341635
PyObject *pivot;
16351636

16361637
assert(lo.keys <= start && start <= hi);
1637-
/* assert [lo, start) is sorted */
1638+
/* assert [lo.keys, start) is sorted */
16381639
if (lo.keys == start)
16391640
++start;
16401641
for (; start < hi; ++start) {
@@ -1643,9 +1644,9 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
16431644
r = start;
16441645
pivot = *r;
16451646
/* Invariants:
1646-
* pivot >= all in [lo, l).
1647+
* pivot >= all in [lo.keys, l).
16471648
* pivot < all in [r, start).
1648-
* The second is vacuously true at the start.
1649+
* These are vacuously true at the start.
16491650
*/
16501651
assert(l < r);
16511652
do {
@@ -1656,7 +1657,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
16561657
l = p+1;
16571658
} while (l < r);
16581659
assert(l == r);
1659-
/* The invariants still hold, so pivot >= all in [lo, l) and
1660+
/* The invariants still hold, so pivot >= all in [lo.keys, l) and
16601661
pivot < all in [l, start), so pivot belongs at l. Note
16611662
that if there are elements equal to pivot, l points to the
16621663
first slot after them -- that's why this sort is stable.
@@ -1671,7 +1672,7 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
16711672
p = start + offset;
16721673
pivot = *p;
16731674
l += offset;
1674-
for (p = start + offset; p > l; --p)
1675+
for ( ; p > l; --p)
16751676
*p = *(p-1);
16761677
*l = pivot;
16771678
}
@@ -1682,56 +1683,115 @@ binarysort(MergeState *ms, sortslice lo, PyObject **hi, PyObject **start)
16821683
return -1;
16831684
}
16841685

1685-
/*
1686-
Return the length of the run beginning at lo, in the slice [lo, hi). lo < hi
1687-
is required on entry. "A run" is the longest ascending sequence, with
1688-
1689-
lo[0] <= lo[1] <= lo[2] <= ...
1690-
1691-
or the longest descending sequence, with
1692-
1693-
lo[0] > lo[1] > lo[2] > ...
1686+
static void
1687+
sortslice_reverse(sortslice *s, Py_ssize_t n)
1688+
{
1689+
reverse_slice(s->keys, &s->keys[n]);
1690+
if (s->values != NULL)
1691+
reverse_slice(s->values, &s->values[n]);
1692+
}
16941693

1695-
Boolean *descending is set to 0 in the former case, or to 1 in the latter.
1696-
For its intended use in a stable mergesort, the strictness of the defn of
1697-
"descending" is needed so that the caller can safely reverse a descending
1698-
sequence without violating stability (strict > ensures there are no equal
1699-
elements to get out of order).
1694+
/*
1695+
Return the length of the run beginning at slo->keys, spanning no more than
1696+
nremaining elements. The run beginning there may be ascending or descending,
1697+
but the function permutes it in place, if needed, so that it's always ascending
1698+
upon return.
17001699
17011700
Returns -1 in case of error.
17021701
*/
17031702
static Py_ssize_t
1704-
count_run(MergeState *ms, PyObject **lo, PyObject **hi, int *descending)
1703+
count_run(MergeState *ms, sortslice *slo, Py_ssize_t nremaining)
17051704
{
1706-
Py_ssize_t k;
1705+
Py_ssize_t k; /* used by IFLT macro expansion */
17071706
Py_ssize_t n;
1707+
PyObject ** const lo = slo->keys;
17081708

1709-
assert(lo < hi);
1710-
*descending = 0;
1711-
++lo;
1712-
if (lo == hi)
1713-
return 1;
1714-
1715-
n = 2;
1716-
IFLT(*lo, *(lo-1)) {
1717-
*descending = 1;
1718-
for (lo = lo+1; lo < hi; ++lo, ++n) {
1719-
IFLT(*lo, *(lo-1))
1720-
;
1721-
else
1722-
break;
1723-
}
1709+
/* In general, as things go on we've established that the slice starts
1710+
with a monotone run of n elements, starting at lo. */
1711+
1712+
/* We're n elements into the slice, and the most recent neq+1 elments are
1713+
* all equal. This reverses them in-place, and resets neq for reuse.
1714+
*/
1715+
#define REVERSE_LAST_NEQ \
1716+
if (neq) { \
1717+
sortslice slice = *slo; \
1718+
++neq; \
1719+
sortslice_advance(&slice, n - neq); \
1720+
sortslice_reverse(&slice, neq); \
1721+
neq = 0; \
1722+
}
1723+
1724+
/* Sticking to only __lt__ compares is confusing and error-prone. But in
1725+
* this routine, almost all uses of IFLT can be captured by tiny macros
1726+
* giving mnemonic names to the intent. Note that inline functions don't
1727+
* work for this (IFLT expands to code including `goto fail`).
1728+
*/
1729+
#define IF_NEXT_LARGER IFLT(lo[n-1], lo[n])
1730+
#define IF_NEXT_SMALLER IFLT(lo[n], lo[n-1])
1731+
1732+
assert(nremaining);
1733+
/* try ascending run first */
1734+
for (n = 1; n < nremaining; ++n) {
1735+
IF_NEXT_SMALLER
1736+
break;
17241737
}
1725-
else {
1726-
for (lo = lo+1; lo < hi; ++lo, ++n) {
1727-
IFLT(*lo, *(lo-1))
1738+
if (n == nremaining)
1739+
return n;
1740+
/* lo[n] is strictly less */
1741+
/* If n is 1 now, then the first compare established it's a descending
1742+
* run, so fall through to the descending case. But if n > 1, there are
1743+
* n elements in an ascending run terminated by the strictly less lo[n].
1744+
* If the first key < lo[n-1], *somewhere* along the way the sequence
1745+
* increased, so we're done (there is no descending run).
1746+
* Else first key >= lo[n-1], which implies that the entire ascending run
1747+
* consists of equal elements. In that case, this is a descending run,
1748+
* and we reverse the all-equal prefix in-place.
1749+
*/
1750+
if (n > 1) {
1751+
IFLT(lo[0], lo[n-1])
1752+
return n;
1753+
sortslice_reverse(slo, n);
1754+
}
1755+
++n; /* in all cases it's been established that lo[n] has been resolved */
1756+
1757+
/* Finish descending run. All-squal subruns are reversed in-place on the
1758+
* fly. Their original order will be restored at the end by the whole-slice
1759+
* reversal.
1760+
*/
1761+
Py_ssize_t neq = 0;
1762+
for ( ; n < nremaining; ++n) {
1763+
IF_NEXT_SMALLER {
1764+
/* This ends the most recent run of equal elments, but still in
1765+
* the "descending" direction.
1766+
*/
1767+
REVERSE_LAST_NEQ
1768+
}
1769+
else {
1770+
IF_NEXT_LARGER /* descending run is over */
17281771
break;
1772+
else /* not x < y and not y < x implies x == y */
1773+
++neq;
17291774
}
17301775
}
1776+
REVERSE_LAST_NEQ
1777+
sortslice_reverse(slo, n); /* transform to ascending run */
1778+
1779+
/* And after reversing, it's possible this can be extended by a
1780+
* naturally increasing suffix; e.g., [3, 2, 3, 4, 1] makes an
1781+
* ascending run from the first 4 elements.
1782+
*/
1783+
for ( ; n < nremaining; ++n) {
1784+
IF_NEXT_SMALLER
1785+
break;
1786+
}
17311787

17321788
return n;
17331789
fail:
17341790
return -1;
1791+
1792+
#undef REVERSE_LAST_NEQ
1793+
#undef IF_NEXT_SMALLER
1794+
#undef IF_NEXT_LARGER
17351795
}
17361796

17371797
/*
@@ -2449,14 +2509,6 @@ merge_compute_minrun(Py_ssize_t n)
24492509
return n + r;
24502510
}
24512511

2452-
static void
2453-
reverse_sortslice(sortslice *s, Py_ssize_t n)
2454-
{
2455-
reverse_slice(s->keys, &s->keys[n]);
2456-
if (s->values != NULL)
2457-
reverse_slice(s->values, &s->values[n]);
2458-
}
2459-
24602512
/* Here we define custom comparison functions to optimize for the cases one commonly
24612513
* encounters in practice: homogeneous lists, often of one of the basic types. */
24622514

@@ -2824,15 +2876,12 @@ list_sort_impl(PyListObject *self, PyObject *keyfunc, int reverse)
28242876
*/
28252877
minrun = merge_compute_minrun(nremaining);
28262878
do {
2827-
int descending;
28282879
Py_ssize_t n;
28292880

28302881
/* Identify next run. */
2831-
n = count_run(&ms, lo.keys, lo.keys + nremaining, &descending);
2882+
n = count_run(&ms, &lo, nremaining);
28322883
if (n < 0)
28332884
goto fail;
2834-
if (descending)
2835-
reverse_sortslice(&lo, n);
28362885
/* If short, extend to min(minrun, nremaining). */
28372886
if (n < minrun) {
28382887
const Py_ssize_t force = nremaining <= minrun ?

Objects/listsort.txt

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,24 +212,43 @@ A detailed description of timsort follows.
212212

213213
Runs
214214
----
215-
count_run() returns the # of elements in the next run. A run is either
216-
"ascending", which means non-decreasing:
215+
count_run() returns the # of elements in the next run, and, if it's a
216+
descending run, reverses it in-place. A run is either "ascending", which
217+
means non-decreasing:
217218

218219
a0 <= a1 <= a2 <= ...
219220

220-
or "descending", which means strictly decreasing:
221+
or "descending", which means non-increasing:
221222

222-
a0 > a1 > a2 > ...
223+
a0 >= a1 >= a2 >= ...
223224

224225
Note that a run is always at least 2 long, unless we start at the array's
225-
last element.
226-
227-
The definition of descending is strict, because the main routine reverses
228-
a descending run in-place, transforming a descending run into an ascending
229-
run. Reversal is done via the obvious fast "swap elements starting at each
230-
end, and converge at the middle" method, and that can violate stability if
231-
the slice contains any equal elements. Using a strict definition of
232-
descending ensures that a descending run contains distinct elements.
226+
last element. If all elements in the array are equal, it can be viewed as
227+
both ascending and descending. Upon return, the run count_run() identifies
228+
is always ascending.
229+
230+
Reversal is done via the obvious fast "swap elements starting at each
231+
end, and converge at the middle" method. That can violate stability if
232+
the slice contains any equal elements. For that reason, for a long time
233+
the code used strict inequality (">" rather than ">=") in its definition
234+
of descending.
235+
236+
Removing that restriction required some complication: when processing a
237+
descending run, all-equal sub-runs of elements are reversed in-place, on the
238+
fly. Their original relative order is restored "by magic" via the final
239+
"reverse the entire run" step.
240+
241+
This makes processing descending runs a little more costly. We only use
242+
`__lt__` comparisons, so that `x == y` has to be deduced from
243+
`not x < y and not y < x`. But so long as a run remains strictly decreasing,
244+
only one of those compares needs to be done per loop iteration. So the primsry
245+
extra cost is paid only when there are equal elements, and they get some
246+
compensating benefit by not needing to end the descending run.
247+
248+
There's one more trick added since the original: after reversing a descending
249+
run, it's possible that it can be extended by an adjacent ascending run. For
250+
example, given [3, 2, 1, 3, 4, 5, 0], the 3-element descending prefix is
251+
reversed in-place, and then extended by [3, 4, 5].
233252

234253
If an array is random, it's very unlikely we'll see long runs. If a natural
235254
run contains less than minrun elements (see next section), the main loop

0 commit comments

Comments
 (0)