Skip to content

Commit 9bbf838

Browse files
Full strict typing (#73)
* strict typing for all sub-packages * testing strict typing for all sub-packages
1 parent e6331c7 commit 9bbf838

File tree

6 files changed

+142
-90
lines changed

6 files changed

+142
-90
lines changed

asyncstdlib/_typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
T4 = TypeVar("T4")
5454
T5 = TypeVar("T5")
5555
R = TypeVar("R", covariant=True)
56+
C = TypeVar("C", bound=Callable[..., Any])
5657
AC = TypeVar("AC", bound=Callable[..., Awaitable[Any]])
5758

5859
#: Hashable Key
@@ -72,7 +73,7 @@ def __lt__(self: LT, other: LT) -> bool:
7273

7374

7475
class SupportsAdd(Protocol):
75-
def __add__(self: ADD, other: ADD) -> bool:
76+
def __add__(self: ADD, other: ADD) -> ADD:
7677
raise NotImplementedError
7778

7879

asyncstdlib/builtins.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,19 @@
2727
__ANEXT_DEFAULT = Sentinel("<no default>")
2828

2929

30-
async def anext(iterator: AsyncIterator[T], default=__ANEXT_DEFAULT) -> T:
30+
@overload
31+
async def anext(iterator: AsyncIterator[T]) -> T:
32+
...
33+
34+
35+
@overload
36+
async def anext(iterator: AsyncIterator[T], default: T) -> T:
37+
...
38+
39+
40+
async def anext(
41+
iterator: AsyncIterator[T], default: Union[Sentinel, T] = __ANEXT_DEFAULT
42+
) -> T:
3143
"""
3244
Retrieve the next item from the async iterator
3345
@@ -47,7 +59,7 @@ async def anext(iterator: AsyncIterator[T], default=__ANEXT_DEFAULT) -> T:
4759
except StopAsyncIteration:
4860
if default is __ANEXT_DEFAULT:
4961
raise
50-
return default
62+
return default # type: ignore
5163

5264

5365
__ITER_DEFAULT = Sentinel("<no default>")
@@ -132,7 +144,7 @@ async def any(iterable: AnyIterable[T]) -> bool:
132144
def zip(
133145
__it1: AnyIterable[T1],
134146
*,
135-
strict=False,
147+
strict: bool = ...,
136148
) -> AsyncIterator[Tuple[T1]]:
137149
...
138150

@@ -142,7 +154,7 @@ def zip(
142154
__it1: AnyIterable[T1],
143155
__it2: AnyIterable[T2],
144156
*,
145-
strict=False,
157+
strict: bool = ...,
146158
) -> AsyncIterator[Tuple[T1, T2]]:
147159
...
148160

@@ -153,7 +165,7 @@ def zip(
153165
__it2: AnyIterable[T2],
154166
__it3: AnyIterable[T3],
155167
*,
156-
strict=False,
168+
strict: bool = ...,
157169
) -> AsyncIterator[Tuple[T1, T2, T3]]:
158170
...
159171

@@ -165,7 +177,7 @@ def zip(
165177
__it3: AnyIterable[T3],
166178
__it4: AnyIterable[T4],
167179
*,
168-
strict=False,
180+
strict: bool = ...,
169181
) -> AsyncIterator[Tuple[T1, T2, T3, T4]]:
170182
...
171183

@@ -178,7 +190,7 @@ def zip(
178190
__it4: AnyIterable[T4],
179191
__it5: AnyIterable[T5],
180192
*,
181-
strict=False,
193+
strict: bool = ...,
182194
) -> AsyncIterator[Tuple[T1, T2, T3, T4, T5]]:
183195
...
184196

@@ -191,13 +203,13 @@ def zip(
191203
__it4: AnyIterable[Any],
192204
__it5: AnyIterable[Any],
193205
*iterables: AnyIterable[Any],
194-
strict=False,
206+
strict: bool = ...,
195207
) -> AsyncIterator[Tuple[Any, ...]]:
196208
...
197209

198210

199211
async def zip(
200-
*iterables: AnyIterable[Any], strict=False
212+
*iterables: AnyIterable[Any], strict: bool = False
201213
) -> AsyncIterator[Tuple[Any, ...]]:
202214
"""
203215
Create an async iterator that aggregates elements from each of the (async) iterables
@@ -240,15 +252,22 @@ async def zip(
240252
await aclose()
241253

242254

243-
async def _zip_inner(aiters):
255+
async def _zip_inner(
256+
aiters: Tuple[AsyncIterator[T], ...]
257+
) -> AsyncIterator[Tuple[T, ...]]:
258+
"""Direct zip transposing tuple-of-iterators to iterator-of-tuples"""
244259
try:
245260
while True:
246261
yield (*[await anext(it) for it in aiters],)
247262
except StopAsyncIteration:
248263
return
249264

250265

251-
async def _zip_inner_strict(aiters):
266+
async def _zip_inner_strict(
267+
aiters: Tuple[AsyncIterator[T], ...]
268+
) -> AsyncIterator[Tuple[T, ...]]:
269+
"""Length aware zip checking that all iterators are equal length"""
270+
# track index of the last iterator we tried to anext
252271
tried = 0
253272
try:
254273
while True:
@@ -541,13 +560,12 @@ async def _min_max(
541560
:param invert: compute ``max`` if ``True`` and ``min`` otherwise
542561
"""
543562
async with ScopedIter(iterable) as item_iter:
544-
best = await anext(item_iter, default=__MIN_MAX_DEFAULT)
545-
if best is __MIN_MAX_DEFAULT:
546-
if default is __MIN_MAX_DEFAULT:
547-
name = "max" if invert else "min"
548-
raise ValueError(f"{name}() arg is an empty sequence")
549-
return default
550-
if key is None:
563+
best = await anext(item_iter, default=default)
564+
# this implies that item_iter is empty and default is __MIN_MAX_DEFAULT
565+
if best is __MIN_MAX_DEFAULT: # type: ignore
566+
name = "max" if invert else "min"
567+
raise ValueError(f"{name}() arg is an empty sequence")
568+
elif key is None:
551569
async for item in item_iter:
552570
if invert ^ (item < best):
553571
best = item
@@ -587,7 +605,9 @@ async def filter(
587605
yield item
588606

589607

590-
async def enumerate(iterable: AnyIterable[T], start=0) -> AsyncIterator[Tuple[int, T]]:
608+
async def enumerate(
609+
iterable: AnyIterable[T], start: int = 0
610+
) -> AsyncIterator[Tuple[int, T]]:
591611
"""
592612
An async iterator of running count and element in an (async) iterable
593613
@@ -685,11 +705,6 @@ async def set(iterable: Union[Iterable[T], AsyncIterable[T]] = ()) -> Set[T]:
685705
return {element async for element in aiter(iterable)}
686706

687707

688-
async def _identity(x: T) -> T:
689-
"""Asynchronous identity function, returns its argument unchanged"""
690-
return x
691-
692-
693708
@overload
694709
async def sorted(
695710
iterable: AnyIterable[LT], *, key: None = ..., reverse: bool = ...
@@ -707,7 +722,7 @@ async def sorted(
707722
async def sorted(
708723
iterable: AnyIterable[T],
709724
*,
710-
key: Optional[Callable[[T], Any]] = None,
725+
key: Optional[Callable[[T], LT]] = None,
711726
reverse: bool = False,
712727
) -> List[T]:
713728
"""
@@ -730,11 +745,15 @@ async def sorted(
730745
It is guaranteed to be worst-case O(n log n) runtime.
731746
"""
732747
if key is None:
748+
# TODO: is this a worthwhile optimisation?
733749
try:
734750
return _sync_builtins.sorted(iterable, reverse=reverse) # type: ignore
735751
except TypeError:
736-
pass
737-
key = _awaitify(key) if key is not None else _identity
738-
keyed_items = [(await key(item), item) async for item in aiter(iterable)]
739-
keyed_items.sort(key=lambda ki: ki[0], reverse=reverse)
740-
return [item for key, item in keyed_items]
752+
items = [item async for item in aiter(iterable)]
753+
items.sort(reverse=reverse)
754+
return items
755+
else:
756+
async_key = _awaitify(key)
757+
keyed_items = [(await async_key(item), item) async for item in aiter(iterable)]
758+
keyed_items.sort(key=lambda ki: ki[0], reverse=reverse)
759+
return [item for key, item in keyed_items]

asyncstdlib/contextlib.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515
from functools import partial
1616
import sys
1717

18-
from ._typing import Protocol, AsyncContextManager, ContextManager, T
18+
from ._typing import Protocol, AsyncContextManager, ContextManager, T, C
1919
from ._core import awaitify
2020
from ._utility import public_module, slot_get as _slot_get
2121

2222

23+
AnyContextManager = Union[AsyncContextManager[T], ContextManager[T]]
24+
25+
2326
# typing.AsyncContextManager uses contextlib.AbstractAsyncContextManager if available,
2427
# and a custom implementation otherwise. No need to replicate it.
2528
AbstractContextManager = AsyncContextManager
2629

2730

2831
class ACloseable(Protocol):
29-
async def aclose(self):
32+
async def aclose(self) -> None:
3033
"""Asynchronously close this object"""
3134

3235

@@ -58,29 +61,31 @@ async def Context(*args, **kwargs):
5861
"""
5962

6063
@wraps(func)
61-
def helper(*args, **kwds):
64+
def helper(*args: Any, **kwds: Any) -> AsyncContextManager[T]:
6265
return _AsyncGeneratorContextManager(func, args, kwds)
6366

6467
return helper
6568

6669

67-
class _AsyncGeneratorContextManager:
68-
def __init__(self, func, args, kwds):
70+
class _AsyncGeneratorContextManager(Generic[T]):
71+
def __init__(
72+
self, func: Callable[..., AsyncGenerator[T, None]], args: Any, kwds: Any
73+
):
6974
self.gen = func(*args, **kwds)
7075
self.__doc__ = getattr(func, "__doc__", type(self).__doc__)
7176

72-
async def __aenter__(self):
77+
async def __aenter__(self) -> T:
7378
try:
7479
return await self.gen.__anext__()
7580
except StopAsyncIteration:
7681
raise RuntimeError("generator did not yield to __aenter__") from None
7782

78-
async def __aexit__(self, exc_type, exc_val, exc_tb):
83+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
7984
if exc_type is None:
8085
try:
8186
await self.gen.__anext__()
8287
except StopAsyncIteration:
83-
return
88+
return False
8489
else:
8590
raise RuntimeError("generator did not stop after __aexit__")
8691
else:
@@ -99,6 +104,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
99104
except exc_type as exc:
100105
if exc is not exc_val:
101106
raise
107+
return False
102108
else:
103109
raise RuntimeError("generator did not stop after throw() in __aexit__")
104110

@@ -134,8 +140,9 @@ def __init__(self, thing: AC):
134140
async def __aenter__(self) -> AC:
135141
return self.thing
136142

137-
async def __aexit__(self, exc_type, exc_val, exc_tb):
143+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
138144
await self.thing.aclose()
145+
return False
139146

140147

141148
closing = Closing
@@ -175,7 +182,7 @@ def __init__(self: "NullContext[None]", enter_result: None = ...) -> None:
175182
def __init__(self: "NullContext[T]", enter_result: T) -> None:
176183
...
177184

178-
def __init__(self, enter_result=None):
185+
def __init__(self, enter_result: Optional[T] = None):
179186
self.enter_result = enter_result
180187

181188
@overload
@@ -186,11 +193,11 @@ async def __aenter__(self: "NullContext[None]") -> None:
186193
async def __aenter__(self: "NullContext[T]") -> T:
187194
...
188195

189-
async def __aenter__(self):
196+
async def __aenter__(self) -> Optional[T]:
190197
return self.enter_result
191198

192-
async def __aexit__(self, exc_type, exc_val, exc_tb):
193-
pass
199+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
200+
return False
194201

195202

196203
nullcontext = NullContext
@@ -199,8 +206,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
199206
SE = TypeVar(
200207
"SE",
201208
bound=Union[
202-
AsyncContextManager,
203-
ContextManager,
209+
AsyncContextManager[Any],
210+
ContextManager[Any],
204211
Callable[[Any, BaseException, Any], Optional[bool]],
205212
Callable[[Any, BaseException, Any], Awaitable[Optional[bool]]],
206213
],
@@ -228,11 +235,13 @@ class ExitStack:
228235
There are no separate methods to distinguish async and regular arguments.
229236
"""
230237

231-
def __init__(self):
238+
def __init__(self) -> None:
232239
self._exit_callbacks: Deque[Callable[..., Awaitable[Optional[bool]]]] = deque()
233240

234241
@staticmethod
235-
async def _aexit_callback(callback, exc_type, exc_val, tb):
242+
async def _aexit_callback(
243+
callback: Callable[[], Awaitable[Any]], exc_type: Any, exc_val: Any, tb: Any
244+
) -> bool:
236245
"""Invoke a callback as if it were an ``__aexit__`` method"""
237246
await callback()
238247
return False # callbacks never suppress exceptions
@@ -298,7 +307,7 @@ def push(self, exit: SE) -> SE:
298307
self._exit_callbacks.append(aexit)
299308
return exit
300309

301-
def callback(self, callback: Callable, *args, **kwargs):
310+
def callback(self, callback: C, *args: Any, **kwargs: Any) -> C:
302311
"""
303312
Registers an arbitrary callback to be called with arguments on unwinding
304313
@@ -312,11 +321,11 @@ def callback(self, callback: Callable, *args, **kwargs):
312321
This method does not change its argument, and can be used as a context manager.
313322
"""
314323
self._exit_callbacks.append(
315-
partial(self._aexit_callback, awaitify(partial(callback, *args, **kwargs)))
324+
partial(self._aexit_callback, partial(awaitify(callback), *args, **kwargs))
316325
)
317326
return callback
318327

319-
async def enter_context(self, cm: AsyncContextManager):
328+
async def enter_context(self, cm: AnyContextManager[T]) -> T:
320329
"""
321330
Enter the supplied context manager, and register it for exit if successful
322331
@@ -353,9 +362,9 @@ async def enter_context(self, cm: AsyncContextManager):
353362
else:
354363
context_value = await _slot_get(cm, "__aenter__")()
355364
self._exit_callbacks.append(aexit)
356-
return context_value
365+
return context_value # type: ignore
357366

358-
async def aclose(self):
367+
async def aclose(self) -> None:
359368
"""
360369
Immediately unwind the context stack
361370
@@ -371,7 +380,7 @@ def _stitch_context(
371380
exception: BaseException,
372381
context: BaseException,
373382
base_context: Optional[BaseException],
374-
):
383+
) -> None:
375384
"""
376385
Emulate that `exception` was caused by an unhandled `context`
377386
@@ -392,10 +401,10 @@ def _stitch_context(
392401
# we expect it to reference
393402
exception.__context__ = context
394403

395-
async def __aenter__(self):
404+
async def __aenter__(self) -> "ExitStack":
396405
return self
397406

398-
async def __aexit__(self, exc_type, exc_val, tb):
407+
async def __aexit__(self, exc_type: Any, exc_val: Any, tb: Any) -> bool:
399408
received_exc = exc_type is not None
400409
# Even if we don't handle an exception *right now*, we may be part
401410
# of an exception handler unwinding gracefully. This is our __context__.

0 commit comments

Comments
 (0)