|
| 1 | +from typing import ( |
| 2 | + Generic, |
| 3 | + AsyncIterator, |
| 4 | + Tuple, |
| 5 | + List, |
| 6 | + Optional, |
| 7 | + Callable, |
| 8 | + Any, |
| 9 | + overload, |
| 10 | + Awaitable, |
| 11 | +) |
| 12 | +import heapq as _heapq |
| 13 | + |
| 14 | +from .builtins import enumerate as a_enumerate, zip as a_zip |
| 15 | +from ._core import aiter, awaitify, ScopedIter, borrow |
| 16 | +from ._typing import AnyIterable, LT, T, SupportsLT |
| 17 | + |
| 18 | + |
| 19 | +class _KeyIter(Generic[LT]): |
| 20 | + __slots__ = ("head", "tail", "reverse", "head_key", "key") |
| 21 | + |
| 22 | + @overload |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + head: T, |
| 26 | + tail: AsyncIterator[T], |
| 27 | + reverse: bool, |
| 28 | + head_key: LT, |
| 29 | + key: Callable[[T], Awaitable[LT]], |
| 30 | + ): |
| 31 | + pass |
| 32 | + |
| 33 | + @overload |
| 34 | + def __init__( |
| 35 | + self, head: LT, tail: AsyncIterator[LT], reverse: bool, head_key: LT, key: None |
| 36 | + ): |
| 37 | + pass |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + head: Any, |
| 42 | + tail: AsyncIterator[Any], |
| 43 | + reverse: bool, |
| 44 | + head_key: LT, |
| 45 | + key: Any, |
| 46 | + ): |
| 47 | + self.head = head |
| 48 | + self.head_key = head_key |
| 49 | + self.tail = tail |
| 50 | + self.key = key |
| 51 | + self.reverse = reverse |
| 52 | + |
| 53 | + @overload |
| 54 | + @classmethod |
| 55 | + def from_iters( |
| 56 | + cls, |
| 57 | + iterables: Tuple[AnyIterable[T], ...], |
| 58 | + reverse: bool, |
| 59 | + key: Callable[[T], Awaitable[LT]], |
| 60 | + ) -> "AsyncIterator[_KeyIter[LT]]": |
| 61 | + pass |
| 62 | + |
| 63 | + @overload |
| 64 | + @classmethod |
| 65 | + def from_iters( |
| 66 | + cls, iterables: Tuple[AnyIterable[LT], ...], reverse: bool, key: None |
| 67 | + ) -> "AsyncIterator[_KeyIter[LT]]": |
| 68 | + pass |
| 69 | + |
| 70 | + @classmethod |
| 71 | + async def from_iters( |
| 72 | + cls, |
| 73 | + iterables: Tuple[AnyIterable[Any], ...], |
| 74 | + reverse: bool, |
| 75 | + key: Optional[Callable[[Any], Any]], |
| 76 | + ) -> "AsyncIterator[_KeyIter[Any]]": |
| 77 | + for iterable in iterables: |
| 78 | + iterator = aiter(iterable) |
| 79 | + try: |
| 80 | + head = await iterator.__anext__() |
| 81 | + except StopAsyncIteration: |
| 82 | + pass |
| 83 | + else: |
| 84 | + head_key = await key(head) if key is not None else head |
| 85 | + yield cls(head, iterator, reverse, head_key, key) |
| 86 | + |
| 87 | + async def pull_head(self) -> bool: |
| 88 | + """ |
| 89 | + Pull the next ``head`` element from the iterator and signal success |
| 90 | + """ |
| 91 | + try: |
| 92 | + self.head = head = await self.tail.__anext__() |
| 93 | + except StopAsyncIteration: |
| 94 | + return False |
| 95 | + else: |
| 96 | + self.head_key = await self.key(head) if self.key is not None else head |
| 97 | + return True |
| 98 | + |
| 99 | + def __lt__(self, other: "_KeyIter[LT]") -> bool: |
| 100 | + return self.reverse ^ (self.head_key < other.head_key) |
| 101 | + |
| 102 | + |
| 103 | +@overload |
| 104 | +def merge( |
| 105 | + *iterables: AnyIterable[LT], key: None = ..., reverse: bool = ... |
| 106 | +) -> AsyncIterator[LT]: |
| 107 | + pass |
| 108 | + |
| 109 | + |
| 110 | +@overload |
| 111 | +def merge( |
| 112 | + *iterables: AnyIterable[T], |
| 113 | + key: Callable[[T], Awaitable[LT]] = ..., |
| 114 | + reverse: bool = ... |
| 115 | +) -> AsyncIterator[T]: |
| 116 | + pass |
| 117 | + |
| 118 | + |
| 119 | +@overload |
| 120 | +def merge( |
| 121 | + *iterables: AnyIterable[T], key: Callable[[T], LT] = ..., reverse: bool = ... |
| 122 | +) -> AsyncIterator[T]: |
| 123 | + pass |
| 124 | + |
| 125 | + |
| 126 | +async def merge( |
| 127 | + *iterables: AnyIterable[Any], |
| 128 | + key: Optional[Callable[[Any], Any]] = None, |
| 129 | + reverse: bool = False |
| 130 | +) -> AsyncIterator[Any]: |
| 131 | + """ |
| 132 | + Merge all pre-sorted (async) ``iterables`` into a single sorted iterator |
| 133 | +
|
| 134 | + This works similar to ``sorted(chain(*iterables), key=key, reverse=reverse)`` but |
| 135 | + operates lazily: at any moment only one item of each iterable is stored for the |
| 136 | + comparison. This allows merging streams of pre-sorted items, such as timestamped |
| 137 | + records from multiple sources. |
| 138 | +
|
| 139 | + The optional ``key`` argument specifies a one-argument (async) callable, which |
| 140 | + provides a substitute for determining the sort order of each item. |
| 141 | + The special value and default :py:data:`None` represents the identity function, |
| 142 | + comparing items directly. |
| 143 | +
|
| 144 | + The default sort order is ascending, that is items with ``a < b`` imply ``a`` |
| 145 | + is yielded before ``b``. Use ``reverse=True`` for descending sort order. |
| 146 | + The ``iterables`` must be pre-sorted in the same order. |
| 147 | + """ |
| 148 | + a_key = awaitify(key) if key is not None else None |
| 149 | + # sortable iterators with (reverse) position to ensure stable sort for ties |
| 150 | + iter_heap: List[Tuple[_KeyIter[Any], int]] = [ |
| 151 | + (itr, idx if not reverse else -idx) |
| 152 | + async for idx, itr in a_enumerate( |
| 153 | + _KeyIter.from_iters(iterables, reverse, a_key) |
| 154 | + ) |
| 155 | + ] |
| 156 | + try: |
| 157 | + _heapq.heapify(iter_heap) |
| 158 | + # there are at least two iterators that need merging |
| 159 | + while len(iter_heap) > 1: |
| 160 | + while True: |
| 161 | + itr, idx = iter_heap[0] |
| 162 | + yield itr.head |
| 163 | + if await itr.pull_head(): |
| 164 | + _heapq.heapreplace(iter_heap, (itr, idx)) |
| 165 | + else: |
| 166 | + _heapq.heappop(iter_heap) |
| 167 | + break |
| 168 | + # there is only one iterator left, no need for merging |
| 169 | + if iter_heap: |
| 170 | + itr, idx = iter_heap[0] |
| 171 | + yield itr.head |
| 172 | + async for item in itr.tail: |
| 173 | + yield item |
| 174 | + finally: |
| 175 | + for itr, _ in iter_heap: |
| 176 | + if hasattr(itr.tail, "aclose"): |
| 177 | + await itr.tail.aclose() # type: ignore |
| 178 | + |
| 179 | + |
| 180 | +class ReverseLT(Generic[LT]): |
| 181 | + """Helper to reverse ``a < b`` ordering""" |
| 182 | + |
| 183 | + __slots__ = ("key",) |
| 184 | + |
| 185 | + def __init__(self, key: LT): |
| 186 | + self.key = key |
| 187 | + |
| 188 | + def __lt__(self, other: "ReverseLT[LT]") -> bool: |
| 189 | + return other.key < self.key |
| 190 | + |
| 191 | + |
| 192 | +# Python's heapq provides a *min*-heap |
| 193 | +# When finding the n largest items, heapq tracks the *minimum* item still large enough. |
| 194 | +# In other words, during search we maintain opposite sort order than what is requested. |
| 195 | +# We turn the min-heap into a max-sort in the end. |
| 196 | +async def _largest( |
| 197 | + iterable: AsyncIterator[T], |
| 198 | + n: int, |
| 199 | + key: Callable[[T], Awaitable[LT]], |
| 200 | + reverse: bool, |
| 201 | +) -> List[T]: |
| 202 | + ordered: Callable[[SupportsLT], SupportsLT] = ( |
| 203 | + ReverseLT if reverse else lambda x: x # type: ignore |
| 204 | + ) |
| 205 | + async with ScopedIter(iterable) as iterator: |
| 206 | + # assign an ordering to items to solve ties |
| 207 | + order_sign = -1 if reverse else 1 |
| 208 | + n_heap = [ |
| 209 | + (ordered(await key(item)), index * order_sign, item) |
| 210 | + async for index, item in a_zip(range(n), borrow(iterator)) |
| 211 | + ] |
| 212 | + if not n_heap: |
| 213 | + return [] |
| 214 | + _heapq.heapify(n_heap) |
| 215 | + worst_key = n_heap[0][0] |
| 216 | + next_index = n * order_sign |
| 217 | + async for item in iterator: |
| 218 | + item_key = ordered(await key(item)) |
| 219 | + if worst_key < item_key: |
| 220 | + _heapq.heapreplace(n_heap, (item_key, next_index, item)) |
| 221 | + worst_key = n_heap[0][0] |
| 222 | + next_index += 1 * order_sign |
| 223 | + n_heap.sort(reverse=True) |
| 224 | + return [item for _, _, item in n_heap] |
| 225 | + |
| 226 | + |
| 227 | +async def _identity(x: T) -> T: |
| 228 | + return x |
| 229 | + |
| 230 | + |
| 231 | +async def nlargest( |
| 232 | + iterable: AsyncIterator[T], |
| 233 | + n: int, |
| 234 | + key: Optional[Callable[[Any], Awaitable[Any]]] = None, |
| 235 | +) -> List[T]: |
| 236 | + """ |
| 237 | + Return a sorted list of the ``n`` largest elements from the (async) iterable |
| 238 | +
|
| 239 | + The optional ``key`` argument specifies a one-argument (async) callable, which |
| 240 | + provides a substitute for determining the sort order of each item. |
| 241 | + The special value and default :py:data:`None` represents the identity functions, |
| 242 | + comparing items directly. |
| 243 | +
|
| 244 | + The result is equivalent to ``sorted(iterable, key=key, reverse=True)[:n]``, |
| 245 | + but ``iterable`` is consumed lazily and items are discarded eagerly. |
| 246 | + """ |
| 247 | + a_key: Callable[[Any], Awaitable[Any]] = ( |
| 248 | + awaitify(key) if key is not None else _identity # type: ignore |
| 249 | + ) |
| 250 | + return await _largest(iterable=iterable, n=n, key=a_key, reverse=False) |
| 251 | + |
| 252 | + |
| 253 | +async def nsmallest( |
| 254 | + iterable: AsyncIterator[T], |
| 255 | + n: int, |
| 256 | + key: Optional[Callable[[Any], Awaitable[Any]]] = None, |
| 257 | +) -> List[T]: |
| 258 | + """ |
| 259 | + Return a sorted list of the ``n`` smallest elements from the (async) iterable |
| 260 | +
|
| 261 | + Provides the reverse functionality to :py:func:`~.nlargest`. |
| 262 | + """ |
| 263 | + a_key: Callable[[Any], Awaitable[Any]] = ( |
| 264 | + awaitify(key) if key is not None else _identity # type: ignore |
| 265 | + ) |
| 266 | + return await _largest(iterable=iterable, n=n, key=a_key, reverse=True) |
0 commit comments