-
-
Notifications
You must be signed in to change notification settings - Fork 33.4k
Description
Feature or enhancement
Has this already been discussed elsewhere?
No response given
Links to previous discussion of this feature:
https://discuss.python.org/t/optimize-heapq-merge-with-for-break-else-pattern/32931?u=pochmann
(@serhiy-storchaka said to do this on GitHub instead)
Proposal:
@rhettinger just mentioned a willingness for optimized heapq.merge:
How messy and convoluted are we willing to make the code to save a few cycles? In some places like
heapq.mergeandrandom.__init_subclass__, the answer is that we go quite far. Elsewhere, we aim for simplicity.
I propose to optimize heapq.merge by using what I call the for-break(-else) pattern in order to get only the next element of an iterator (using an unconditionalbreak), which in my experience is significantly faster than calling next() or __next__ and can also lead to simpler code (depends on the case).
Benchmark for merging three sorted lists of 10,000 elements each:
6.94 ± 0.05 ms merge_proposal
7.31 ± 0.07 ms merge
Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]
Here's a comparison with the current implementation:
Initialization: Put the non-empty iterators into the heap (note I include the iterator itself instead of its __next__):
Current:
for order, it in enumerate(map(iter, iterables)):
try:
next = it.__next__
h_append([next(), order * direction, next])
except StopIteration:
pass
Proposal:
for order, it in enumerate(map(iter, iterables)):
for value in it:
h_append([value, order * direction, it])
breakMerging while multiple iterators remain:
Current:
while len(h) > 1:
try:
while True:
value, order, next = s = h[0]
yield value
s[0] = next() # raises StopIteration when exhausted
_heapreplace(h, s) # restore heap condition
except StopIteration:
_heappop(h) # remove empty iterator
Proposal:
while len(h) > 1:
while True:
value, order, it = s = h[0]
yield value
for s[0] in it:
_heapreplace(h, s) # restore heap condition
break
else:
_heappop(h) # remove empty iterator
breakEnd when only one iterator remains:
Current:
# fast case when only a single iterator remains
value, order, next = h[0]
yield value
yield from next.__self__
Proposal:
# fast case when only a single iterator remains
value, order, it = h[0]
yield value
yield from itBenchmark script:
import random
from timeit import timeit
from statistics import mean, stdev
from collections import deque
import sys
from heapq import *
def merge_proposal(*iterables, key=None, reverse=False):
'''Merge multiple sorted inputs into a single sorted output.
Similar to sorted(itertools.chain(*iterables)) but returns a generator,
does not pull the data into memory all at once, and assumes that each of
the input streams is already sorted (smallest to largest).
>>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
[0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]
If *key* is not None, applies a key function to each element to determine
its sort order.
>>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len))
['dog', 'cat', 'fish', 'horse', 'kangaroo']
'''
h = []
h_append = h.append
if reverse:
_heapify = _heapify_max
_heappop = _heappop_max
_heapreplace = _heapreplace_max
direction = -1
else:
_heapify = heapify
_heappop = heappop
_heapreplace = heapreplace
direction = 1
if key is None:
for order, it in enumerate(map(iter, iterables)):
for value in it:
h_append([value, order * direction, it])
break
_heapify(h)
while len(h) > 1:
while True:
value, order, it = s = h[0]
yield value
for s[0] in it:
_heapreplace(h, s) # restore heap condition
break
else:
_heappop(h) # remove empty iterator
break
if h:
# fast case when only a single iterator remains
value, order, it = h[0]
yield value
yield from it
return
# Omitted the code for non-None key case
funcs = merge, merge_proposal
n = 10 ** 4
iterables = [
sorted(random.choices(range(n), k=n))
for _ in range(3)
]
expect = list(merge(*iterables))
for f in funcs:
result = list(f(*iterables))
print(result == expect, f.__name__)
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(100):
for f in funcs:
t = timeit(lambda: deque(f(*iterables), 0), number=1)
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print('Python:', sys.version)