Skip to content
Merged
237 changes: 151 additions & 86 deletions packages/service-library/src/servicelib/aiohttp/rest_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

from aiohttp import web
from aiohttp.web_exceptions import HTTPError
from aiohttp.web_request import Request
from aiohttp.web_response import StreamResponse
from common_library.error_codes import create_error_code
Expand All @@ -18,6 +19,7 @@
from ..mimetype_constants import MIMETYPE_APPLICATION_JSON
from ..rest_responses import is_enveloped_from_map, is_enveloped_from_text
from ..utils import is_production_environ
from . import status
from .rest_responses import (
create_data_response,
create_http_error,
Expand All @@ -26,6 +28,7 @@
)
from .rest_utils import EnvelopeFactory
from .typing_extension import Handler, Middleware
from .web_exceptions_extension import get_http_error_class_or_none

DEFAULT_API_VERSION = "v0"
_FMSG_INTERNAL_ERROR_USER_FRIENDLY = (
Expand All @@ -42,110 +45,172 @@ def is_api_request(request: web.Request, api_version: str) -> bool:
return bool(request.path.startswith(base_path))


def error_middleware_factory( # noqa: C901
api_version: str,
) -> Middleware:
_is_prod: bool = is_production_environ()
def _handle_unexpected_exception_as_500(
request: web.BaseRequest,
exception: Exception,
*,
skip_internal_error_details: bool,
) -> web.HTTPInternalServerError:
"""Process unexpected exceptions and return them as HTTP errors with proper formatting.

IMPORTANT: this function cannot throw exceptions, as it is called
"""
error_code = create_error_code(exception)
error_context: dict[str, Any] = {
"request.remote": f"{request.remote}",
"request.method": f"{request.method}",
"request.path": f"{request.path}",
}

user_error_msg = _FMSG_INTERNAL_ERROR_USER_FRIENDLY

http_error = create_http_error(
exception,
user_error_msg,
web.HTTPInternalServerError,
skip_internal_error_details=skip_internal_error_details,
error_code=error_code,
)

error_context["http_error"] = http_error

def _process_and_raise_unexpected_error(request: web.BaseRequest, err: Exception):
error_code = create_error_code(err)
error_context: dict[str, Any] = {
"request.remote": f"{request.remote}",
"request.method": f"{request.method}",
"request.path": f"{request.path}",
}

user_error_msg = _FMSG_INTERNAL_ERROR_USER_FRIENDLY
http_error = create_http_error(
err,
_logger.exception(
**create_troubleshotting_log_kwargs(
user_error_msg,
web.HTTPInternalServerError,
skip_internal_error_details=_is_prod,
error=exception,
error_context=error_context,
error_code=error_code,
)
_logger.exception(
**create_troubleshotting_log_kwargs(
user_error_msg,
error=err,
error_context=error_context,
error_code=error_code,
)
)
return http_error


def _handle_http_error(
request: web.BaseRequest, exception: web.HTTPError
) -> web.HTTPError:
"""Handle standard HTTP errors by ensuring they're properly formatted."""
assert request # nosec
exception.content_type = MIMETYPE_APPLICATION_JSON
if exception.reason:
exception.set_status(
exception.status, safe_status_message(message=exception.reason)
)

if not exception.text or not is_enveloped_from_text(exception.text):
error_message = exception.text or exception.reason or "Unexpected error"
error_model = ErrorGet(
errors=[
ErrorItemType.from_error(exception),
],
status=exception.status,
logs=[
LogMessageType(message=error_message, level="ERROR"),
],
message=error_message,
)
exception.text = EnvelopeFactory(error=error_model).as_text()

return exception


def _handle_http_successful(
request: web.Request, exception: web.HTTPSuccessful
) -> web.HTTPSuccessful:
"""Handle successful HTTP responses, ensuring they're properly enveloped."""
assert request # nosec

exception.content_type = MIMETYPE_APPLICATION_JSON
if exception.reason:
exception.set_status(
exception.status, safe_status_message(message=exception.reason)
)

if exception.text:
payload = json_loads(exception.text)
if not is_enveloped_from_map(payload):
payload = wrap_as_envelope(data=payload)
exception.text = json_dumps(payload)

return exception


def _handle_exception_as_http_error(
request: web.Request,
exception: Exception,
status_code: int,
*,
skip_internal_error_details: bool,
) -> HTTPError:
"""
Generic handler for exceptions that map to specific HTTP status codes.
Converts the status code to the appropriate HTTP error class and creates a response.
"""
assert request # nosec

http_error_cls = get_http_error_class_or_none(status_code)
if http_error_cls is None:
msg = (
f"No HTTP error class found for status code {status_code}, falling back to 500",
)
raise http_error
raise ValueError(msg)

return create_http_error(
exception,
f"{exception}",
http_error_cls,
skip_internal_error_details=skip_internal_error_details,
)


def error_middleware_factory(api_version: str) -> Middleware:
_is_prod: bool = is_production_environ()

@web.middleware
async def _middleware_handler(request: web.Request, handler: Handler): # noqa: C901
async def _middleware_handler(request: web.Request, handler: Handler):
"""
Ensure all error raised are properly enveloped and json responses
"""
if not is_api_request(request, api_version):
return await handler(request)

# FIXME: review when to send info to client and when not!
try:
return await handler(request)
try:
result = await handler(request)

except web.HTTPError as err:

err.content_type = MIMETYPE_APPLICATION_JSON
if err.reason:
err.set_status(err.status, safe_status_message(message=err.reason))

if not err.text or not is_enveloped_from_text(err.text):
error_message = err.text or err.reason or "Unexpected error"
error_model = ErrorGet(
errors=[
ErrorItemType.from_error(err),
],
status=err.status,
logs=[
LogMessageType(message=error_message, level="ERROR"),
],
message=error_message,
except web.HTTPError as exc: # 4XX and 5XX raised as exceptions
result = _handle_http_error(request, exc)

except web.HTTPSuccessful as exc: # 2XX rased as exceptions
result = _handle_http_successful(request, exc)

except web.HTTPRedirection as exc: # 3XX raised as exceptions
result = exc

except NotImplementedError as exc:
result = _handle_exception_as_http_error(
request,
exc,
status.HTTP_501_NOT_IMPLEMENTED,
skip_internal_error_details=_is_prod,
)
err.text = EnvelopeFactory(error=error_model).as_text()

raise

except web.HTTPSuccessful as err:
err.content_type = MIMETYPE_APPLICATION_JSON
if err.reason:
err.set_status(err.status, safe_status_message(message=err.reason))

if err.text:
try:
payload = json_loads(err.text)
if not is_enveloped_from_map(payload):
payload = wrap_as_envelope(data=payload)
err.text = json_dumps(payload)
except Exception as other_error: # pylint: disable=broad-except
_process_and_raise_unexpected_error(request, other_error)
raise

except web.HTTPRedirection as err:
_logger.debug("Redirected to %s", err)
raise

except NotImplementedError as err:
http_error = create_http_error(
err,
f"{err}",
web.HTTPNotImplemented,
skip_internal_error_details=_is_prod,
)
raise http_error from err

except TimeoutError as err:
http_error = create_http_error(
err,
f"{err}",
web.HTTPGatewayTimeout,
skip_internal_error_details=_is_prod,

except TimeoutError as exc:
result = _handle_exception_as_http_error(
request,
exc,
status.HTTP_504_GATEWAY_TIMEOUT,
skip_internal_error_details=_is_prod,
)

except Exception as exc: # pylint: disable=broad-except
#
# Last resort for unexpected exceptions (including those raise by the exception handlers!)
#
result = _handle_unexpected_exception_as_500(
request, exc, skip_internal_error_details=_is_prod
)
raise http_error from err

except Exception as err: # pylint: disable=broad-except
_process_and_raise_unexpected_error(request, err)
return result

# adds identifier (mostly for debugging)
setattr( # noqa: B010
Expand Down
38 changes: 23 additions & 15 deletions packages/service-library/src/servicelib/aiohttp/rest_responses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Final, TypedDict
from typing import Any, Final, TypedDict, TypeVar

from aiohttp import web
from aiohttp.web_exceptions import HTTPError
Expand All @@ -10,7 +10,7 @@
from ..mimetype_constants import MIMETYPE_APPLICATION_JSON
from ..rest_constants import RESPONSE_MODEL_POLICY
from ..rest_responses import is_enveloped
from ..status_codes_utils import get_code_description, is_error
from ..status_codes_utils import get_code_description, get_code_display_name, is_error


class EnvelopeDict(TypedDict):
Expand Down Expand Up @@ -69,41 +69,50 @@ def safe_status_message(
return flat_message[: max_length - 3] + "..."


T_HTTPError = TypeVar("T_HTTPError", bound=HTTPError)


def create_http_error(
errors: list[Exception] | Exception,
error_message: str | None = None,
http_error_cls: type[HTTPError] = web.HTTPInternalServerError,
http_error_cls: type[
T_HTTPError
] = web.HTTPInternalServerError, # type: ignore[assignment]
*,
status_reason: str | None = None,
skip_internal_error_details: bool = False,
error_code: ErrorCodeStr | None = None,
) -> HTTPError:
) -> T_HTTPError:
"""
- Response body conforms OAS schema model
- Can skip internal details when 500 status e.g. to avoid transmitting server
exceptions to the client in production
"""
if not isinstance(errors, list):
errors = [errors]

is_internal_error = bool(http_error_cls == web.HTTPInternalServerError)

status_reason = status_reason or get_code_description(http_error_cls.status_code)
status_reason = status_reason or get_code_display_name(http_error_cls.status_code)
error_message = error_message or get_code_description(http_error_cls.status_code)

assert len(status_reason) < MAX_STATUS_MESSAGE_LENGTH # nosec

# WARNING: do not refactor too much this function withouth considering how
# front-end handle errors. i.e. please sync with front-end developers before
# changing the workflows in this function

is_internal_error = bool(http_error_cls == web.HTTPInternalServerError)
if is_internal_error and skip_internal_error_details:
error = ErrorGet.model_validate(
error_model = ErrorGet.model_validate(
{
"status": http_error_cls.status_code,
"message": error_message,
"support_id": error_code,
}
)
else:
if not isinstance(errors, list):
errors = [errors]

items = [ErrorItemType.from_error(err) for err in errors]
error = ErrorGet.model_validate(
error_model = ErrorGet.model_validate(
{
"errors": items, # NOTE: deprecated!
"status": http_error_cls.status_code,
Expand All @@ -113,15 +122,14 @@ def create_http_error(
)

assert not http_error_cls.empty_body # nosec

payload = wrap_as_envelope(
error=error.model_dump(mode="json", **RESPONSE_MODEL_POLICY)
error=error_model.model_dump(mode="json", **RESPONSE_MODEL_POLICY)
)

return http_error_cls(
reason=safe_status_message(status_reason),
text=json_dumps(
payload,
),
text=json_dumps(payload),
content_type=MIMETYPE_APPLICATION_JSON,
)

Expand Down
Loading
Loading