Skip to content

Commit a59bbdf

Browse files
authored
Improve typing in baseapp module (#640)
* Improve typing in baseapp module * Fix type anotation for py313 and less
1 parent 1f8ce5b commit a59bbdf

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

blacksheep/baseapp.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import logging
44
from collections import UserDict
5+
import typing
56

67
from blacksheep.server.errors import ServerErrorDetailsHandler
78
from blacksheep.server.routing import Router
@@ -17,8 +18,11 @@
1718
ValidationError = None
1819

1920

20-
class ExceptionHandlersDict(UserDict):
21+
if typing.TYPE_CHECKING:
22+
from .messages import Request
23+
2124

25+
class ExceptionHandlersDict(UserDict):
2226
def __setitem__(self, key, item) -> None:
2327
if not inspect.iscoroutinefunction(item):
2428
raise InvalidExceptionHandler()
@@ -32,15 +36,15 @@ def __setitem__(self, key, item) -> None:
3236
return super().__setitem__(key, item)
3337

3438

35-
async def handle_not_found(app, request, http_exception):
39+
async def handle_not_found(app, request, http_exception) -> Response:
3640
return Response(404, content=TextContent("Resource not found"))
3741

3842

39-
async def handle_internal_server_error(app, request, exception):
43+
async def handle_internal_server_error(app, request, exception) -> Response:
4044
return Response(500, content=TextContent("Internal Server Error"))
4145

4246

43-
async def handle_bad_request(app, request, http_exception):
47+
async def handle_bad_request(app, request, http_exception) -> Response:
4448
if getattr(http_exception, "__context__", None) is not None and callable(
4549
getattr(http_exception.__context__, "json", None)
4650
):
@@ -53,20 +57,20 @@ async def handle_bad_request(app, request, http_exception):
5357
return Response(400, content=TextContent(f"Bad Request: {str(http_exception)}"))
5458

5559

56-
async def _default_pydantic_validation_error_handler(app, request, error):
60+
async def _default_pydantic_validation_error_handler(app, request, error) -> Response:
5761
return Response(
5862
400, content=Content(b"application/json", error.json(indent=4).encode("utf-8"))
5963
)
6064

6165

62-
async def common_http_exception_handler(app, request, http_exception):
66+
async def common_http_exception_handler(app, request, http_exception) -> Response:
6367
return Response(
6468
http_exception.status,
6569
content=TextContent(http.HTTPStatus(http_exception.status).phrase),
6670
)
6771

6872

69-
def get_logger():
73+
def get_logger() -> logging.Logger:
7074
logger = logging.getLogger("blacksheep.server")
7175
logger.setLevel(logging.INFO)
7276
return logger
@@ -82,7 +86,7 @@ def __init__(self, show_error_details, router):
8286
self.logger = get_logger()
8387
self.server_error_details_handler: ServerErrorDetailsHandler
8488

85-
def init_exceptions_handlers(self):
89+
def init_exceptions_handlers(self) -> ExceptionHandlersDict:
8690
default_handlers = ExceptionHandlersDict(
8791
{404: handle_not_found, 400: handle_bad_request}
8892
)
@@ -92,15 +96,23 @@ def init_exceptions_handlers(self):
9296
)
9397
return default_handlers
9498

95-
async def log_unhandled_exc(self, request, exc):
99+
async def log_unhandled_exc(
100+
self,
101+
request: "Request",
102+
exc: Exception,
103+
):
96104
self.logger.error(
97105
'Unhandled exception - "%s %s"',
98106
request.method,
99107
request.url.value.decode(),
100108
exc_info=exc,
101109
)
102110

103-
async def log_handled_exc(self, request, exc):
111+
async def log_handled_exc(
112+
self,
113+
request: "Request",
114+
exc: Exception,
115+
):
104116
if isinstance(exc, HTTPException):
105117
self.logger.info(
106118
'HTTP %s - "%s %s". %s',
@@ -117,7 +129,7 @@ async def log_handled_exc(self, request, exc):
117129
str(exc),
118130
)
119131

120-
async def handle(self, request):
132+
async def handle(self, request: "Request") -> Response:
121133
route = self.router.get_match(request)
122134

123135
if not route:
@@ -133,7 +145,11 @@ async def handle(self, request):
133145
response = await self.handle_request_handler_exception(request, exc)
134146
return response or Response(204)
135147

136-
async def handle_request_handler_exception(self, request, exc):
148+
async def handle_request_handler_exception(
149+
self,
150+
request: "Request",
151+
exc: Exception,
152+
) -> Response:
137153
if isinstance(exc, HTTPException):
138154
await self.log_handled_exc(request, exc)
139155
return await self.handle_http_exception(request, exc)
@@ -143,29 +159,46 @@ async def handle_request_handler_exception(self, request, exc):
143159
await self.log_unhandled_exc(request, exc)
144160
return await self.handle_exception(request, exc)
145161

146-
def get_http_exception_handler(self, http_exception):
162+
def get_http_exception_handler(
163+
self, http_exception: HTTPException
164+
) -> typing.Callable[
165+
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
166+
]:
147167
handler = self.get_exception_handler(http_exception, stop_at=HTTPException)
148168
if handler:
149169
return handler
150170
return self.exceptions_handlers.get(
151171
getattr(http_exception, "status", None), common_http_exception_handler
152172
)
153173

154-
def is_handled_exception(self, exception):
174+
def is_handled_exception(self, exception) -> bool:
155175
for class_type in get_class_instance_hierarchy(exception):
156176
if class_type in self.exceptions_handlers:
157177
return True
158178
return False
159179

160-
def get_exception_handler(self, exception, stop_at):
180+
def get_exception_handler(
181+
self,
182+
exception: Exception,
183+
stop_at: type | None,
184+
) -> (
185+
typing.Callable[
186+
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
187+
]
188+
| None
189+
):
161190
for class_type in get_class_instance_hierarchy(exception):
162191
if stop_at is not None and stop_at is class_type:
163192
return None
164193
if class_type in self.exceptions_handlers:
165194
return self.exceptions_handlers[class_type]
166195
return None
167196

168-
async def handle_internal_server_error(self, request, exc):
197+
async def handle_internal_server_error(
198+
self,
199+
request: "Request",
200+
exc,
201+
) -> Response:
169202
if self.show_error_details:
170203
return self.server_error_details_handler.produce_response(request, exc)
171204
error = InternalServerError(exc)
@@ -179,7 +212,14 @@ async def handle_internal_server_error(self, request, exc):
179212
)
180213
return Response(500, content=TextContent("Internal Server Error"))
181214

182-
async def _apply_exception_handler(self, request, exc, exception_handler):
215+
async def _apply_exception_handler(
216+
self,
217+
request: "Request",
218+
exc: Exception,
219+
exception_handler: typing.Callable[
220+
["BaseApplication", "Request", Exception], typing.Awaitable[Response]
221+
],
222+
):
183223
try:
184224
return await exception_handler(self, request, exc)
185225
except Exception as server_ex:
@@ -194,15 +234,19 @@ async def _apply_exception_handler(self, request, exc, exception_handler):
194234

195235
return await handle_internal_server_error(self, request, server_ex)
196236

197-
async def handle_http_exception(self, request, http_exception):
237+
async def handle_http_exception(
238+
self,
239+
request: "Request",
240+
http_exception: HTTPException,
241+
) -> Response:
198242
exception_handler = self.get_http_exception_handler(http_exception)
199243
if exception_handler:
200244
return await self._apply_exception_handler(
201245
request, http_exception, exception_handler
202246
)
203247
return await self.handle_exception(request, http_exception)
204248

205-
async def handle_exception(self, request, exc):
249+
async def handle_exception(self, request: "Request", exc: Exception) -> Response:
206250
exception_handler = self.get_exception_handler(exc, None)
207251
if exception_handler:
208252
return await self._apply_exception_handler(request, exc, exception_handler)

blacksheep/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def join_fragments(*args: AnyStr) -> str:
3131
)
3232

3333

34-
def get_class_hierarchy(cls: Type[T]):
34+
def get_class_hierarchy(cls: Type[T]) -> tuple[Type[T], ...]:
3535
return cls.__mro__
3636

3737

38-
def get_class_instance_hierarchy(instance: T):
38+
def get_class_instance_hierarchy(instance: T) -> tuple[Type[T], ...]:
3939
return get_class_hierarchy(type(instance))
4040

4141

0 commit comments

Comments
 (0)