|
8 | 8 | import inspect |
9 | 9 | import warnings |
10 | 10 |
|
11 | | -from typing import Any, Callable, Awaitable, TypeVar, Coroutine |
| 11 | +from typing import Any, Callable, Awaitable, TypeVar, Coroutine, ParamSpec |
| 12 | +from typing_extensions import TypeIs |
12 | 13 |
|
13 | 14 | T = TypeVar("T") |
| 15 | +P = ParamSpec("P") |
14 | 16 |
|
15 | 17 |
|
16 | 18 | class _ExecutionContext: |
@@ -68,14 +70,16 @@ def run(self, coro: Coroutine[Any, Any, T]) -> T: |
68 | 70 | execution_context = _ExecutionContext() |
69 | 71 |
|
70 | 72 |
|
71 | | -def is_coroutine_fn(fn: Callable[..., Any]) -> bool: |
| 73 | +def is_coroutine_fn( |
| 74 | + fn: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]], |
| 75 | +) -> TypeIs[Callable[P, Coroutine[Any, Any, T]]]: |
72 | 76 | if isinstance(fn, (staticmethod, classmethod)): |
73 | 77 | return inspect.iscoroutinefunction(fn.__func__) |
74 | 78 | else: |
75 | 79 | return inspect.iscoroutinefunction(fn) |
76 | 80 |
|
77 | 81 |
|
78 | | -def to_async_call(fn: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: |
| 82 | +def to_async_call(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: |
79 | 83 | if is_coroutine_fn(fn): |
80 | 84 | return fn |
81 | 85 | return lambda *args, **kwargs: asyncio.to_thread(fn, *args, **kwargs) |
0 commit comments