Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
67d0c99
Use ParamSpec for wrapped signatures
Dreamsorcerer Jul 25, 2023
79cc7e0
Lint
Dreamsorcerer Jul 25, 2023
30c5aa6
Lint
Dreamsorcerer Jul 25, 2023
fa601e1
Merge branch 'master' into paramspec
Dreamsorcerer Oct 26, 2023
58cf7f0
Update setup.cfg
Dreamsorcerer Oct 26, 2023
0e56d88
Update __init__.py
Dreamsorcerer Oct 26, 2023
c8156f9
Merge branch 'master' into paramspec
Dreamsorcerer May 23, 2024
f2e28be
Merge branch 'master' into paramspec
Dreamsorcerer Aug 2, 2024
6577c02
Merge branch 'master' into paramspec
Dreamsorcerer Nov 14, 2024
a07fe6b
Merge branch 'master' into paramspec
asvetlov Jan 2, 2025
00c51c5
Merge branch 'master' into paramspec
asvetlov Jan 24, 2025
f358ce3
Run CI on any PR
Dreamsorcerer Feb 1, 2025
ae08cc0
Partial partial() support
Dreamsorcerer Feb 2, 2025
0f20348
Merge branch 'paramspec' of github.com:aio-libs/async-lru into paramspec
Dreamsorcerer Feb 2, 2025
bf9442d
Update __init__.py
Dreamsorcerer Feb 2, 2025
8067aa7
Update test_basic.py
Dreamsorcerer Feb 2, 2025
66d8591
Update __init__.py
Dreamsorcerer Feb 2, 2025
1d8ffe9
Tweak
Dreamsorcerer Feb 2, 2025
1bfa687
Update __init__.py
Dreamsorcerer Feb 2, 2025
5ffe115
Use overload
Dreamsorcerer Feb 2, 2025
51d93cd
Merge branch 'paramspec' of github.com:aio-libs/async-lru into paramspec
Dreamsorcerer Feb 2, 2025
0950988
Formatting
Dreamsorcerer Feb 2, 2025
bd6a5ee
Merge branch 'master' into paramspec
Dreamsorcerer Jul 11, 2025
21b1eee
Update tests/conftest.py
Dreamsorcerer Jul 11, 2025
e5dea23
Merge branch 'master' into paramspec
Dreamsorcerer Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ on:
- '[0-9].[0-9]+' # matches to backport branches, e.g. 3.6
tags: [ 'v*' ]
pull_request:
branches:
- master
- '[0-9].[0-9]+'


jobs:
Expand Down
24 changes: 24 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[mypy]
files = async_lru, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
disallow_any_generics = True
disallow_any_unimported = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
enable_error_code = ignore-without-code, possibly-undefined, redundant-expr, redundant-self, truthy-bool, truthy-iterable, unused-awaitable
implicit_reexport = False
no_implicit_optional = True
pretty = True
show_column_numbers = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_return_any = True
warn_unreachable = True
warn_unused_ignores = True
70 changes: 44 additions & 26 deletions async_lru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sys
from functools import _CacheInfo, _make_key, partial, partialmethod
from typing import (
Any,
Callable,
Coroutine,
Generic,
Expand All @@ -16,12 +15,17 @@
TypedDict,
TypeVar,
Union,
cast,
final,
overload,
)


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


if sys.version_info >= (3, 11):
from typing import Self
else:
Expand All @@ -38,9 +42,7 @@

_T = TypeVar("_T")
_R = TypeVar("_R")
_Coro = Coroutine[Any, Any, _R]
_CB = Callable[..., _Coro[_R]]
_CBP = Union[_CB[_R], "partial[_Coro[_R]]", "partialmethod[_Coro[_R]]"]
_P = ParamSpec("_P")


@final
Expand All @@ -64,10 +66,10 @@ def cancel(self) -> None:


@final
class _LRUCacheWrapper(Generic[_R]):
class _LRUCacheWrapper(Generic[_P, _R]):
def __init__(
self,
fn: _CB[_R],
fn: Callable[_P, Coroutine[object, object, _R]],
maxsize: Optional[int],
typed: bool,
ttl: Optional[float],
Expand Down Expand Up @@ -110,7 +112,7 @@ def __init__(
self.__misses = 0
self.__tasks: Set["asyncio.Task[_R]"] = set()

def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool:
key = _make_key(args, kwargs, self.__typed)

cache_item = self.__cache.pop(key, None)
Expand Down Expand Up @@ -192,7 +194,7 @@ def _task_done_callback(

fut.set_result(task.result())

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
if self.__closed:
raise RuntimeError(f"alru_cache is closed for {self}")

Expand All @@ -211,7 +213,7 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:

fut = loop.create_future()
coro = self.__wrapped__(*fn_args, **fn_kwargs)
task: asyncio.Task[_R] = loop.create_task(coro)
task = loop.create_task(coro)
self.__tasks.add(task)
task.add_done_callback(partial(self._task_done_callback, fut, key))

Expand All @@ -224,20 +226,30 @@ async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
self._cache_miss(key)
return await asyncio.shield(fut)

@overload
def __get__(self, instance: _T, owner: None) -> Self:
...

@overload
def __get__(
self, instance: _T, owner: Type[_T]
) -> "_LRUCacheWrapperInstanceMethod[_P, _R, _T]":
...

def __get__(
self, instance: _T, owner: Optional[Type[_T]]
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_R, _T]"]:
) -> Union[Self, "_LRUCacheWrapperInstanceMethod[_P, _R, _T]"]:
if owner is None:
return self
else:
return _LRUCacheWrapperInstanceMethod(self, instance)


@final
class _LRUCacheWrapperInstanceMethod(Generic[_R, _T]):
class _LRUCacheWrapperInstanceMethod(Generic[_P, _R, _T]):
def __init__(
self,
wrapper: _LRUCacheWrapper[_R],
wrapper: _LRUCacheWrapper[_P, _R],
instance: _T,
) -> None:
try:
Expand Down Expand Up @@ -272,7 +284,7 @@ def __init__(
self.__instance = instance
self.__wrapper = wrapper

def cache_invalidate(self, /, *args: Hashable, **kwargs: Any) -> bool:
def cache_invalidate(self, /, *args: _P.args, **kwargs: _P.kwargs) -> bool:
return self.__wrapper.cache_invalidate(self.__instance, *args, **kwargs)

def cache_clear(self) -> None:
Expand All @@ -289,16 +301,18 @@ def cache_info(self) -> _CacheInfo:
def cache_parameters(self) -> _CacheParameters:
return self.__wrapper.cache_parameters()

async def __call__(self, /, *fn_args: Any, **fn_kwargs: Any) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs)
async def __call__(self, /, *fn_args: _P.args, **fn_kwargs: _P.kwargs) -> _R:
return await self.__wrapper(self.__instance, *fn_args, **fn_kwargs) # type: ignore[arg-type]


def _make_wrapper(
maxsize: Optional[int],
typed: bool,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]:
def wrapper(
fn: Callable[_P, Coroutine[object, object, _R]]
) -> _LRUCacheWrapper[_P, _R]:
origin = fn

while isinstance(origin, (partial, partialmethod)):
Expand All @@ -311,7 +325,7 @@ def wrapper(fn: _CBP[_R]) -> _LRUCacheWrapper[_R]:
if hasattr(fn, "_make_unbound_method"):
fn = fn._make_unbound_method()

wrapper = _LRUCacheWrapper(cast(_CB[_R], fn), maxsize, typed, ttl)
wrapper = _LRUCacheWrapper(fn, maxsize, typed, ttl)
if sys.version_info >= (3, 12):
wrapper = inspect.markcoroutinefunction(wrapper)
return wrapper
Expand All @@ -325,30 +339,34 @@ def alru_cache(
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Callable[[_CBP[_R]], _LRUCacheWrapper[_R]]:
) -> Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]]:
...


@overload
def alru_cache(
maxsize: _CBP[_R],
maxsize: Callable[_P, Coroutine[object, object, _R]],
/,
) -> _LRUCacheWrapper[_R]:
) -> _LRUCacheWrapper[_P, _R]:
...


def alru_cache(
maxsize: Union[Optional[int], _CBP[_R]] = 128,
maxsize: Union[Optional[int], Callable[_P, Coroutine[object, object, _R]]] = 128,
typed: bool = False,
*,
ttl: Optional[float] = None,
) -> Union[Callable[[_CBP[_R]], _LRUCacheWrapper[_R]], _LRUCacheWrapper[_R]]:
) -> Union[
Callable[[Callable[_P, Coroutine[object, object, _R]]], _LRUCacheWrapper[_P, _R]],
_LRUCacheWrapper[_P, _R],
]:
if maxsize is None or isinstance(maxsize, int):
return _make_wrapper(maxsize, typed, ttl)
else:
fn = cast(_CB[_R], maxsize)
fn = maxsize

if callable(fn) or hasattr(fn, "_make_unbound_method"):
# partialmethod is not callable() at runtime.
if callable(fn) or hasattr(fn, "_make_unbound_method"): # type: ignore[unreachable]
return _make_wrapper(128, False, None)(fn)

raise NotImplementedError(f"{fn!r} decorating is not supported")
5 changes: 0 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,3 @@ junit_family=xunit2
asyncio_mode=auto
timeout=15
xfail_strict = true

[mypy]
strict=True
pretty=True
packages=async_lru, tests
19 changes: 15 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import sys
from functools import _CacheInfo
from typing import Callable
from typing import Callable, TypeVar

import pytest

from async_lru import _R, _LRUCacheWrapper
from async_lru import _LRUCacheWrapper


if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


_T = TypeVar("_T")
_P = ParamSpec("_P")


@pytest.fixture
def check_lru() -> Callable[..., None]:
def check_lru() -> Callable[..., None]: # type: ignore[misc]
def _check_lru(
wrapped: _LRUCacheWrapper[_R],
wrapped: _LRUCacheWrapper[_P, _T],
*,
hits: int,
misses: int,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ async def coro(val: int) -> int:
assert await coro_wrapped2() == 2


async def test_alru_cache_partial_typing() -> None:
"""Test that mypy produces call-arg errors correctly."""

async def coro(val: int) -> int:
return val

coro_wrapped1 = alru_cache(coro)
with pytest.raises(TypeError):
await coro_wrapped1(1, 1) # type: ignore[call-arg]

coro_wrapped2 = alru_cache(partial(coro, 2))
with pytest.raises(TypeError):
await coro_wrapped2(4) == 2 # type: ignore[call-arg]


async def test_alru_cache_await_same_result_async(
check_lru: Callable[..., None]
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def coro(val: int) -> None:
reason="Memory leak is not fixed for PyPy3.9",
condition=sys.implementation.name == "pypy",
)
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None:
async def test_alru_exception_reference_cleanup(check_lru: Callable[..., None]) -> None: # type: ignore[misc]
class CustomClass:
...

Expand Down
Loading