Skip to content

Commit b1cfc7b

Browse files
committed
Remove unnecessary route calls.
1 parent 8c9362e commit b1cfc7b

File tree

1 file changed

+52
-44
lines changed

1 file changed

+52
-44
lines changed

starlette_plus/core.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,19 @@
1919
import inspect
2020
import logging
2121
from collections.abc import Callable, Coroutine, Iterator, Sequence
22+
from functools import partial
2223
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypedDict, Unpack
2324

2425
from starlette.applications import Starlette
25-
from starlette.requests import Request
26-
from starlette.responses import Response
26+
from starlette.middleware import Middleware
2727
from starlette.routing import Mount, Route, WebSocketRoute
2828
from starlette.types import Receive, Scope, Send
29-
from starlette.websockets import WebSocket
3029

3130
from .types_.core import RouteCoro
3231

3332

3433
if TYPE_CHECKING:
35-
from starlette.middleware import Middleware
36-
from starlette.types import Message, Receive, Scope, Send
34+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
3735

3836
from .types_.core import Methods, RouteOptions
3937
from .types_.limiter import BucketType, ExemptCallable, RateLimitData
@@ -55,6 +53,34 @@ class ApplicationOptions(TypedDict, total=False):
5553
__all__ = ("Application", "View", "route", "limit")
5654

5755

56+
class LoggingMiddleware:
57+
def __init__(self, app: ASGIApp) -> None:
58+
self.app = app
59+
60+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
61+
if scope["type"] != "http":
62+
await self.app(scope, receive, send)
63+
return
64+
65+
method: str = scope["method"]
66+
path: str = scope["path"]
67+
client: str = f"{scope['client'][0]}:{scope['client'][1]}"
68+
version: str = scope["http_version"]
69+
70+
async def inspect_response(message: Message) -> None:
71+
nonlocal method, path, client, version
72+
73+
if message["type"] == "http.response.start":
74+
status_code: int = message.get("status", 200)
75+
msg: str = f'{client} - "{method} {path} HTTP/{version}" '
76+
77+
access_logger.info(msg, extra={"status": status_code})
78+
79+
await send(message)
80+
81+
await self.app(scope, receive, inspect_response)
82+
83+
5884
class _Route:
5985
def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
6086
self._path: str = kwargs["path"]
@@ -65,17 +91,6 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
6591
self._is_websocket: bool = kwargs.get("websocket", False)
6692
self._view: View | None = None
6793

68-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Response | None:
69-
request: Request | WebSocket = (
70-
WebSocket(scope, receive, send) if scope["type"] == "websocket" else Request(scope, receive, send)
71-
)
72-
73-
response: Response | None = await self._coro(self._view, request)
74-
if response is None:
75-
response = Response(status_code=500, content="Internal Server Error")
76-
77-
await response(scope, receive, send)
78-
7994

8095
LimitDecorator: TypeAlias = Callable[..., RouteCoro] | _Route
8196
T_LimitDecorator: TypeAlias = Callable[..., LimitDecorator]
@@ -144,7 +159,10 @@ def __init__(self, *args: Any, **kwargs: Unpack[ApplicationOptions]) -> None:
144159
self._access_log: bool = kwargs.pop("access_log", True)
145160
views: list[View] = kwargs.pop("views", [])
146161

147-
super().__init__(*args, **kwargs) # type: ignore
162+
middleware_: list[Middleware] = kwargs.pop("middleware", [])
163+
middleware_.insert(0, Middleware(LoggingMiddleware)) if self._access_log else None
164+
165+
super().__init__(*args, **kwargs, middleware=middleware_) # type: ignore
148166

149167
self.add_view(self)
150168
for view in views:
@@ -167,39 +185,23 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
167185
setattr(member, method, member._coro)
168186

169187
new: WebSocketRoute | Route
188+
endpoint: partial[RouteCoro] = partial(member._coro, self)
170189

171190
if member._is_websocket:
172-
new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}")
191+
new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}")
173192
else:
174-
new = Route(path=path, endpoint=member, methods=member._methods, name=f"{name}.{member._coro.__name__}")
193+
new = Route(
194+
path=path,
195+
endpoint=endpoint,
196+
methods=member._methods,
197+
name=f"{name}.{member._coro.__name__}",
198+
)
175199

176200
new.limits = getattr(member, "_limits", []) # type: ignore
177201
self.__routes__.append(new)
178202

179203
return self
180204

181-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
182-
if scope["type"] != "http" or not self._access_log:
183-
return await super().__call__(scope, receive, send)
184-
185-
method: str = scope["method"]
186-
path: str = scope["path"]
187-
client: str = f"{scope['client'][0]}:{scope['client'][1]}"
188-
version: str = scope["http_version"]
189-
190-
async def inspect_response(message: Message) -> None:
191-
nonlocal method, path, client
192-
193-
if message["type"] == "http.response.start":
194-
status_code: int = message.get("status", 200)
195-
msg: str = f'{client} - "{method} {path} HTTP/{version}" '
196-
197-
access_logger.info(msg, extra={"status": status_code})
198-
199-
await send(message)
200-
201-
await super().__call__(scope, receive, inspect_response)
202-
203205
@property
204206
def prefix(self) -> str:
205207
return self._prefix
@@ -225,7 +227,7 @@ def add_view(self, view: View | Self) -> None:
225227
new = Route(path, endpoint=route_.endpoint, methods=methods, name=route_.name)
226228

227229
new.limits = route_.limits # type: ignore
228-
self.router.routes.append(new)
230+
self.routes.append(new)
229231

230232
if isinstance(view, View):
231233
self._views.append(view)
@@ -259,11 +261,17 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
259261
setattr(member, method, member._coro)
260262

261263
new: WebSocketRoute | Route
264+
endpoint: partial[RouteCoro] = partial(member._coro, self)
262265

263266
if member._is_websocket:
264-
new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}")
267+
new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}")
265268
else:
266-
new = Route(path=path, endpoint=member, methods=member._methods, name=f"{name}.{member._coro.__name__}")
269+
new = Route(
270+
path=path,
271+
endpoint=endpoint,
272+
methods=member._methods,
273+
name=f"{name}.{member._coro.__name__}",
274+
)
267275

268276
new.limits = getattr(member, "_limits", []) # type: ignore
269277
self.__routes__.append(new)

0 commit comments

Comments
 (0)