Skip to content

Commit cecf2cd

Browse files
cached_property reacts gracefully to hasattr (#100)
* add helper to repeatedly await a coroutine * black formatting * CacheProperty may be awaited 0 or more times
1 parent 204f558 commit cecf2cd

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

asyncstdlib/functools.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Generic,
77
Generator,
88
Optional,
9+
Coroutine,
910
overload,
1011
)
1112

@@ -57,6 +58,25 @@ def __repr__(self) -> str:
5758
return f"{self.__class__.__name__}({self.value!r})"
5859

5960

61+
class _RepeatableCoroutine(Generic[T]):
62+
"""Helper to ``await`` a coroutine also more or less than just once"""
63+
64+
__slots__ = ("call", "args", "kwargs")
65+
66+
def __init__(
67+
self, __call: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any
68+
):
69+
self.call = __call
70+
self.args = args
71+
self.kwargs = kwargs
72+
73+
def __await__(self) -> Generator[Any, Any, T]:
74+
return self.call(*self.args, **self.kwargs).__await__()
75+
76+
def __repr__(self) -> str:
77+
return f"<{self.__class__.__name__} object {self.call.__name__} at {id(self)}>"
78+
79+
6080
@public_module(__name__, "cached_property")
6181
class CachedProperty(Generic[T]):
6282
"""
@@ -129,7 +149,7 @@ def __get__(
129149
) -> Union["CachedProperty[T]", Awaitable[T]]:
130150
if instance is None:
131151
return self
132-
return self._get_attribute(instance)
152+
return _RepeatableCoroutine(self._get_attribute, instance)
133153

134154
async def _get_attribute(self, instance: object) -> T:
135155
value = await self.__wrapped__(instance)

asyncstdlib/heapq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def merge(
111111
def merge(
112112
*iterables: AnyIterable[T],
113113
key: Callable[[T], Awaitable[LT]] = ...,
114-
reverse: bool = ...
114+
reverse: bool = ...,
115115
) -> AsyncIterator[T]:
116116
pass
117117

@@ -126,7 +126,7 @@ def merge(
126126
async def merge(
127127
*iterables: AnyIterable[Any],
128128
key: Optional[Callable[[Any], Any]] = None,
129-
reverse: bool = False
129+
reverse: bool = False,
130130
) -> AsyncIterator[Any]:
131131
"""
132132
Merge all pre-sorted (async) ``iterables`` into a single sorted iterator

unittests/test_itertools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ async def values(gby):
316316
@pytest.mark.parametrize("view", [keys, values])
317317
@sync
318318
async def test_groupby(iterable, key, view):
319-
320319
for akey in (key, awaitify(key)):
321320
assert await view(a.groupby(iterable)) == await view(
322321
itertools.groupby(iterable)

0 commit comments

Comments
 (0)