|
1 | 1 | import http |
| 2 | +import inspect |
2 | 3 | import logging |
| 4 | +from collections import UserDict |
| 5 | + |
| 6 | +from blacksheep.server.errors import ServerErrorDetailsHandler |
3 | 7 |
|
4 | 8 | from .contents import Content, TextContent |
5 | | -from .exceptions import HTTPException, InternalServerError, NotFound |
| 9 | +from .exceptions import ( |
| 10 | + HTTPException, |
| 11 | + InternalServerError, |
| 12 | + InvalidExceptionHandler, |
| 13 | + NotFound, |
| 14 | +) |
6 | 15 | from .messages import Response |
7 | 16 | from .utils import get_class_instance_hierarchy |
8 | 17 |
|
|
12 | 21 | ValidationError = None |
13 | 22 |
|
14 | 23 |
|
| 24 | +class ExceptionHandlersDict(UserDict): |
| 25 | + |
| 26 | + def __setitem__(self, key, item) -> None: |
| 27 | + if not inspect.iscoroutinefunction(item): |
| 28 | + raise InvalidExceptionHandler() |
| 29 | + signature = inspect.Signature.from_callable(item) |
| 30 | + if len(signature.parameters) != 3 and not any( |
| 31 | + param |
| 32 | + for param in signature.parameters |
| 33 | + if signature.parameters[param].kind == 2 |
| 34 | + ): |
| 35 | + raise InvalidExceptionHandler() |
| 36 | + return super().__setitem__(key, item) |
| 37 | + |
| 38 | + |
15 | 39 | async def handle_not_found(app, request, http_exception): |
16 | 40 | return Response(404, content=TextContent("Resource not found")) |
17 | 41 |
|
@@ -58,9 +82,12 @@ def __init__(self, show_error_details, router): |
58 | 82 | self.exceptions_handlers = self.init_exceptions_handlers() |
59 | 83 | self.show_error_details = show_error_details |
60 | 84 | self.logger = get_logger() |
| 85 | + self.server_error_details_handler: ServerErrorDetailsHandler |
61 | 86 |
|
62 | 87 | def init_exceptions_handlers(self): |
63 | | - default_handlers = {404: handle_not_found, 400: handle_bad_request} |
| 88 | + default_handlers = ExceptionHandlersDict( |
| 89 | + {404: handle_not_found, 400: handle_bad_request} |
| 90 | + ) |
64 | 91 | if ValidationError is not None: |
65 | 92 | default_handlers[ValidationError] = ( |
66 | 93 | _default_pydantic_validation_error_handler |
@@ -159,7 +186,16 @@ async def _apply_exception_handler(self, request, exc, exception_handler): |
159 | 186 | try: |
160 | 187 | return await exception_handler(self, request, exc) |
161 | 188 | except Exception as server_ex: |
162 | | - return await self.handle_exception(request, server_ex) |
| 189 | + # If the exception happens in the user-defined exception handler, |
| 190 | + # we need to fallback to the default handlers. |
| 191 | + self.logger.error( |
| 192 | + "Unhandled exception in exception_handler: %s", |
| 193 | + exception_handler.__name__, |
| 194 | + ) |
| 195 | + if self.show_error_details: |
| 196 | + return self.server_error_details_handler.produce_response(request, exc) |
| 197 | + |
| 198 | + return await handle_internal_server_error(self, request, server_ex) |
163 | 199 |
|
164 | 200 | async def handle_http_exception(self, request, http_exception): |
165 | 201 | exception_handler = self.get_http_exception_handler(http_exception) |
|
0 commit comments