diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index d5612a86..4db11f11 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -7,9 +7,6 @@ on: - '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6 tags: [ 'v*' ] pull_request: - branches: - - master - - '[0-9].[0-9]+' jobs: diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000..71ff0a8f --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,24 @@ +[mypy] +files = async_lru, tests +check_untyped_defs = True +follow_imports_for_stubs = True +disallow_any_decorated = True +disallow_any_generics = True +disallow_any_unimported = True +disallow_incomplete_defs = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_decorators = True +disallow_untyped_defs = True +enable_error_code = ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable +implicit_reexport = False +no_implicit_optional = True +pretty = True +show_column_numbers = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_return_any = True +warn_unreachable = True +warn_unused_ignores = True diff --git a/async_lru/__init__.py b/async_lru/__init__.py index 447e9cdb..a800ac84 100644 --- a/async_lru/__init__.py +++ b/async_lru/__init__.py @@ -4,7 +4,6 @@ import sys from functools import _CacheInfo, _make_key, partial, partialmethod from typing import ( - Any, Callable, Coroutine, Generic, @@ -16,12 +15,17 @@ TypedDict, TypeVar, Union, - cast, final, overload, ) +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + if sys.version_info >= (3, 11): from typing import Self else: @@ -38,9 +42,7 @@ _T = TypeVar("_T") _R = TypeVar("_R") -_Coro = Coroutine[Any, Any, _R] -_CB = Callable[..., _Coro[_R]] -_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"] +_P = ParamSpec("_P") @final @@ -64,10 +66,10 @@ def cancel(self) -> None: @final -class _LRUCacheWrapper(Generic[_R]): +class _LRUCacheWrapper(Generic[_P, _R]): def __init__( self, - fn: _CB[_R], + fn: Callable[_P, Coroutine[object, object, _R]], maxsize: Optional[int], typed: bool, ttl: Optional[float], @@ -110,7 +112,7 @@ def __init__( self.__misses = 0 self.__tasks: Set["asyncio.Task[_R]"] = set() - def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool: + def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool: key = _make_key(args, kwargs, self.__typed) cache_item = self.__cache.pop(key, None) @@ -192,7 +194,7 @@ def _task_done_callback( fut.set_result(task.result()) - async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: + async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R: if self.__closed: raise RuntimeError(f"alru_cache is closed for {self}") @@ -211,7 +213,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: fut = loop.create_future() coro = self.__wrapped__(*fn_args, **fn_kwargs) - task: asyncio.Task[_R] = loop.create_task(coro) + task = loop.create_task(coro) self.__tasks.add(task) task.add_done_callback(partial(self._task_done_callback, fut, key)) @@ -224,9 +226,19 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: self._cache_miss(key) return await asyncio.shield(fut) + @overload + def __get__(self, instance: _T, owner: None) -> Self: + ... + + @overload + def __get__( + self, instance: _T, owner: Type[_T] + ) -> "_LRUCacheWrapperInstanceMethod[_P, _R, _T]": + ... + def __get__( self, instance: _T, owner: Optional[Type[_T]] - ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]: + ) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_P, _R, _T]"]: if owner is None: return self else: @@ -234,10 +246,10 @@ def __get__( @final -class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]): +class _LRUCacheWrapperInstanceMethod(Generic[_P, _R, _T]): def __init__( self, - wrapper: _LRUCacheWrapper[_R], + wrapper: _LRUCacheWrapper[_P, _R], instance: _T, ) -> None: try: @@ -272,7 +284,7 @@ def __init__( self.__instance = instance self.__wrapper = wrapper - def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool: + def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool: return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs) def cache_clear(self) -> None: @@ -289,16 +301,18 @@ def cache_info(self) -> _CacheInfo: def cache_parameters(self) -> _CacheParameters: return self.__wrapper.cache_parameters() - async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R: - return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) + async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R: + return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) # type: ignore[arg-type] def _make_wrapper( maxsize: Optional[int], typed: bool, ttl: Optional[float] = None, -) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]: - def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]: +) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]: + def wrapper( + fn: Callable[_P, Coroutine[object, object, _R]] + ) -> _LRUCacheWrapper[_P, _R]: origin = fn while isinstance(origin, (partial, partialmethod)): @@ -311,7 +325,7 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]: if hasattr(fn, "_make_unbound_method"): fn = fn._make_unbound_method() - wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl) + wrapper = _LRUCacheWrapper(fn, maxsize, typed, ttl) if sys.version_info >= (3, 12): wrapper = inspect.markcoroutinefunction(wrapper) return wrapper @@ -325,30 +339,34 @@ def alru_cache( typed: bool = False, *, ttl: Optional[float] = None, -) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]: +) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]: ... @overload def alru_cache( - maxsize: _CBP[_R], + maxsize: Callable[_P, Coroutine[object, object, _R]], /, -) -> _LRUCacheWrapper[_R]: +) -> _LRUCacheWrapper[_P, _R]: ... def alru_cache( - maxsize: Union[Optional[int], _CBP[_R]] = 128, + maxsize: Union[Optional[int], Callable[_P, Coroutine[object, object, _R]]] = 128, typed: bool = False, *, ttl: Optional[float] = None, -) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]: +) -> Union[ + Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]], + _LRUCacheWrapper[_P, _R], +]: if maxsize is None or isinstance(maxsize, int): return _make_wrapper(maxsize, typed, ttl) else: - fn = cast(_CB[_R], maxsize) + fn = maxsize - if callable(fn) or hasattr(fn, "_make_unbound_method"): + # partialmethod is not callable() at runtime. + if callable(fn) or hasattr(fn, "_make_unbound_method"): # type: ignore[unreachable] return _make_wrapper(128, False, None)(fn) raise NotImplementedError(f"{fn!r} decorating is not supported") diff --git a/setup.cfg b/setup.cfg index 95fde295..d812e394 100644 --- a/setup.cfg +++ b/setup.cfg @@ -82,8 +82,3 @@ junit_family=xunit2 asyncio_mode=auto timeout=15 xfail_strict = true - -[mypy] -strict=True -pretty=True -packages=async_lru, tests diff --git a/tests/conftest.py b/tests/conftest.py index 36147d4d..b0456ceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,26 @@ +import sys from functools import _CacheInfo -from typing import Callable +from typing import Callable, TypeVar import pytest -from async_lru import _R, _LRUCacheWrapper +from async_lru import _LRUCacheWrapper + + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + + +_T = TypeVar("_T") +_P = ParamSpec("_P") @pytest.fixture def check_lru() -> Callable[..., None]: def _check_lru( - wrapped: _LRUCacheWrapper[_R], + wrapped: _LRUCacheWrapper[_P, _T], *, hits: int, misses: int, diff --git a/tests/test_basic.py b/tests/test_basic.py index ef234f0e..90cc1fb9 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -88,6 +88,21 @@ async def coro(val: int) -> int: assert await coro_wrapped2() == 2 +async def test_alru_cache_partial_typing() -> None: + """Test that mypy produces call-arg errors correctly.""" + + async def coro(val: int) -> int: + return val + + coro_wrapped1 = alru_cache(coro) + with pytest.raises(TypeError): + await coro_wrapped1(1, 1) # type: ignore[call-arg] + + coro_wrapped2 = alru_cache(partial(coro, 2)) + with pytest.raises(TypeError): + await coro_wrapped2(4) == 2 # type: ignore[call-arg] + + async def test_alru_cache_await_same_result_async( check_lru: Callable[..., None] ) -> None: diff --git a/tests/test_exception.py b/tests/test_exception.py index 054ea3ae..51b19aef 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -33,7 +33,7 @@ async def coro(val: int) -> None: reason="Memory leak is not fixed for PyPy3.9", condition=sys.implementation.name == "pypy", ) -async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: +async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: # type: ignore[misc] class CustomClass: ...