1919import inspect
2020import logging
2121from collections .abc import Callable , Coroutine , Iterator , Sequence
22+ from functools import partial
2223from typing import TYPE_CHECKING , Any , ClassVar , Self , TypeAlias , TypedDict , Unpack
2324
2425from starlette .applications import Starlette
25- from starlette .requests import Request
26- from starlette .responses import Response
26+ from starlette .middleware import Middleware
2727from starlette .routing import Mount , Route , WebSocketRoute
2828from starlette .types import Receive , Scope , Send
29- from starlette .websockets import WebSocket
3029
3130from .types_ .core import RouteCoro
3231
3332
3433if TYPE_CHECKING :
35- from starlette .middleware import Middleware
36- from starlette .types import Message , Receive , Scope , Send
34+ from starlette .types import ASGIApp , Message , Receive , Scope , Send
3735
3836 from .types_ .core import Methods , RouteOptions
3937 from .types_ .limiter import BucketType , ExemptCallable , RateLimitData
@@ -55,6 +53,34 @@ class ApplicationOptions(TypedDict, total=False):
5553__all__ = ("Application" , "View" , "route" , "limit" )
5654
5755
56+ class LoggingMiddleware :
57+ def __init__ (self , app : ASGIApp ) -> None :
58+ self .app = app
59+
60+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
61+ if scope ["type" ] != "http" :
62+ await self .app (scope , receive , send )
63+ return
64+
65+ method : str = scope ["method" ]
66+ path : str = scope ["path" ]
67+ client : str = f"{ scope ['client' ][0 ]} :{ scope ['client' ][1 ]} "
68+ version : str = scope ["http_version" ]
69+
70+ async def inspect_response (message : Message ) -> None :
71+ nonlocal method , path , client , version
72+
73+ if message ["type" ] == "http.response.start" :
74+ status_code : int = message .get ("status" , 200 )
75+ msg : str = f'{ client } - "{ method } { path } HTTP/{ version } " '
76+
77+ access_logger .info (msg , extra = {"status" : status_code })
78+
79+ await send (message )
80+
81+ await self .app (scope , receive , inspect_response )
82+
83+
5884class _Route :
5985 def __init__ (self , ** kwargs : Unpack [RouteOptions ]) -> None :
6086 self ._path : str = kwargs ["path" ]
@@ -65,17 +91,6 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
6591 self ._is_websocket : bool = kwargs .get ("websocket" , False )
6692 self ._view : View | None = None
6793
68- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> Response | None :
69- request : Request | WebSocket = (
70- WebSocket (scope , receive , send ) if scope ["type" ] == "websocket" else Request (scope , receive , send )
71- )
72-
73- response : Response | None = await self ._coro (self ._view , request )
74- if response is None :
75- response = Response (status_code = 500 , content = "Internal Server Error" )
76-
77- await response (scope , receive , send )
78-
7994
8095LimitDecorator : TypeAlias = Callable [..., RouteCoro ] | _Route
8196T_LimitDecorator : TypeAlias = Callable [..., LimitDecorator ]
@@ -144,7 +159,10 @@ def __init__(self, *args: Any, **kwargs: Unpack[ApplicationOptions]) -> None:
144159 self ._access_log : bool = kwargs .pop ("access_log" , True )
145160 views : list [View ] = kwargs .pop ("views" , [])
146161
147- super ().__init__ (* args , ** kwargs ) # type: ignore
162+ middleware_ : list [Middleware ] = kwargs .pop ("middleware" , [])
163+ middleware_ .insert (0 , Middleware (LoggingMiddleware )) if self ._access_log else None
164+
165+ super ().__init__ (* args , ** kwargs , middleware = middleware_ ) # type: ignore
148166
149167 self .add_view (self )
150168 for view in views :
@@ -167,39 +185,23 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
167185 setattr (member , method , member ._coro )
168186
169187 new : WebSocketRoute | Route
188+ endpoint : partial [RouteCoro ] = partial (member ._coro , self )
170189
171190 if member ._is_websocket :
172- new = WebSocketRoute (path = path , endpoint = member , name = f"{ name } .{ member ._coro .__name__ } " )
191+ new = WebSocketRoute (path = path , endpoint = endpoint , name = f"{ name } .{ member ._coro .__name__ } " )
173192 else :
174- new = Route (path = path , endpoint = member , methods = member ._methods , name = f"{ name } .{ member ._coro .__name__ } " )
193+ new = Route (
194+ path = path ,
195+ endpoint = endpoint ,
196+ methods = member ._methods ,
197+ name = f"{ name } .{ member ._coro .__name__ } " ,
198+ )
175199
176200 new .limits = getattr (member , "_limits" , []) # type: ignore
177201 self .__routes__ .append (new )
178202
179203 return self
180204
181- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
182- if scope ["type" ] != "http" or not self ._access_log :
183- return await super ().__call__ (scope , receive , send )
184-
185- method : str = scope ["method" ]
186- path : str = scope ["path" ]
187- client : str = f"{ scope ['client' ][0 ]} :{ scope ['client' ][1 ]} "
188- version : str = scope ["http_version" ]
189-
190- async def inspect_response (message : Message ) -> None :
191- nonlocal method , path , client
192-
193- if message ["type" ] == "http.response.start" :
194- status_code : int = message .get ("status" , 200 )
195- msg : str = f'{ client } - "{ method } { path } HTTP/{ version } " '
196-
197- access_logger .info (msg , extra = {"status" : status_code })
198-
199- await send (message )
200-
201- await super ().__call__ (scope , receive , inspect_response )
202-
203205 @property
204206 def prefix (self ) -> str :
205207 return self ._prefix
@@ -225,7 +227,7 @@ def add_view(self, view: View | Self) -> None:
225227 new = Route (path , endpoint = route_ .endpoint , methods = methods , name = route_ .name )
226228
227229 new .limits = route_ .limits # type: ignore
228- self .router . routes .append (new )
230+ self .routes .append (new )
229231
230232 if isinstance (view , View ):
231233 self ._views .append (view )
@@ -259,11 +261,17 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
259261 setattr (member , method , member ._coro )
260262
261263 new : WebSocketRoute | Route
264+ endpoint : partial [RouteCoro ] = partial (member ._coro , self )
262265
263266 if member ._is_websocket :
264- new = WebSocketRoute (path = path , endpoint = member , name = f"{ name } .{ member ._coro .__name__ } " )
267+ new = WebSocketRoute (path = path , endpoint = endpoint , name = f"{ name } .{ member ._coro .__name__ } " )
265268 else :
266- new = Route (path = path , endpoint = member , methods = member ._methods , name = f"{ name } .{ member ._coro .__name__ } " )
269+ new = Route (
270+ path = path ,
271+ endpoint = endpoint ,
272+ methods = member ._methods ,
273+ name = f"{ name } .{ member ._coro .__name__ } " ,
274+ )
267275
268276 new .limits = getattr (member , "_limits" , []) # type: ignore
269277 self .__routes__ .append (new )
0 commit comments