Skip to content

Commit e6331c7

Browse files
functools typing (#72)
* typing for internal modules * typing for functools
1 parent 820a544 commit e6331c7

File tree

5 files changed

+112
-43
lines changed

5 files changed

+112
-43
lines changed

asyncstdlib/_lrucache.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
Tuple,
1414
Dict,
1515
Union,
16+
Hashable,
17+
overload,
18+
cast,
1619
)
1720
from functools import update_wrapper
1821
from collections import OrderedDict
1922

20-
from ._typing import Protocol, TypedDict, C
23+
from ._typing import Protocol, TypedDict, AC
2124
from ._utility import public_module
2225

2326

@@ -56,16 +59,16 @@ class CacheParameters(TypedDict):
5659

5760

5861
@public_module("asyncstdlib.functools")
59-
class LRUAsyncCallable(Protocol[C]):
62+
class LRUAsyncCallable(Protocol[AC]):
6063
"""
6164
:py:class:`~typing.Protocol` of a LRU cache wrapping a callable to an awaitable
6265
"""
6366

6467
#: The callable wrapped by this cache
65-
__wrapped__: C
68+
__wrapped__: AC
6669

6770
#: Get the result of ``await __wrapped__(...)`` from the cache or evaluation
68-
__call__: C
71+
__call__: AC
6972

7073
def cache_parameters(self) -> CacheParameters:
7174
"""Get the parameters of the cache"""
@@ -80,8 +83,22 @@ def cache_clear(self) -> None:
8083
"""Evict all call argument patterns and their results from the cache"""
8184

8285

86+
@overload
87+
def lru_cache(maxsize: AC, typed: bool = ...) -> LRUAsyncCallable[AC]:
88+
...
89+
90+
91+
@overload
92+
def lru_cache(
93+
maxsize: Optional[int] = ..., typed: bool = ...
94+
) -> Callable[[AC], LRUAsyncCallable[AC]]:
95+
...
96+
97+
8398
@public_module("asyncstdlib.functools")
84-
def lru_cache(maxsize: Optional[Union[int, Callable]] = 128, typed: bool = False):
99+
def lru_cache(
100+
maxsize: Optional[Union[int, AC]] = 128, typed: bool = False
101+
) -> Union[LRUAsyncCallable[AC], Callable[[AC], LRUAsyncCallable[AC]]]:
85102
"""
86103
Least Recently Used cache for async functions
87104
@@ -127,14 +144,16 @@ def lru_cache(maxsize: Optional[Union[int, Callable]] = 128, typed: bool = False
127144
maxsize = 0 if maxsize < 0 else maxsize
128145
elif callable(maxsize):
129146
# used as function decorator, first arg is the function to be wrapped
130-
fast_wrapper = _bounded_lru(function=maxsize, maxsize=128, typed=typed)
147+
fast_wrapper = _bounded_lru(
148+
function=cast(AC, maxsize), maxsize=128, typed=typed
149+
)
131150
return update_wrapper(fast_wrapper, maxsize)
132151
elif maxsize is not None:
133152
raise TypeError(
134153
"first argument to 'lru_cache' must be an int, a callable or None"
135154
)
136155

137-
def lru_decorator(function: C) -> LRUAsyncCallable[C]:
156+
def lru_decorator(function: AC) -> LRUAsyncCallable[AC]:
138157
assert not callable(maxsize)
139158
if maxsize is None:
140159
wrapper = _unbound_lru(function=function, typed=typed)
@@ -148,28 +167,44 @@ def lru_decorator(function: C) -> LRUAsyncCallable[C]:
148167

149168

150169
class CallKey:
170+
"""Representation of a call suitable as a ``dict`` key and for equality testing"""
171+
151172
__slots__ = "_hash", "values"
152173

153-
def __init__(self, values):
174+
def __init__(self, values: Tuple[Hashable, ...]):
175+
# we may need the hash very often so caching helps
154176
self._hash = hash(values)
155177
self.values = values
156178

157-
def __hash__(self):
179+
def __hash__(self) -> int:
158180
return self._hash
159181

160-
def __eq__(self, other):
161-
return type(self) is type(other) and self.values == other.values
182+
def __eq__(self, other: object) -> bool:
183+
return type(self) is type(other) and self.values == other.values # type: ignore
162184

163-
# DEVIATION: fast_types tuple vs. set contains is faster +40%/pypy vs. +20%/cpython
185+
# DEVIATION: fast_types tuple vs. set contains is faster +40%/pypy and +20%/cpython
164186
@classmethod
165187
def from_call(
166188
cls,
167-
args: Tuple,
168-
kwds: Dict,
189+
args: Tuple[Hashable, ...],
190+
kwds: Dict[str, Hashable],
169191
typed: bool,
170-
fast_types=(int, str),
171-
kwarg_sentinel=object(),
192+
fast_types: Tuple[type, ...] = (int, str),
193+
kwarg_sentinel: Hashable = object(),
172194
) -> "Union[CallKey, int, str]":
195+
"""
196+
Create a key based on call arguments
197+
198+
:param args: positional call arguments
199+
:param kwds: keyword call arguments
200+
:param typed: whether to compare arguments by strict type as well
201+
:param fast_types: types which do not need wrapping
202+
:param kwarg_sentinel: internal marker, stick with default
203+
:return: representation of the call arguments
204+
205+
The `fast_types` and `kwarg_sentinel` primarily are arguments to make them
206+
pre-initialised locals for speed; their defaults should be optimal already.
207+
"""
173208
key = args if not kwds else (*args, kwarg_sentinel, *kwds.items())
174209
if typed:
175210
key += (
@@ -178,16 +213,16 @@ def from_call(
178213
else (*map(type, args), *map(type, kwds.values()))
179214
)
180215
elif len(key) == 1 and type(key[0]) in fast_types:
181-
return key[0]
216+
return key[0] # type: ignore
182217
return cls(key)
183218

184219

185-
def _empty_lru(function: C, typed: bool) -> LRUAsyncCallable[C]:
220+
def _empty_lru(function: AC, typed: bool) -> LRUAsyncCallable[AC]:
186221
"""Wrap the async ``function`` in an async LRU cache without any capacity"""
187222
# cache statistics
188223
misses = 0
189224

190-
async def wrapper(*args, **kwargs):
225+
async def wrapper(*args: Hashable, **kwargs: Hashable) -> Any:
191226
nonlocal misses
192227
misses += 1
193228
return await function(*args, **kwargs)
@@ -198,7 +233,7 @@ def cache_parameters() -> CacheParameters:
198233
def cache_info() -> CacheInfo:
199234
return CacheInfo(0, misses, 0, 0)
200235

201-
def cache_clear():
236+
def cache_clear() -> None:
202237
nonlocal misses
203238
misses = 0
204239

@@ -208,7 +243,7 @@ def cache_clear():
208243
return wrapper # type: ignore
209244

210245

211-
def _unbound_lru(function: C, typed: bool) -> LRUAsyncCallable[C]:
246+
def _unbound_lru(function: AC, typed: bool) -> LRUAsyncCallable[AC]:
212247
"""Wrap the async ``function`` in an async LRU cache with infinite capacity"""
213248
# local lookup
214249
make_key = CallKey.from_call
@@ -218,7 +253,7 @@ def _unbound_lru(function: C, typed: bool) -> LRUAsyncCallable[C]:
218253
# cache content
219254
cache: Dict[Union[CallKey, int, str], Any] = {}
220255

221-
async def wrapper(*args, **kwargs):
256+
async def wrapper(*args: Hashable, **kwargs: Hashable) -> Any:
222257
nonlocal hits, misses
223258
key = make_key(args, kwargs, typed=typed)
224259
try:
@@ -241,7 +276,7 @@ def cache_parameters() -> CacheParameters:
241276
def cache_info() -> CacheInfo:
242277
return CacheInfo(hits, misses, None, len(cache))
243278

244-
def cache_clear():
279+
def cache_clear() -> None:
245280
nonlocal hits, misses
246281
misses = 0
247282
hits = 0
@@ -253,7 +288,7 @@ def cache_clear():
253288
return wrapper # type: ignore
254289

255290

256-
def _bounded_lru(function: C, typed: bool, maxsize: int) -> LRUAsyncCallable[C]:
291+
def _bounded_lru(function: AC, typed: bool, maxsize: int) -> LRUAsyncCallable[AC]:
257292
"""Wrap the async ``function`` in an async LRU cache with fixed capacity"""
258293
# local lookup
259294
make_key = CallKey.from_call
@@ -264,7 +299,7 @@ def _bounded_lru(function: C, typed: bool, maxsize: int) -> LRUAsyncCallable[C]:
264299
cache: OrderedDict[Union[int, str, CallKey], Any] = OrderedDict()
265300
filled = False
266301

267-
async def wrapper(*args, **kwargs):
302+
async def wrapper(*args: Hashable, **kwargs: Hashable) -> Any:
268303
nonlocal hits, misses, filled
269304
key = make_key(args, kwargs, typed=typed)
270305
try:
@@ -298,7 +333,7 @@ def cache_parameters() -> CacheParameters:
298333
def cache_info() -> CacheInfo:
299334
return CacheInfo(hits, misses, maxsize, len(cache))
300335

301-
def cache_clear():
336+
def cache_clear() -> None:
302337
nonlocal hits, misses, filled
303338
misses = 0
304339
hits = 0

asyncstdlib/_typing.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
"async typing" definitions here.
66
"""
77
import sys
8-
from typing import TypeVar, Hashable, Union, AsyncIterable, Iterable, Callable
8+
from typing import (
9+
TypeVar,
10+
Hashable,
11+
Union,
12+
AsyncIterable,
13+
Iterable,
14+
Callable,
15+
Any,
16+
Awaitable,
17+
)
918

1019
if sys.version_info >= (3, 8):
1120
from typing import Protocol, AsyncContextManager, ContextManager, TypedDict
@@ -29,7 +38,7 @@
2938
"T4",
3039
"T5",
3140
"R",
32-
"C",
41+
"AC",
3342
"HK",
3443
"LT",
3544
"ADD",
@@ -44,7 +53,7 @@
4453
T4 = TypeVar("T4")
4554
T5 = TypeVar("T5")
4655
R = TypeVar("R", covariant=True)
47-
C = TypeVar("C", bound=Callable)
56+
AC = TypeVar("AC", bound=Callable[..., Awaitable[Any]])
4857

4958
#: Hashable Key
5059
HK = TypeVar("HK", bound=Hashable)

asyncstdlib/_utility.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar, Any, Optional
1+
from typing import TypeVar, Any, Optional, Callable
22

33
from ._typing import Protocol
44

@@ -16,7 +16,9 @@ class Definition(Protocol):
1616
D = TypeVar("D", bound=Definition)
1717

1818

19-
def public_module(module_name: str, qual_name: Optional[str] = None):
19+
def public_module(
20+
module_name: str, qual_name: Optional[str] = None
21+
) -> Callable[[D], D]:
2022
"""Set the module name of a function or class"""
2123

2224
def decorator(thing: D) -> D:

asyncstdlib/functools.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1-
from typing import Callable, Awaitable, Union, Any
2-
3-
from ._typing import T, C, AnyIterable
1+
from typing import (
2+
Callable,
3+
Awaitable,
4+
Union,
5+
Any,
6+
Generic,
7+
Generator,
8+
Optional,
9+
overload,
10+
)
11+
12+
from ._typing import T, AC, AnyIterable
413
from ._core import ScopedIter, awaitify as _awaitify, Sentinel
514
from .builtins import anext
615
from ._utility import public_module
@@ -18,7 +27,7 @@
1827
]
1928

2029

21-
def cache(user_function: C) -> LRUAsyncCallable[C]:
30+
def cache(user_function: AC) -> LRUAsyncCallable[AC]:
2231
"""
2332
Simple unbounded cache, aka memoization, for async functions
2433
@@ -31,25 +40,25 @@ def cache(user_function: C) -> LRUAsyncCallable[C]:
3140
__REDUCE_SENTINEL = Sentinel("<no default>")
3241

3342

34-
class AwaitableValue:
43+
class AwaitableValue(Generic[T]):
3544
"""Helper to provide an arbitrary value in ``await``"""
3645

3746
__slots__ = ("value",)
3847

39-
def __init__(self, value):
48+
def __init__(self, value: T):
4049
self.value = value
4150

4251
# noinspection PyUnreachableCode
43-
def __await__(self):
52+
def __await__(self) -> Generator[None, None, T]:
4453
return self.value
4554
yield # type: ignore # pragma: no cover
4655

47-
def __repr__(self):
56+
def __repr__(self) -> str:
4857
return f"{self.__class__.__name__}({self.value!r})"
4958

5059

5160
@public_module(__name__, "cached_property")
52-
class CachedProperty:
61+
class CachedProperty(Generic[T]):
5362
"""
5463
Transform a method into an attribute whose value is cached
5564
@@ -95,7 +104,7 @@ def __init__(self, getter: Callable[[Any], Awaitable[T]]):
95104
self._name = getter.__name__
96105
self.__doc__ = getter.__doc__
97106

98-
def __set_name__(self, owner, name):
107+
def __set_name__(self, owner: Any, name: str) -> None:
99108
# Check whether we can store anything on the instance
100109
# Note that this is a failsafe, and might fail ugly.
101110
# People who are clever enough to avoid this heuristic
@@ -107,12 +116,22 @@ def __set_name__(self, owner, name):
107116
)
108117
self._name = name
109118

110-
def __get__(self, instance, owner):
119+
@overload
120+
def __get__(self, instance: None, owner: type) -> "CachedProperty[T]":
121+
...
122+
123+
@overload
124+
def __get__(self, instance: object, owner: Optional[type]) -> Awaitable[T]:
125+
...
126+
127+
def __get__(
128+
self, instance: Optional[object], owner: Optional[type]
129+
) -> Union["CachedProperty[T]", Awaitable[T]]:
111130
if instance is None:
112131
return self
113132
return self._get_attribute(instance)
114133

115-
async def _get_attribute(self, instance) -> T:
134+
async def _get_attribute(self, instance: object) -> T:
116135
value = await self.__wrapped__(instance)
117136
instance.__dict__[self._name] = AwaitableValue(value)
118137
return value

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ warn_unreachable = true
5151
[[tool.mypy.overrides]]
5252
module = [
5353
"asyncstdlib.asynctools",
54+
"asyncstdlib.functools",
5455
"asyncstdlib._core",
56+
"asyncstdlib._utility",
57+
"asyncstdlib._lrucache",
58+
"asyncstdlib._typing",
5559
]
5660
disallow_any_generics = true
5761
disallow_subclassing_any = true

0 commit comments

Comments
 (0)