11from __future__ import annotations
22
3- __all__ = ("current_namespace" , "resolve_id" , "ui" , "server" , "ResolvedId" )
4-
5- from typing import TYPE_CHECKING , Callable , TypeVar
3+ import functools
4+ from typing import TYPE_CHECKING , Awaitable , Callable , TypeVar , overload
65
76from ._docstring import no_example
87from ._namespaces import (
1312 resolve_id ,
1413)
1514from ._typing_extensions import Concatenate , ParamSpec
15+ from ._utils import is_async_callable , not_is_async_callable
1616
1717if TYPE_CHECKING :
1818 from .session import Inputs , Outputs , Session
1919
20+ __all__ = ("current_namespace" , "resolve_id" , "ui" , "server" , "ResolvedId" )
21+
2022P = ParamSpec ("P" )
2123R = TypeVar ("R" )
2224
@@ -34,15 +36,50 @@ def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R:
3436
3537
3638@no_example ()
39+ # Use overloads so the function type stays the same for when the user calls it
40+ @overload
41+ def server (
42+ fn : Callable [Concatenate [Inputs , Outputs , Session , P ], Awaitable [R ]],
43+ ) -> Callable [Concatenate [str , P ], Awaitable [R ]]: ...
44+ @overload
3745def server (
3846 fn : Callable [Concatenate [Inputs , Outputs , Session , P ], R ],
39- ) -> Callable [Concatenate [str , P ], R ]:
47+ ) -> Callable [Concatenate [str , P ], R ]: ...
48+ def server (
49+ fn : (
50+ Callable [Concatenate [Inputs , Outputs , Session , P ], R ]
51+ | Callable [Concatenate [Inputs , Outputs , Session , P ], Awaitable [R ]]
52+ ),
53+ ) -> Callable [Concatenate [str , P ], R ] | Callable [Concatenate [str , P ], Awaitable [R ]]:
4054 from .session import require_active_session , session_context
4155
42- def wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
43- sess = require_active_session (None )
44- child_sess = sess .make_scope (id )
45- with session_context (child_sess ):
46- return fn (child_sess .input , child_sess .output , child_sess , * args , ** kwargs )
56+ if is_async_callable (fn ):
4757
48- return wrapper
58+ @functools .wraps (fn )
59+ async def async_wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
60+ sess = require_active_session (None )
61+ child_sess = sess .make_scope (id )
62+ with session_context (child_sess ):
63+ return await fn (
64+ child_sess .input , child_sess .output , child_sess , * args , ** kwargs
65+ )
66+
67+ return async_wrapper
68+
69+ # Required for type narrowing. `TypeIs` did not seem to work as expected here.
70+ if not_is_async_callable (fn ):
71+
72+ @functools .wraps (fn )
73+ def sync_wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
74+ sess = require_active_session (None )
75+ child_sess = sess .make_scope (id )
76+ with session_context (child_sess ):
77+ return fn (
78+ child_sess .input , child_sess .output , child_sess , * args , ** kwargs
79+ )
80+
81+ return sync_wrapper
82+
83+ raise RuntimeError (
84+ "The provided function must be either synchronous or asynchronous."
85+ )
0 commit comments