|
1 | 1 | from __future__ import annotations as _annotations
|
2 | 2 |
|
3 | 3 | import asyncio
|
| 4 | +import functools |
| 5 | +import inspect |
4 | 6 | import time
|
5 | 7 | import uuid
|
6 |
| -from collections.abc import AsyncIterable, AsyncIterator, Iterator |
| 8 | +from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator |
7 | 9 | from contextlib import asynccontextmanager, suppress
|
8 | 10 | from dataclasses import dataclass, fields, is_dataclass
|
9 | 11 | from datetime import datetime, timezone
|
10 | 12 | from functools import partial
|
11 | 13 | from types import GenericAlias
|
12 |
| -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union |
| 14 | +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload |
13 | 15 |
|
14 | 16 | from anyio.to_thread import run_sync
|
15 | 17 | from pydantic import BaseModel, TypeAdapter
|
16 | 18 | from pydantic.json_schema import JsonSchemaValue
|
17 |
| -from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict |
| 19 | +from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict |
18 | 20 |
|
19 | 21 | from pydantic_graph._utils import AbstractSpan
|
20 | 22 |
|
@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
|
302 | 304 |
|
303 | 305 | def number_to_datetime(x: int | float) -> datetime:
|
304 | 306 | return TypeAdapter(datetime).validate_python(x)
|
| 307 | + |
| 308 | + |
| 309 | +AwaitableCallable = Callable[..., Awaitable[T]] |
| 310 | + |
| 311 | + |
| 312 | +@overload |
| 313 | +def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ... |
| 314 | + |
| 315 | + |
| 316 | +@overload |
| 317 | +def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ... |
| 318 | + |
| 319 | + |
| 320 | +def is_async_callable(obj: Any) -> Any: |
| 321 | + """Correctly check if a callable is async. |
| 322 | +
|
| 323 | + This function was copied from Starlette: |
| 324 | + https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40 |
| 325 | + """ |
| 326 | + while isinstance(obj, functools.partial): |
| 327 | + obj = obj.func |
| 328 | + |
| 329 | + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore |
0 commit comments