1010from typing import TYPE_CHECKING
1111
1212if TYPE_CHECKING :
13- from typing import Optional , Callable , Any
13+ from typing import Optional , Callable , Awaitable , Any , Union , TypeVar
14+
15+ F = TypeVar ("F" , bound = Union [Callable [..., Any ], Callable [..., Awaitable [Any ]]])
1416
1517_ai_pipeline_name = ContextVar ("ai_pipeline_name" , default = None )
1618
@@ -26,9 +28,9 @@ def get_ai_pipeline_name():
2628
2729
2830def ai_track (description , ** span_kwargs ):
29- # type: (str, Any) -> Callable[..., Any ]
31+ # type: (str, Any) -> Callable[[F], F ]
3032 def decorator (f ):
31- # type: (Callable[..., Any] ) -> Callable[..., Any]
33+ # type: (F ) -> F
3234 def sync_wrapped (* args , ** kwargs ):
3335 # type: (Any, Any) -> Any
3436 curr_pipeline = _ai_pipeline_name .get ()
@@ -88,9 +90,9 @@ async def async_wrapped(*args, **kwargs):
8890 return res
8991
9092 if inspect .iscoroutinefunction (f ):
91- return wraps (f )(async_wrapped )
93+ return wraps (f )(async_wrapped ) # type: ignore
9294 else :
93- return wraps (f )(sync_wrapped )
95+ return wraps (f )(sync_wrapped ) # type: ignore
9496
9597 return decorator
9698
0 commit comments