Skip to content

Commit ae08cc0

Browse files
committed
Partial partial() support
1 parent 00c51c5 commit ae08cc0

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

async_lru/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@
4242
_T = TypeVar("_T")
4343
_R = TypeVar("_R")
4444
_P = ParamSpec("_P")
45-
_Coro = Coroutine[Any, Any, _R]
46-
_CB = Callable[_P, _Coro[_R]]
47-
_CBP = Union[_CB[_P, _R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
4845

4946

5047
@final
@@ -71,7 +68,7 @@ def cancel(self) -> None:
7168
class _LRUCacheWrapper(Generic[_P, _R]):
7269
def __init__(
7370
self,
74-
fn: _CB[_P, _R],
71+
fn: Callable[_P, Coroutine[Any, Any, _R]],
7572
maxsize: Optional[int],
7673
typed: bool,
7774
ttl: Optional[float],
@@ -299,8 +296,8 @@ def _make_wrapper(
299296
maxsize: Optional[int],
300297
typed: bool,
301298
ttl: Optional[float] = None,
302-
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
303-
def wrapper(fn: _CBP[_P, _R]) -> _LRUCacheWrapper[_P, _R]:
299+
) -> Callable[[Callable[_P, Coroutine[Any, Any, _R]]], _LRUCacheWrapper[_P, _R]]:
300+
def wrapper(fn: Callable[_P, Coroutine[Any, Any, _R]]) -> _LRUCacheWrapper[_P, _R]:
304301
origin = fn
305302

306303
while isinstance(origin, (partial, partialmethod)):
@@ -313,7 +310,7 @@ def wrapper(fn: _CBP[_P, _R]) -> _LRUCacheWrapper[_P, _R]:
313310
if hasattr(fn, "_make_unbound_method"):
314311
fn = fn._make_unbound_method()
315312

316-
return _LRUCacheWrapper(cast(_CB[_P, _R], fn), maxsize, typed, ttl)
313+
return _LRUCacheWrapper(fn, maxsize, typed, ttl)
317314

318315
return wrapper
319316

@@ -324,32 +321,33 @@ def alru_cache(
324321
typed: bool = False,
325322
*,
326323
ttl: Optional[float] = None,
327-
) -> Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]]:
324+
) -> Callable[[Callable[_P, Coroutine[Any, Any, _R]]], _LRUCacheWrapper[_P, _R]]:
328325
...
329326

330327

331328
@overload
332-
def alru_cache(
333-
maxsize: _CBP[_P, _R],
329+
def alru_cache( # type: ignore[misc]
330+
maxsize: Callable[_P, Coroutine[Any, Any, _R]],
334331
/,
335332
) -> _LRUCacheWrapper[_P, _R]:
336333
...
337334

338335

339336
def alru_cache(
340-
maxsize: Union[Optional[int], _CBP[_P, _R]] = 128,
337+
maxsize: Union[Optional[int], Callable[_P, Coroutine[Any, Any, _R]]] = 128,
341338
typed: bool = False,
342339
*,
343340
ttl: Optional[float] = None,
344341
) -> Union[
345-
Callable[[_CBP[_P, _R]], _LRUCacheWrapper[_P, _R]], _LRUCacheWrapper[_P, _R]
342+
Callable[[Callable[_P, Coroutine[Any, Any, _R]]], _LRUCacheWrapper[_P, _R]], _LRUCacheWrapper[_P, _R]
346343
]:
347344
if maxsize is None or isinstance(maxsize, int):
348345
return _make_wrapper(maxsize, typed, ttl)
349346
else:
350347
fn = maxsize
351348

352-
if callable(fn) or hasattr(fn, "_make_unbound_method"):
349+
# partialmethod is not callable() at runtime.
350+
if callable(fn) or hasattr(fn, "_make_unbound_method"): # type: ignore[unreachable]
353351
return _make_wrapper(128, False, None)(fn)
354352

355353
raise NotImplementedError(f"{fn!r} decorating is not supported")

tests/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ async def coro(val: int) -> int:
7878
assert await coro_wrapped2() == 2
7979

8080

81+
async def test_alru_cache_partial_typing() -> None:
82+
"""Test that mypy produces call-arg errors correctly."""
83+
84+
async def coro(val: int) -> int:
85+
return val
86+
87+
coro_wrapped1 = alru_cache(coro)
88+
with pytest.raises(ValueError):
89+
await coro_wrapped1(1, 1) # type: ignore[call-arg]
90+
91+
coro_wrapped2 = alru_cache(partial(coro, 2))
92+
with pytest.raises(ValueError):
93+
await coro_wrapped2(4) == 2 # type: ignore[call-arg]
94+
95+
8196
async def test_alru_cache_await_same_result_async(
8297
check_lru: Callable[..., None]
8398
) -> None:

tests/test_exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def coro(val: int) -> None:
3333
reason="Memory leak is not fixed for PyPy3.9",
3434
condition=sys.implementation.name == "pypy",
3535
)
36-
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None:
36+
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: # type: ignore[misc]
3737
class CustomClass:
3838
...
3939

0 commit comments

Comments
 (0)