diff --git a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py index 064355fdbd2..c8c731a2b8c 100644 --- a/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py +++ b/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py @@ -8,6 +8,7 @@ from typing import Any from aiohttp import web +from aiohttp.web_exceptions import HTTPError from aiohttp.web_request import Request from aiohttp.web_response import StreamResponse from common_library.error_codes import create_error_code @@ -18,6 +19,7 @@ from ..mimetype_constants import MIMETYPE_APPLICATION_JSON from ..rest_responses import is_enveloped_from_map, is_enveloped_from_text from ..utils import is_production_environ +from . import status from .rest_responses import ( create_data_response, create_http_error, @@ -26,6 +28,7 @@ ) from .rest_utils import EnvelopeFactory from .typing_extension import Handler, Middleware +from .web_exceptions_extension import get_http_error_class_or_none DEFAULT_API_VERSION = "v0" _FMSG_INTERNAL_ERROR_USER_FRIENDLY = ( @@ -42,110 +45,172 @@ def is_api_request(request: web.Request, api_version: str) -> bool: return bool(request.path.startswith(base_path)) -def error_middleware_factory( # noqa: C901 - api_version: str, -) -> Middleware: - _is_prod: bool = is_production_environ() +def _handle_unexpected_exception_as_500( + request: web.BaseRequest, + exception: Exception, + *, + skip_internal_error_details: bool, +) -> web.HTTPInternalServerError: + """Process unexpected exceptions and return them as HTTP errors with proper formatting. + + IMPORTANT: this function cannot throw exceptions, as it is called + """ + error_code = create_error_code(exception) + error_context: dict[str, Any] = { + "request.remote": f"{request.remote}", + "request.method": f"{request.method}", + "request.path": f"{request.path}", + } + + user_error_msg = _FMSG_INTERNAL_ERROR_USER_FRIENDLY + + http_error = create_http_error( + exception, + user_error_msg, + web.HTTPInternalServerError, + skip_internal_error_details=skip_internal_error_details, + error_code=error_code, + ) + + error_context["http_error"] = http_error - def _process_and_raise_unexpected_error(request: web.BaseRequest, err: Exception): - error_code = create_error_code(err) - error_context: dict[str, Any] = { - "request.remote": f"{request.remote}", - "request.method": f"{request.method}", - "request.path": f"{request.path}", - } - - user_error_msg = _FMSG_INTERNAL_ERROR_USER_FRIENDLY - http_error = create_http_error( - err, + _logger.exception( + **create_troubleshotting_log_kwargs( user_error_msg, - web.HTTPInternalServerError, - skip_internal_error_details=_is_prod, + error=exception, + error_context=error_context, error_code=error_code, ) - _logger.exception( - **create_troubleshotting_log_kwargs( - user_error_msg, - error=err, - error_context=error_context, - error_code=error_code, - ) + ) + return http_error + + +def _handle_http_error( + request: web.BaseRequest, exception: web.HTTPError +) -> web.HTTPError: + """Handle standard HTTP errors by ensuring they're properly formatted.""" + assert request # nosec + exception.content_type = MIMETYPE_APPLICATION_JSON + if exception.reason: + exception.set_status( + exception.status, safe_status_message(message=exception.reason) + ) + + if not exception.text or not is_enveloped_from_text(exception.text): + error_message = exception.text or exception.reason or "Unexpected error" + error_model = ErrorGet( + errors=[ + ErrorItemType.from_error(exception), + ], + status=exception.status, + logs=[ + LogMessageType(message=error_message, level="ERROR"), + ], + message=error_message, + ) + exception.text = EnvelopeFactory(error=error_model).as_text() + + return exception + + +def _handle_http_successful( + request: web.Request, exception: web.HTTPSuccessful +) -> web.HTTPSuccessful: + """Handle successful HTTP responses, ensuring they're properly enveloped.""" + assert request # nosec + + exception.content_type = MIMETYPE_APPLICATION_JSON + if exception.reason: + exception.set_status( + exception.status, safe_status_message(message=exception.reason) + ) + + if exception.text: + payload = json_loads(exception.text) + if not is_enveloped_from_map(payload): + payload = wrap_as_envelope(data=payload) + exception.text = json_dumps(payload) + + return exception + + +def _handle_exception_as_http_error( + request: web.Request, + exception: Exception, + status_code: int, + *, + skip_internal_error_details: bool, +) -> HTTPError: + """ + Generic handler for exceptions that map to specific HTTP status codes. + Converts the status code to the appropriate HTTP error class and creates a response. + """ + assert request # nosec + + http_error_cls = get_http_error_class_or_none(status_code) + if http_error_cls is None: + msg = ( + f"No HTTP error class found for status code {status_code}, falling back to 500", ) - raise http_error + raise ValueError(msg) + + return create_http_error( + exception, + f"{exception}", + http_error_cls, + skip_internal_error_details=skip_internal_error_details, + ) + + +def error_middleware_factory(api_version: str) -> Middleware: + _is_prod: bool = is_production_environ() @web.middleware - async def _middleware_handler(request: web.Request, handler: Handler): # noqa: C901 + async def _middleware_handler(request: web.Request, handler: Handler): """ Ensure all error raised are properly enveloped and json responses """ if not is_api_request(request, api_version): return await handler(request) - # FIXME: review when to send info to client and when not! try: - return await handler(request) + try: + result = await handler(request) - except web.HTTPError as err: - - err.content_type = MIMETYPE_APPLICATION_JSON - if err.reason: - err.set_status(err.status, safe_status_message(message=err.reason)) - - if not err.text or not is_enveloped_from_text(err.text): - error_message = err.text or err.reason or "Unexpected error" - error_model = ErrorGet( - errors=[ - ErrorItemType.from_error(err), - ], - status=err.status, - logs=[ - LogMessageType(message=error_message, level="ERROR"), - ], - message=error_message, + except web.HTTPError as exc: # 4XX and 5XX raised as exceptions + result = _handle_http_error(request, exc) + + except web.HTTPSuccessful as exc: # 2XX rased as exceptions + result = _handle_http_successful(request, exc) + + except web.HTTPRedirection as exc: # 3XX raised as exceptions + result = exc + + except NotImplementedError as exc: + result = _handle_exception_as_http_error( + request, + exc, + status.HTTP_501_NOT_IMPLEMENTED, + skip_internal_error_details=_is_prod, ) - err.text = EnvelopeFactory(error=error_model).as_text() - - raise - - except web.HTTPSuccessful as err: - err.content_type = MIMETYPE_APPLICATION_JSON - if err.reason: - err.set_status(err.status, safe_status_message(message=err.reason)) - - if err.text: - try: - payload = json_loads(err.text) - if not is_enveloped_from_map(payload): - payload = wrap_as_envelope(data=payload) - err.text = json_dumps(payload) - except Exception as other_error: # pylint: disable=broad-except - _process_and_raise_unexpected_error(request, other_error) - raise - - except web.HTTPRedirection as err: - _logger.debug("Redirected to %s", err) - raise - - except NotImplementedError as err: - http_error = create_http_error( - err, - f"{err}", - web.HTTPNotImplemented, - skip_internal_error_details=_is_prod, - ) - raise http_error from err - - except TimeoutError as err: - http_error = create_http_error( - err, - f"{err}", - web.HTTPGatewayTimeout, - skip_internal_error_details=_is_prod, + + except TimeoutError as exc: + result = _handle_exception_as_http_error( + request, + exc, + status.HTTP_504_GATEWAY_TIMEOUT, + skip_internal_error_details=_is_prod, + ) + + except Exception as exc: # pylint: disable=broad-except + # + # Last resort for unexpected exceptions (including those raise by the exception handlers!) + # + result = _handle_unexpected_exception_as_500( + request, exc, skip_internal_error_details=_is_prod ) - raise http_error from err - except Exception as err: # pylint: disable=broad-except - _process_and_raise_unexpected_error(request, err) + return result # adds identifier (mostly for debugging) setattr( # noqa: B010 diff --git a/packages/service-library/src/servicelib/aiohttp/rest_responses.py b/packages/service-library/src/servicelib/aiohttp/rest_responses.py index 3986de59700..8dfa59cf9b1 100644 --- a/packages/service-library/src/servicelib/aiohttp/rest_responses.py +++ b/packages/service-library/src/servicelib/aiohttp/rest_responses.py @@ -1,4 +1,4 @@ -from typing import Any, Final, TypedDict +from typing import Any, Final, TypedDict, TypeVar from aiohttp import web from aiohttp.web_exceptions import HTTPError @@ -10,7 +10,7 @@ from ..mimetype_constants import MIMETYPE_APPLICATION_JSON from ..rest_constants import RESPONSE_MODEL_POLICY from ..rest_responses import is_enveloped -from ..status_codes_utils import get_code_description, is_error +from ..status_codes_utils import get_code_description, get_code_display_name, is_error class EnvelopeDict(TypedDict): @@ -69,32 +69,38 @@ def safe_status_message( return flat_message[: max_length - 3] + "..." +T_HTTPError = TypeVar("T_HTTPError", bound=HTTPError) + + def create_http_error( errors: list[Exception] | Exception, error_message: str | None = None, - http_error_cls: type[HTTPError] = web.HTTPInternalServerError, + http_error_cls: type[ + T_HTTPError + ] = web.HTTPInternalServerError, # type: ignore[assignment] *, status_reason: str | None = None, skip_internal_error_details: bool = False, error_code: ErrorCodeStr | None = None, -) -> HTTPError: +) -> T_HTTPError: """ - Response body conforms OAS schema model - Can skip internal details when 500 status e.g. to avoid transmitting server exceptions to the client in production """ - if not isinstance(errors, list): - errors = [errors] - - is_internal_error = bool(http_error_cls == web.HTTPInternalServerError) - status_reason = status_reason or get_code_description(http_error_cls.status_code) + status_reason = status_reason or get_code_display_name(http_error_cls.status_code) error_message = error_message or get_code_description(http_error_cls.status_code) assert len(status_reason) < MAX_STATUS_MESSAGE_LENGTH # nosec + # WARNING: do not refactor too much this function withouth considering how + # front-end handle errors. i.e. please sync with front-end developers before + # changing the workflows in this function + + is_internal_error = bool(http_error_cls == web.HTTPInternalServerError) if is_internal_error and skip_internal_error_details: - error = ErrorGet.model_validate( + error_model = ErrorGet.model_validate( { "status": http_error_cls.status_code, "message": error_message, @@ -102,8 +108,11 @@ def create_http_error( } ) else: + if not isinstance(errors, list): + errors = [errors] + items = [ErrorItemType.from_error(err) for err in errors] - error = ErrorGet.model_validate( + error_model = ErrorGet.model_validate( { "errors": items, # NOTE: deprecated! "status": http_error_cls.status_code, @@ -113,15 +122,14 @@ def create_http_error( ) assert not http_error_cls.empty_body # nosec + payload = wrap_as_envelope( - error=error.model_dump(mode="json", **RESPONSE_MODEL_POLICY) + error=error_model.model_dump(mode="json", **RESPONSE_MODEL_POLICY) ) return http_error_cls( reason=safe_status_message(status_reason), - text=json_dumps( - payload, - ), + text=json_dumps(payload), content_type=MIMETYPE_APPLICATION_JSON, ) diff --git a/packages/service-library/tests/aiohttp/test_rest_middlewares.py b/packages/service-library/tests/aiohttp/test_rest_middlewares.py index de5e80b85ae..8fac9803906 100644 --- a/packages/service-library/tests/aiohttp/test_rest_middlewares.py +++ b/packages/service-library/tests/aiohttp/test_rest_middlewares.py @@ -14,6 +14,7 @@ from aiohttp import web from aiohttp.test_utils import TestClient from common_library.json_serialization import json_dumps +from pytest_mock import MockerFixture from servicelib.aiohttp import status from servicelib.aiohttp.rest_middlewares import ( envelope_middleware_factory, @@ -269,3 +270,111 @@ async def test_raised_unhandled_exception( # log OEC assert "OEC:" in caplog.text + + +async def test_not_implemented_error_is_501(client: TestClient): + """Test that NotImplementedError is correctly mapped to HTTP 501 NOT IMPLEMENTED.""" + response = await client.get( + "/v1/raise_exception", params={"exc": NotImplementedError.__name__} + ) + assert response.status == status.HTTP_501_NOT_IMPLEMENTED + + # Check that the response is properly enveloped + payload = await response.json() + assert is_enveloped(payload) + + # Verify error details + data, error = unwrap_envelope(payload) + assert not data + assert error + assert error.get("status") == status.HTTP_501_NOT_IMPLEMENTED + + +async def test_timeout_error_is_504(client: TestClient): + """Test that TimeoutError is correctly mapped to HTTP 504 GATEWAY TIMEOUT.""" + response = await client.get( + "/v1/raise_exception", params={"exc": asyncio.TimeoutError.__name__} + ) + assert response.status == status.HTTP_504_GATEWAY_TIMEOUT + + # Check that the response is properly enveloped + payload = await response.json() + assert is_enveloped(payload) + + # Verify error details + data, error = unwrap_envelope(payload) + assert not data + assert error + assert error.get("status") == status.HTTP_504_GATEWAY_TIMEOUT + + +async def test_exception_in_non_api_route(client: TestClient): + """Test how exceptions are handled in routes not under the API path.""" + response = await client.get("/free/raise_exception") + + # This should be a raw exception, not processed by our middleware + assert response.status == status.HTTP_500_INTERNAL_SERVER_ERROR + + # Should not be enveloped since it's outside the API path + text = await response.text() + try: + # If it happens to be JSON, check it's not enveloped + payload = json.loads(text) + assert not is_enveloped(payload) + except json.JSONDecodeError: + # If it's not JSON, that's expected too + pass + + +async def test_http_ok_with_text_is_enveloped(client: TestClient): + """Test that HTTPOk with text is properly enveloped.""" + response = await client.get("/v1/raise_success_with_text") + assert response.status == status.HTTP_200_OK + + # Should be enveloped + payload = await response.json() + assert is_enveloped(payload) + + # Check the content was preserved + data, error = unwrap_envelope(payload) + assert not error + assert data + assert data.get("ok") is True + + +async def test_exception_in_handler_returns_500( + client: TestClient, mocker: MockerFixture +): + """Test that exceptions in the handler functions are caught and return 500.""" + + # Mock _handle_http_successful to raise an exception + def mocked_handler(*args, **kwargs): + msg = "Simulated error in handler" + raise ValueError(msg) + + mocker.patch( + "servicelib.aiohttp.rest_middlewares._handle_http_successful", + side_effect=mocked_handler, + ) + + # Trigger a successful HTTP response that will be processed by our mocked handler + response = await client.get( + "/v1/raise_exception", params={"exc": web.HTTPOk.__name__} + ) + + # Should return 500 since our handler raised an exception + assert response.status == status.HTTP_500_INTERNAL_SERVER_ERROR + + # Check that the response is properly enveloped + payload = await response.json() + assert is_enveloped(payload) + + # Verify error details + data, error = unwrap_envelope(payload) + assert not data + assert error + assert error.get("status") == status.HTTP_500_INTERNAL_SERVER_ERROR + + # Make sure there are no detailed error logs in production mode + assert not error.get("errors") + assert not error.get("logs")