Skip to content

Commit 028721d

Browse files
authored
Add fixes for type stub generation (#828)
1 parent a0df2ed commit 028721d

File tree

15 files changed

+57
-42
lines changed

15 files changed

+57
-42
lines changed

shiny/_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _server(inputs: Inputs, outputs: Outputs, session: Session):
166166
cast("Tag | TagList", ui), lib_prefix=self.lib_prefix
167167
)
168168

169-
def init_starlette_app(self):
169+
def init_starlette_app(self) -> starlette.applications.Starlette:
170170
routes: list[starlette.routing.BaseRoute] = [
171171
starlette.routing.WebSocketRoute("/websocket/", self._on_connect_cb),
172172
starlette.routing.Route("/", self._on_root_request_cb, methods=["GET"]),
@@ -400,7 +400,7 @@ def _render_page_from_file(self, file: Path, lib_prefix: str) -> RenderedHTML:
400400
return rendered
401401

402402

403-
def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]):
403+
def is_uifunc(x: Path | Tag | TagList | Callable[[Request], Tag | TagList]) -> bool:
404404
if (
405405
isinstance(x, Path)
406406
or isinstance(x, Tag)

shiny/_namespaces.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from contextlib import contextmanager
77
from contextvars import ContextVar, Token
8-
from typing import Pattern, Union, overload
8+
from typing import Generator, Pattern, Union, overload
99

1010

1111
class ResolvedId(str):
@@ -82,7 +82,7 @@ def resolve_id_or_none(id: Id | None) -> ResolvedId | None:
8282
re_valid_id: Pattern[str] = re.compile("^\\.?\\w+$")
8383

8484

85-
def validate_id(id: str):
85+
def validate_id(id: str) -> None:
8686
if not re_valid_id.match(id):
8787
raise ValueError(
8888
f"The string '{id}' is not a valid id; only letters, numbers, and "
@@ -97,7 +97,7 @@ def validate_id(id: str):
9797

9898

9999
@contextmanager
100-
def namespace_context(id: Id | None):
100+
def namespace_context(id: Id | None) -> Generator[None, None, None]:
101101
namespace = resolve_id(id) if id else Root
102102
token: Token[ResolvedId | None] = _current_namespace.set(namespace)
103103
try:

shiny/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import secrets
1212
import socketserver
1313
import tempfile
14-
from typing import Any, Awaitable, Callable, Optional, TypeVar, cast
14+
from typing import Any, Awaitable, Callable, Generator, Optional, TypeVar, cast
1515

1616
from ._typing_extensions import ParamSpec, TypeGuard
1717

@@ -200,7 +200,7 @@ def private_random_int(min: int, max: int) -> str:
200200

201201

202202
@contextlib.contextmanager
203-
def private_seed():
203+
def private_seed() -> Generator[None, None, None]:
204204
state = random.getstate()
205205
global own_random_state
206206
try:

shiny/express/_is_express.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def __init__(self):
5656
super().__init__()
5757
self.found_shiny_express_import = False
5858

59-
def visit_Import(self, node: ast.Import):
59+
def visit_Import(self, node: ast.Import) -> None:
6060
if any(alias.name == "shiny.express" for alias in node.names):
6161
self.found_shiny_express_import = True
6262

63-
def visit_ImportFrom(self, node: ast.ImportFrom):
63+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
6464
if node.module == "shiny.express":
6565
self.found_shiny_express_import = True
6666
elif node.module == "shiny" and any(
@@ -69,9 +69,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom):
6969
self.found_shiny_express_import = True
7070

7171
# Visit top-level nodes.
72-
def visit_Module(self, node: ast.Module):
72+
def visit_Module(self, node: ast.Module) -> None:
7373
super().generic_visit(node)
7474

7575
# Don't recurse into any nodes, so the we'll only ever look at top-level nodes.
76-
def generic_visit(self, node: ast.AST):
76+
def generic_visit(self, node: ast.AST) -> None:
7777
pass

shiny/express/_output.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import contextlib
44
import sys
55
from contextlib import AbstractContextManager
6-
from typing import Callable, TypeVar, cast, overload
6+
from typing import Callable, Generator, TypeVar, cast, overload
77

88
from .. import ui
99
from .._typing_extensions import ParamSpec
@@ -109,7 +109,7 @@ def suspend_display(
109109

110110

111111
@contextlib.contextmanager
112-
def suspend_display_ctxmgr():
112+
def suspend_display_ctxmgr() -> Generator[None, None, None]:
113113
oldhook = sys.displayhook
114114
sys.displayhook = null_displayhook
115115
try:

shiny/express/_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ def set_result(x: object):
136136
_top_level_recall_context_manager_has_been_replaced = False
137137

138138

139-
def reset_top_level_recall_context_manager():
139+
def reset_top_level_recall_context_manager() -> None:
140140
global _top_level_recall_context_manager
141141
global _top_level_recall_context_manager_has_been_replaced
142142
_top_level_recall_context_manager = RecallContextManager(_DEFAULT_PAGE_FUNCTION)
143143
_top_level_recall_context_manager_has_been_replaced = False
144144

145145

146-
def get_top_level_recall_context_manager():
146+
def get_top_level_recall_context_manager() -> RecallContextManager[Tag]:
147147
return _top_level_recall_context_manager
148148

149149

shiny/express/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from pathlib import Path
44

5+
from .._app import App
56
from ._run import wrap_express_app
67
from ._utils import unescape_from_var_name
78

89

910
# If someone requests shiny.express.app:_2f_path_2f_to_2f_app_2e_py, then we will call
1011
# wrap_express_app(Path("/path/to/app.py")) and return the result.
11-
def __getattr__(name: str):
12+
def __getattr__(name: str) -> App:
1213
name = unescape_from_var_name(name)
1314
return wrap_express_app(Path(name))

shiny/express/display_decorator/_display_body.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def unwrap(fn: TFunc) -> TFunc:
4848
display_body_attr = "__display_body__"
4949

5050

51-
def display_body_unwrap_inplace():
51+
def display_body_unwrap_inplace() -> Callable[[TFunc], TFunc]:
5252
"""
5353
Like `display_body`, but far more violent. This will attempt to traverse any
5454
decorators between this one and the function, and then modify the function _in
@@ -76,7 +76,7 @@ def decorator(fn: TFunc) -> TFunc:
7676
return decorator
7777

7878

79-
def display_body():
79+
def display_body() -> Callable[[TFunc], TFunc]:
8080
def decorator(fn: TFunc) -> TFunc:
8181
if fn.__code__ in code_cache:
8282
fcode = code_cache[fn.__code__]
@@ -197,7 +197,9 @@ def _transform_function_ast(node: ast.AST) -> ast.AST:
197197
return func_node
198198

199199

200-
def compare_decorated_code_objects(func_ast: ast.FunctionDef):
200+
def compare_decorated_code_objects(
201+
func_ast: ast.FunctionDef,
202+
) -> Callable[[types.CodeType, types.CodeType], bool]:
201203
linenos = [*[x.lineno for x in func_ast.decorator_list], func_ast.lineno]
202204

203205
def comparator(candidate: types.CodeType, target: types.CodeType) -> bool:

shiny/express/layout.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from .. import ui
99
from ..types import MISSING, MISSING_TYPE
10+
from ..ui._accordion import AccordionPanel
1011
from ..ui._layout_columns import BreakpointsUser
12+
from ..ui._navs import NavPanel, NavSet, NavSetCard
1113
from ..ui.css import CssUnit
1214
from . import _run
1315
from ._recall_context import RecallContextManager, wrap_recall_context_manager
@@ -39,7 +41,7 @@
3941
# ======================================================================================
4042
# Page functions
4143
# ======================================================================================
42-
def set_page(page_fn: RecallContextManager[Tag]):
44+
def set_page(page_fn: RecallContextManager[Tag]) -> None:
4345
"""Set the page function for the current Shiny express app."""
4446
_run.replace_top_level_recall_context_manager(page_fn, force=True)
4547

@@ -162,7 +164,7 @@ def layout_column_wrap(
162164
gap: Optional[CssUnit] = None,
163165
class_: Optional[str] = None,
164166
**kwargs: TagAttrValue,
165-
):
167+
) -> RecallContextManager[Tag]:
166168
"""
167169
A grid-like, column-first layout
168170
@@ -252,7 +254,7 @@ def layout_columns(
252254
class_: Optional[str] = None,
253255
height: Optional[CssUnit] = None,
254256
**kwargs: TagAttrValue,
255-
):
257+
) -> RecallContextManager[Tag]:
256258
"""
257259
Create responsive, column-based grid layouts, based on a 12-column grid.
258260
@@ -346,7 +348,9 @@ def layout_columns(
346348
)
347349

348350

349-
def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
351+
def column(
352+
width: int, *, offset: int = 0, **kwargs: TagAttrValue
353+
) -> RecallContextManager[Tag]:
350354
"""
351355
Responsive row-column based layout
352356
@@ -381,7 +385,7 @@ def column(width: int, *, offset: int = 0, **kwargs: TagAttrValue):
381385
)
382386

383387

384-
def row(**kwargs: TagAttrValue):
388+
def row(**kwargs: TagAttrValue) -> RecallContextManager[Tag]:
385389
"""
386390
Responsive row-column based layout
387391
@@ -419,7 +423,7 @@ def card(
419423
fill: bool = True,
420424
class_: Optional[str] = None,
421425
**kwargs: TagAttrValue,
422-
):
426+
) -> RecallContextManager[Tag]:
423427
"""
424428
A Bootstrap card component
425429
@@ -481,7 +485,7 @@ def accordion(
481485
width: Optional[CssUnit] = None,
482486
height: Optional[CssUnit] = None,
483487
**kwargs: TagAttrValue,
484-
):
488+
) -> RecallContextManager[Tag]:
485489
"""
486490
Create a vertically collapsing accordion.
487491
@@ -537,7 +541,7 @@ def accordion_panel(
537541
value: Optional[str] | MISSING_TYPE = MISSING,
538542
icon: Optional[TagChild] = None,
539543
**kwargs: TagAttrValue,
540-
):
544+
) -> RecallContextManager[AccordionPanel]:
541545
"""
542546
Single accordion panel.
543547
@@ -583,7 +587,7 @@ def navset(
583587
selected: Optional[str] = None,
584588
header: TagChild = None,
585589
footer: TagChild = None,
586-
):
590+
) -> RecallContextManager[NavSet]:
587591
"""
588592
Render a set of nav items
589593
@@ -635,7 +639,7 @@ def navset_card(
635639
sidebar: Optional[ui.Sidebar] = None,
636640
header: TagChild = None,
637641
footer: TagChild = None,
638-
):
642+
) -> RecallContextManager[NavSetCard]:
639643
"""
640644
Render a set of nav items inside a card container.
641645
@@ -687,7 +691,7 @@ def nav_panel(
687691
*,
688692
value: Optional[str] = None,
689693
icon: TagChild = None,
690-
):
694+
) -> RecallContextManager[NavPanel]:
691695
"""
692696
Create a nav item pointing to some internal content.
693697
@@ -803,7 +807,7 @@ def page_fillable(
803807
title: Optional[str] = None,
804808
lang: Optional[str] = None,
805809
**kwargs: TagAttrValue,
806-
):
810+
) -> RecallContextManager[Tag]:
807811
"""
808812
Creates a fillable page.
809813
@@ -854,7 +858,7 @@ def page_sidebar(
854858
window_title: str | MISSING_TYPE = MISSING,
855859
lang: Optional[str] = None,
856860
**kwargs: TagAttrValue,
857-
):
861+
) -> RecallContextManager[Tag]:
858862
"""
859863
Create a page with a sidebar and a title.
860864

shiny/http_staticfiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class StaticFiles:
5050
def __init__(self, *, directory: str | os.PathLike[str]):
5151
self.dir = pathlib.Path(os.path.realpath(os.path.normpath(directory)))
5252

53-
async def __call__(self, scope: Scope, receive: Receive, send: Send):
53+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5454
if scope["type"] != "http":
5555
raise AssertionError("StaticFiles can't handle non-http request")
5656
path = scope["path"]

0 commit comments

Comments
 (0)