Skip to content

Commit d2ae8dd

Browse files
committed
Fix issue with uneccessary iteration
1 parent 66cf276 commit d2ae8dd

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/async_utils/_graphs.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections.abc import Generator, Iterator
18+
import heapq
1819

1920
from . import _typings as t
2021

@@ -25,12 +26,12 @@
2526
class CanHashAndCompareLT(typing.Protocol):
2627
def __hash__(self) -> int: ...
2728

28-
def __lt__(self, other: typing.Self, /) -> bool: ...
29+
def __lt__(self, other: typing.Any, /) -> bool: ...
2930

3031
class CanHashAndCompareGT(typing.Protocol):
3132
def __hash__(self) -> int: ...
3233

33-
def __gt__(self, other: typing.Self, /) -> bool: ...
34+
def __gt__(self, other: typing.Any, /) -> bool: ...
3435

3536
else:
3637

@@ -62,7 +63,7 @@ def cycle(self) -> list[T]:
6263
return self.args[0]
6364

6465

65-
class NodeData[T]:
66+
class NodeData[T: CanHashAndCompare]:
6667
__slots__ = ("dependants", "ndependencies", "node")
6768

6869
def __init__(self, node: T) -> None:
@@ -74,6 +75,9 @@ def __init_subclass__(cls) -> t.Never:
7475
msg = "Don't subclass this"
7576
raise RuntimeError(msg)
7677

78+
def __lt__(self, other: t.Self) -> bool: # falback for tuple sorting
79+
return True
80+
7781
__final__ = True
7882

7983

@@ -174,14 +178,16 @@ def __iter__(self) -> Generator[T, None, None]:
174178
return self.__iter()
175179

176180
def __iter(self) -> Generator[T, None, None]:
177-
while ready := [
178-
i.node for i in self._nodemap.values() if not i.ndependencies
179-
]:
180-
next_node = min(ready)
181-
self._nodemap[next_node].ndependencies = -1
181+
ready = [(n, i) for n, i in self._nodemap.items() if not i.ndependencies]
182+
heapq.heapify(ready)
183+
while ready:
184+
next_node, info = heapq.heappop(ready)
185+
info.ndependencies = -1
182186

183187
yield next_node
184188

185-
for dep in self._nodemap[next_node].dependants:
189+
for dep in info.dependants:
186190
dep_info = self._nodemap[dep]
187191
dep_info.ndependencies -= 1
192+
if not dep_info.ndependencies:
193+
heapq.heappush(ready, (dep, dep_info))

0 commit comments

Comments
 (0)