Skip to content

Commit 39bf145

Browse files
committed
Allow route decorator stacking.
1 parent 517e151 commit 39bf145

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

starlette_plus/core.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
9292
self._prefix: bool = kwargs["prefix"]
9393
self._limits: list[RateLimitData] = kwargs.get("limits", [])
9494
self._is_websocket: bool = kwargs.get("websocket", False)
95-
self._view: View | None = None
95+
self._view: View | Application | None = None
9696
self._include_in_schema: bool = kwargs["include_in_schema"]
9797

9898
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
@@ -117,8 +117,8 @@ def route(
117117
prefix: bool = True,
118118
websocket: bool = False,
119119
include_in_schema: bool = True,
120-
) -> Callable[..., _Route]:
121-
def decorator(coro: Callable[..., RouteCoro]) -> _Route:
120+
) -> Callable[..., Callable[..., RouteCoro]]:
121+
def decorator(coro: Callable[..., RouteCoro]) -> Callable[..., RouteCoro]:
122122
if not asyncio.iscoroutinefunction(coro):
123123
raise RuntimeError("Route callback must be a coroutine function.")
124124

@@ -127,7 +127,7 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route:
127127
raise ValueError(f"Route callback function must not be named any: {', '.join(disallowed)}")
128128

129129
limits: list[RateLimitData] = getattr(coro, "__limits__", [])
130-
return _Route(
130+
route = _Route(
131131
path=path,
132132
coro=coro,
133133
methods=methods,
@@ -137,6 +137,13 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route:
137137
include_in_schema=include_in_schema,
138138
)
139139

140+
try:
141+
coro.__routes__.append(route) # type: ignore
142+
except AttributeError:
143+
setattr(coro, "__routes__", [route])
144+
145+
return coro
146+
140147
return decorator
141148

142149

@@ -215,8 +222,11 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
215222
self.__routes__ = []
216223

217224
name: str = cls.__name__
225+
members: list[Any] = [
226+
r for (_, m) in inspect.getmembers(self, predicate=lambda m: hasattr(m, "__routes__")) for r in m.__routes__
227+
]
218228

219-
for _, member in inspect.getmembers(self, predicate=lambda m: isinstance(m, _Route)):
229+
for member in members:
220230
member._view = self
221231
path: str = member._path
222232

@@ -294,8 +304,11 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
294304
prefix = cls.__prefix__ or name
295305

296306
self.__routes__ = []
307+
members: list[Any] = [
308+
r for (_, m) in inspect.getmembers(self, predicate=lambda m: hasattr(m, "__routes__")) for r in m.__routes__
309+
]
297310

298-
for _, member in inspect.getmembers(self, predicate=lambda m: isinstance(m, _Route)):
311+
for member in members:
299312
member._view = self
300313
path: str = member._path
301314

0 commit comments

Comments
 (0)