diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py index e3f372c3ba..9dd1aa132c 100644 --- a/sentry_sdk/ai/monitoring.py +++ b/sentry_sdk/ai/monitoring.py @@ -10,7 +10,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Callable, Any + from typing import Optional, Callable, Awaitable, Any, Union, TypeVar + + F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]]) _ai_pipeline_name = ContextVar("ai_pipeline_name", default=None) @@ -26,9 +28,9 @@ def get_ai_pipeline_name(): def ai_track(description, **span_kwargs): - # type: (str, Any) -> Callable[..., Any] + # type: (str, Any) -> Callable[[F], F] def decorator(f): - # type: (Callable[..., Any]) -> Callable[..., Any] + # type: (F) -> F def sync_wrapped(*args, **kwargs): # type: (Any, Any) -> Any curr_pipeline = _ai_pipeline_name.get() @@ -88,9 +90,9 @@ async def async_wrapped(*args, **kwargs): return res if inspect.iscoroutinefunction(f): - return wraps(f)(async_wrapped) + return wraps(f)(async_wrapped) # type: ignore else: - return wraps(f)(sync_wrapped) + return wraps(f)(sync_wrapped) # type: ignore return decorator