Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Optional, Callable, Any
from typing import Optional, Callable, Awaitable, Any, Union, TypeVar

F = TypeVar("F", bound=Union[Callable[..., Any], Callable[..., Awaitable[Any]]])

_ai_pipeline_name = ContextVar("ai_pipeline_name", default=None)

Expand All @@ -26,9 +28,9 @@ def get_ai_pipeline_name():


def ai_track(description, **span_kwargs):
# type: (str, Any) -> Callable[..., Any]
# type: (str, Any) -> Callable[[F], F]
def decorator(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
# type: (F) -> F
def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
Expand Down Expand Up @@ -88,9 +90,9 @@ async def async_wrapped(*args, **kwargs):
return res

if inspect.iscoroutinefunction(f):
return wraps(f)(async_wrapped)
return wraps(f)(async_wrapped) # type: ignore
else:
return wraps(f)(sync_wrapped)
return wraps(f)(sync_wrapped) # type: ignore

return decorator

Expand Down