Skip to content

Commit 58c0db5

Browse files
authored
Allow for module server to be async
1 parent 03721bb commit 58c0db5

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

shiny/module.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId")
4-
5-
from typing import TYPE_CHECKING, Callable, TypeVar
3+
import functools
4+
from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar, overload
65

76
from ._docstring import no_example
87
from ._namespaces import (
@@ -13,10 +12,13 @@
1312
resolve_id,
1413
)
1514
from ._typing_extensions import Concatenate, ParamSpec
15+
from ._utils import is_async_callable, not_is_async_callable
1616

1717
if TYPE_CHECKING:
1818
from .session import Inputs, Outputs, Session
1919

20+
__all__ = ("current_namespace", "resolve_id", "ui", "server", "ResolvedId")
21+
2022
P = ParamSpec("P")
2123
R = TypeVar("R")
2224

@@ -34,15 +36,50 @@ def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R:
3436

3537

3638
@no_example()
39+
# Use overloads so the function type stays the same for when the user calls it
40+
@overload
41+
def server(
42+
fn: Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]],
43+
) -> Callable[Concatenate[str, P], Awaitable[R]]: ...
44+
@overload
3745
def server(
3846
fn: Callable[Concatenate[Inputs, Outputs, Session, P], R],
39-
) -> Callable[Concatenate[str, P], R]:
47+
) -> Callable[Concatenate[str, P], R]: ...
48+
def server(
49+
fn: (
50+
Callable[Concatenate[Inputs, Outputs, Session, P], R]
51+
| Callable[Concatenate[Inputs, Outputs, Session, P], Awaitable[R]]
52+
),
53+
) -> Callable[Concatenate[str, P], R] | Callable[Concatenate[str, P], Awaitable[R]]:
4054
from .session import require_active_session, session_context
4155

42-
def wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R:
43-
sess = require_active_session(None)
44-
child_sess = sess.make_scope(id)
45-
with session_context(child_sess):
46-
return fn(child_sess.input, child_sess.output, child_sess, *args, **kwargs)
56+
if is_async_callable(fn):
4757

48-
return wrapper
58+
@functools.wraps(fn)
59+
async def async_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R:
60+
sess = require_active_session(None)
61+
child_sess = sess.make_scope(id)
62+
with session_context(child_sess):
63+
return await fn(
64+
child_sess.input, child_sess.output, child_sess, *args, **kwargs
65+
)
66+
67+
return async_wrapper
68+
69+
# Required for type narrowing. `TypeIs` did not seem to work as expected here.
70+
if not_is_async_callable(fn):
71+
72+
@functools.wraps(fn)
73+
def sync_wrapper(id: Id, *args: P.args, **kwargs: P.kwargs) -> R:
74+
sess = require_active_session(None)
75+
child_sess = sess.make_scope(id)
76+
with session_context(child_sess):
77+
return fn(
78+
child_sess.input, child_sess.output, child_sess, *args, **kwargs
79+
)
80+
81+
return sync_wrapper
82+
83+
raise RuntimeError(
84+
"The provided function must be either synchronous or asynchronous."
85+
)

0 commit comments

Comments
 (0)