Skip to content

Commit 0bd2fd7

Browse files
committed
Initial stab at implementing Stefan Pochmann's spiffy new minrun scheme.
1 parent 60181f4 commit 0bd2fd7

File tree

3 files changed

+137
-41
lines changed

3 files changed

+137
-41
lines changed

Misc/ACKS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,7 @@ Jean-François Piéronne
14801480
Oleg Plakhotnyuk
14811481
Anatoliy Platonov
14821482
Marcel Plch
1483+
Stefan Pochmann
14831484
Kirill Podoprigora
14841485
Remi Pointel
14851486
Jon Poler

Objects/listobject.c

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,10 +1684,7 @@ sortslice_advance(sortslice *slice, Py_ssize_t n)
16841684
/* Avoid malloc for small temp arrays. */
16851685
#define MERGESTATE_TEMP_SIZE 256
16861686

1687-
/* The largest value of minrun. This must be a power of 2, and >= 1, so that
1688-
* the compute_minrun() algorithm guarantees to return a result no larger than
1689-
* this,
1690-
*/
1687+
/* The largest value of minrun. This must be a power of 2, and >= 1 */
16911688
#define MAX_MINRUN 64
16921689
#if ((MAX_MINRUN) < 1) || ((MAX_MINRUN) & ((MAX_MINRUN) - 1))
16931690
#error "MAX_MINRUN must be a power of 2, and >= 1"
@@ -1748,6 +1745,12 @@ struct s_MergeState {
17481745
* of tuples. It may be set to safe_object_compare, but the idea is that hopefully
17491746
* we can assume more, and use one of the special-case compares. */
17501747
int (*tuple_elem_compare)(PyObject *, PyObject *, MergeState *);
1748+
1749+
/* Varisbles used for minrun computation. The "ideal" minrun length is
1750+
* the infinite precision listlen / 2**e, which is represented as the
1751+
* marhematical value of mr_int + mr_frac / 2**e.
1752+
*/
1753+
Py_ssize_t mr_int, mr_frac, mr_current_frac, mr_e, mr_mask;
17511754
};
17521755

17531756
/* binarysort is the best method for sorting small arrays: it does few
@@ -2209,6 +2212,16 @@ merge_init(MergeState *ms, Py_ssize_t list_size, int has_keyfunc,
22092212
ms->min_gallop = MIN_GALLOP;
22102213
ms->listlen = list_size;
22112214
ms->basekeys = lo->keys;
2215+
2216+
ms->mr_int = list_size;
2217+
ms->mr_e = 0;
2218+
while (ms->mr_int >= MAX_MINRUN) {
2219+
ms->mr_int >>= 1;
2220+
++ms->mr_e;
2221+
}
2222+
ms->mr_mask = (1 << ms->mr_e) - 1;
2223+
ms->mr_frac = list_size & ms->mr_mask;
2224+
ms->mr_current_frac = 0;
22122225
}
22132226

22142227
/* Free all the temp memory owned by the MergeState. This must be called
@@ -2686,27 +2699,15 @@ merge_force_collapse(MergeState *ms)
26862699
return 0;
26872700
}
26882701

2689-
/* Compute a good value for the minimum run length; natural runs shorter
2690-
* than this are boosted artificially via binary insertion.
2691-
*
2692-
* If n < MAX_MINRUN return n (it's too small to bother with fancy stuff).
2693-
* Else if n is an exact power of 2, return MAX_MINRUN / 2.
2694-
* Else return an int k, MAX_MINRUN / 2 <= k <= MAX_MINRUN, such that n/k is
2695-
* close to, but strictly less than, an exact power of 2.
2696-
*
2697-
* See listsort.txt for more info.
2698-
*/
2699-
static Py_ssize_t
2700-
merge_compute_minrun(Py_ssize_t n)
2702+
/* Return the next minrun value to use. See listsort.txt. */
2703+
static inline Py_ssize_t
2704+
minrun_next(MergeState *ms)
27012705
{
2702-
Py_ssize_t r = 0; /* becomes 1 if any 1 bits are shifted off */
2703-
2704-
assert(n >= 0);
2705-
while (n >= MAX_MINRUN) {
2706-
r |= n & 1;
2707-
n >>= 1;
2708-
}
2709-
return n + r;
2706+
ms->mr_current_frac += ms->mr_frac;
2707+
assert(ms->mr_current_frac >> ms->mr_e <= 1);
2708+
Py_ssize_t result = ms->mr_int + (ms->mr_current_frac >> ms->mr_e);
2709+
ms->mr_current_frac &= ms->mr_mask;
2710+
return result;
27102711
}
27112712

27122713
/* Here we define custom comparison functions to optimize for the cases one commonly
@@ -3074,7 +3075,6 @@ list_sort_impl(PyListObject *self, PyObject *keyfunc, int reverse)
30743075
/* March over the array once, left to right, finding natural runs,
30753076
* and extending short natural runs to minrun elements.
30763077
*/
3077-
minrun = merge_compute_minrun(nremaining);
30783078
do {
30793079
Py_ssize_t n;
30803080

@@ -3083,6 +3083,7 @@ list_sort_impl(PyListObject *self, PyObject *keyfunc, int reverse)
30833083
if (n < 0)
30843084
goto fail;
30853085
/* If short, extend to min(minrun, nremaining). */
3086+
minrun = minrun_next(&ms);
30863087
if (n < minrun) {
30873088
const Py_ssize_t force = nremaining <= minrun ?
30883089
nremaining : minrun;

Objects/listsort.txt

Lines changed: 110 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ result. This has two primary good effects:
270270

271271
Computing minrun
272272
----------------
273-
If N < MAX_MINRUN, minrun is N. IOW, binary insertion sort is used for the
274-
whole array then; it's hard to beat that given the overheads of trying
273+
If N < MAX_MINRUN, minrun is N. IOW, binary insertion sort is used for the
274+
whole array then; it's hard to beat that given the overheads of trying
275275
something fancier (see note BINSORT).
276276

277277
When N is a power of 2, testing on random data showed that minrun values of
@@ -288,7 +288,6 @@ that 32 isn't a good choice for the general case! Consider N=2112:
288288

289289
>>> divmod(2112, 32)
290290
(66, 0)
291-
>>>
292291

293292
If the data is randomly ordered, we're very likely to end up with 66 runs
294293
each of length 32. The first 64 of these trigger a sequence of perfectly
@@ -301,22 +300,40 @@ to get 64 elements into place).
301300
If we take minrun=33 in this case, then we're very likely to end up with 64
302301
runs each of length 33, and then all merges are perfectly balanced. Better!
303302

304-
What we want to avoid is picking minrun such that in
303+
The original code used a cheap heuristic to pick a minrun that avoided the
304+
very worst cases of imbalance for the final merge, but "pretty bad" cases
305+
still existed.
305306

306-
q, r = divmod(N, minrun)
307+
In 2025, Stefan Pochmann found a much better approach, based on letting minrun
308+
vary a bit from one run to the next. Under his scheme, at _all_ levels of the
309+
merge tree:
307310

308-
q is a power of 2 and r>0 (then the last merge only gets r elements into
309-
place, and r < minrun is small compared to N), or q a little larger than a
310-
power of 2 regardless of r (then we've got a case similar to "2112", again
311-
leaving too little work for the last merge to do).
311+
- The number of runs is a power of 2.
312+
- At most two different run lengths appear.
313+
- When two do appear, the smaller is one less than the larger.
314+
- The lengths of run pairs merged never differ by more than one.
312315

313-
Instead we pick a minrun in range(MAX_MINRUN / 2, MAX_MINRUN + 1) such that
314-
N/minrun is exactly a power of 2, or if that isn't possible, is close to, but
315-
strictly less than, a power of 2. This is easier to do than it may sound:
316-
take the first log2(MAX_MINRUN) bits of N, and add 1 if any of the remaining
317-
bits are set. In fact, that rule covers every case in this section, including
318-
small N and exact powers of 2; merge_compute_minrun() is a deceptively simple
319-
function.
316+
So, in all respects, as perfectly balanced as possible.
317+
318+
For the 2112 case, that also keeps minrun at 33, but we were lucky there
319+
that 2112 is a power of 2 times 33. The new approach doesn't rely on luck.
320+
321+
The basic idea is to conceive of the ideal run length as being a real number
322+
rather than just an integer. For an array of length `n`, let `e` be the
323+
smallest int such that n/2**e < MAX_MINRUN. Then mr = n/2**e is the ideal
324+
run length, and obviously mr * 2**e is n, so there are exactly 2**e runs.
325+
326+
Of course runs can't have a fractional length, so we start the i'th (zero-
327+
based) run at index int(mr * i), for i in range(2**e). The differences between
328+
adjacent starting indices are the run lengths, and it's left as an exercise
329+
for the reader to show that they have the nice properties listed above. See
330+
note MINRUN CODE for an executable Python implementation to help make it all
331+
concrete.
332+
333+
The code doesn't actually compute the starting indices, or use floats. Instead
334+
mr is represented as a pair of integers such that the infinite precision mr is
335+
equal to mr_int + mr_frac / 2**e, and only the delta (run length) from one
336+
index to the next is computed.
320337

321338

322339
The Merge Pattern
@@ -820,3 +837,80 @@ partially mitigated by pre-scanning the data to determine whether the data is
820837
homogeneous with respect to type. If so, it is sometimes possible to
821838
substitute faster type-specific comparisons for the slower, generic
822839
PyObject_RichCompareBool.
840+
841+
MINRUN CODE
842+
from itertools import accumulate
843+
try:
844+
from itertools import batched
845+
except ImportError:
846+
from itertools import islice
847+
def batched(xs, k):
848+
it = iter(xs)
849+
while chunk := tuple(islice(it, k)):
850+
yield chunk
851+
852+
MAX_MINRUN = 64
853+
854+
def gen_minruns(n):
855+
# mr_int = minrun's integral part
856+
# mr_frac = minrun's fractional part with mr_e bits and
857+
# mask mr_mask
858+
mr_int = n
859+
mr_e = 0
860+
while mr_int >= MAX_MINRUN:
861+
mr_int >>= 1
862+
mr_e += 1
863+
mr_mask = (1 << mr_e) - 1
864+
mr_frac = n & mr_mask
865+
866+
mr_current_frac = 0
867+
while True:
868+
mr_current_frac += mr_frac
869+
assert mr_current_frac >> mr_e <= 1
870+
yield mr_int + (mr_current_frac >> mr_e)
871+
mr_current_frac &= mr_mask
872+
873+
def chew(n, show=False):
874+
if n < 1:
875+
return
876+
877+
sizes = []
878+
tot = 0
879+
for size in gen_minruns(n):
880+
sizes.append(size)
881+
tot += size
882+
if tot >= n:
883+
break
884+
assert tot == n
885+
print(n, len(sizes))
886+
887+
small, large = 32, 64
888+
while len(sizes) > 1:
889+
assert not len(sizes) & 1
890+
assert len(sizes).bit_count() == 1 # i.e., power of 2
891+
assert sum(sizes) == n
892+
assert min(sizes) >= min(n, small)
893+
assert max(sizes) <= large
894+
895+
d = set(sizes)
896+
assert len(d) <= 2
897+
if len(d) == 2:
898+
lo, hi = sorted(d)
899+
assert lo + 1 == hi
900+
901+
mr = n / len(sizes)
902+
for i, s in enumerate(accumulate(sizes, initial=0)):
903+
assert int(mr * i) == s
904+
905+
newsizes = []
906+
for a, b in batched(sizes, 2):
907+
assert abs(a - b) <= 1
908+
newsizes.append(a + b)
909+
sizes = newsizes
910+
smsll = large
911+
large *= 2
912+
913+
assert sizes[0] == n
914+
915+
for n in range(2_000_001):
916+
chew(n)

0 commit comments

Comments
 (0)