11import functools
2- from typing import Callable , TypeVar
2+ from typing import Awaitable , Callable , TypeVar , overload
33
44from .._docstring import add_example
55from .._typing_extensions import Concatenate , ParamSpec
6+ from .._utils import is_async_callable , not_is_async_callable
67from ..module import Id
78from ..session ._session import Inputs , Outputs , Session
89from ..session ._utils import require_active_session , session_context
1617
1718
1819@add_example (ex_dir = "../api-examples/express_module" )
20+ # Use overloads so the function type stays the same for when the user calls it
21+ @overload
22+ def module (
23+ fn : Callable [Concatenate [Inputs , Outputs , Session , P ], Awaitable [R ]],
24+ ) -> Callable [Concatenate [Id , P ], Awaitable [R ]]: ...
25+ @overload
1926def module (
2027 fn : Callable [Concatenate [Inputs , Outputs , Session , P ], R ],
21- ) -> Callable [Concatenate [Id , P ], R ]:
28+ ) -> Callable [Concatenate [Id , P ], R ]: ...
29+ def module (
30+ fn : (
31+ Callable [Concatenate [Inputs , Outputs , Session , P ], R ]
32+ | Callable [Concatenate [Inputs , Outputs , Session , P ], Awaitable [R ]]
33+ ),
34+ ) -> Callable [Concatenate [Id , P ], R ] | Callable [Concatenate [Id , P ], Awaitable [R ]]:
2235 """
2336 Create a Shiny module using Shiny Express syntax
2437
@@ -42,18 +55,43 @@ def module(
4255 """
4356 fn = expressify (fn )
4457
45- @functools .wraps (fn )
46- def wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
47- parent_session = require_active_session (None )
48- module_session = parent_session .make_scope (id )
49-
50- with session_context (module_session ):
51- return fn (
52- module_session .input ,
53- module_session .output ,
54- module_session ,
55- * args ,
56- ** kwargs ,
57- )
58-
59- return wrapper
58+ if is_async_callable (fn ):
59+ # If the function is async, we need to wrap it in an async wrapper
60+ @functools .wraps (fn )
61+ async def async_wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
62+ parent_session = require_active_session (None )
63+ module_session = parent_session .make_scope (id )
64+
65+ with session_context (module_session ):
66+ return await fn (
67+ module_session .input ,
68+ module_session .output ,
69+ module_session ,
70+ * args ,
71+ ** kwargs ,
72+ )
73+
74+ return async_wrapper
75+
76+ # Required for type narrowing. `TypeIs` did not seem to work as expected here.
77+ if not_is_async_callable (fn ):
78+
79+ @functools .wraps (fn )
80+ def wrapper (id : Id , * args : P .args , ** kwargs : P .kwargs ) -> R :
81+ parent_session = require_active_session (None )
82+ module_session = parent_session .make_scope (id )
83+
84+ with session_context (module_session ):
85+ return fn (
86+ module_session .input ,
87+ module_session .output ,
88+ module_session ,
89+ * args ,
90+ ** kwargs ,
91+ )
92+
93+ return wrapper
94+
95+ raise RuntimeError (
96+ "The provided function must be either synchronous or asynchronous."
97+ )
0 commit comments