Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from aiohttp import web
from aiohttp.web_request import Request
from aiohttp.web_response import StreamResponse
from models_library.errors_classes import OsparcErrorMixin
from models_library.utils.json_serialization import json_dumps
from servicelib.error_codes import create_error_code

from ..logging_utils import create_troubleshotting_log_message, get_log_record_extra
from ..mimetype_constants import MIMETYPE_APPLICATION_JSON
from ..utils import is_production_environ
from .rest_models import ErrorItemType, ErrorType, LogMessageType
Expand All @@ -28,6 +31,7 @@
from .typing_extension import Handler, Middleware

DEFAULT_API_VERSION = "v0"
FMSG_INTERNAL_ERROR_USER_FRIENDLY = "Oops! Something went wrong, but we've noted it down and we'll sort it out ASAP. Thanks for your patience! [{}]"


_logger = logging.getLogger(__name__)
Expand All @@ -40,29 +44,41 @@ def is_api_request(request: web.Request, api_version: str) -> bool:

def error_middleware_factory(
api_version: str,
log_exceptions: bool = True,
) -> Middleware:
_is_prod: bool = is_production_environ()

def _process_and_raise_unexpected_error(request: web.BaseRequest, err: Exception):

error_code = create_error_code(err)
error_context: dict[str, Any] = {
"request.remote": f"{request.remote}",
"request.method": f"{request.method}",
"request.path": f"{request.path}",
}
if isinstance(err, OsparcErrorMixin):
error_context.update(err.error_context())

frontend_msg = FMSG_INTERNAL_ERROR_USER_FRIENDLY.format(error_code)
log_msg = create_troubleshotting_log_message(
message_to_user=frontend_msg,
error=err,
error_code=error_code,
error_context=error_context,
)

http_error = create_http_error(
err,
"Unexpected Server error",
frontend_msg,
web.HTTPInternalServerError,
skip_internal_error_details=_is_prod,
)

if log_exceptions:
_logger.error(
'Unexpected server error "%s" from access: %s "%s %s". Responding with status %s',
type(err),
request.remote,
request.method,
request.path,
http_error.status,
exc_info=err,
stack_info=True,
)
_logger.exception(
log_msg,
extra=get_log_record_extra(
error_code=error_code,
user_id=error_context.get("user_id"),
),
)
raise http_error

@web.middleware
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def create_http_error(
error = ErrorType(
errors=items,
status=http_error_cls.status_code,
message=items[0].message if items else default_message,
message=default_message,
)

assert not http_error_cls.empty_body # nosec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ async def _string_list_task(
await asyncio.sleep(sleep_time)
task_progress.update(message="generated item", percent=index / num_strings)
if fail:
raise RuntimeError("We were asked to fail!!")
msg = "We were asked to fail!!"
raise RuntimeError(msg)

# NOTE: this code is used just for the sake of not returning the default 200
return web.json_response(
Expand Down
182 changes: 153 additions & 29 deletions packages/service-library/tests/aiohttp/test_rest_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import asyncio
import json
import logging
from dataclasses import dataclass
from typing import Any

Expand All @@ -14,10 +15,12 @@
from models_library.utils.json_serialization import json_dumps
from servicelib.aiohttp import status
from servicelib.aiohttp.rest_middlewares import (
FMSG_INTERNAL_ERROR_USER_FRIENDLY,
envelope_middleware_factory,
error_middleware_factory,
)
from servicelib.aiohttp.rest_responses import is_enveloped, unwrap_envelope
from servicelib.error_codes import parse_error_code


@dataclass
Expand All @@ -26,9 +29,13 @@ class Data:
y: str = "foo"


class SomeUnexpectedError(Exception):
...


class Handlers:
@staticmethod
async def get_health_wrong(request: web.Request):
async def get_health_wrong(_request: web.Request):
return {
"name": __name__.split(".")[0],
"version": "1.0",
Expand All @@ -37,7 +44,7 @@ async def get_health_wrong(request: web.Request):
}

@staticmethod
async def get_health(request: web.Request):
async def get_health(_request: web.Request):
return {
"name": __name__.split(".")[0],
"version": "1.0",
Expand All @@ -46,42 +53,81 @@ async def get_health(request: web.Request):
}

@staticmethod
async def get_dict(request: web.Request):
async def get_dict(_request: web.Request):
return {"x": 3, "y": "3"}

@staticmethod
async def get_envelope(request: web.Request):
async def get_envelope(_request: web.Request):
data = {"x": 3, "y": "3"}
return {"error": None, "data": data}

@staticmethod
async def get_list(request: web.Request):
async def get_list(_request: web.Request):
return [{"x": 3, "y": "3"}] * 3

@staticmethod
async def get_attobj(request: web.Request):
async def get_obj(_request: web.Request):
return Data(3, "3")

@staticmethod
async def get_string(request: web.Request):
async def get_string(_request: web.Request):
return "foo"

@staticmethod
async def get_number(request: web.Request):
async def get_number(_request: web.Request):
return 3

@staticmethod
async def get_mixed(request: web.Request):
async def get_mixed(_request: web.Request):
return [{"x": 3, "y": "3", "z": [Data(3, "3")] * 2}] * 3

@classmethod
def get(cls, suffix):
def returns_value(cls, suffix):
handlers = cls()
coro = getattr(handlers, "get_" + suffix)
loop = asyncio.get_event_loop()
data = loop.run_until_complete(coro(None))
returned_value = loop.run_until_complete(coro(None))
return json.loads(json_dumps(returned_value))

EXPECTED_RAISE_UNEXPECTED_REASON = "Unexpected error"

@classmethod
async def raise_exception(cls, request: web.Request):
exc_name = request.query.get("exc")
match exc_name:
case NotImplementedError.__name__:
raise NotImplementedError
case asyncio.TimeoutError.__name__:
raise asyncio.TimeoutError
case web.HTTPOk.__name__:
raise web.HTTPOk # 2XX
case web.HTTPUnauthorized.__name__:
raise web.HTTPUnauthorized # 4XX
case web.HTTPServiceUnavailable.__name__:
raise web.HTTPServiceUnavailable # 5XX
case _: # unexpected
raise SomeUnexpectedError(cls.EXPECTED_RAISE_UNEXPECTED_REASON)

return json.loads(json_dumps(data))
@staticmethod
async def raise_error(_request: web.Request):
raise web.HTTPNotFound

@staticmethod
async def raise_error_with_reason(_request: web.Request):
raise web.HTTPNotFound(reason="I did not find it")

@staticmethod
async def raise_success(_request: web.Request):
raise web.HTTPOk

@staticmethod
async def raise_success_with_reason(_request: web.Request):
raise web.HTTPOk(reason="I'm ok")

@staticmethod
async def raise_success_with_text(_request: web.Request):
# NOTE: explicitly NOT enveloped!
raise web.HTTPOk(reason="I'm ok", text=json.dumps({"ok": True}))


@pytest.fixture
Expand All @@ -91,17 +137,36 @@ def client(event_loop, aiohttp_client):
# routes
app.router.add_routes(
[
web.get("/v1/health", Handlers.get_health, name="get_health"),
web.get("/v1/dict", Handlers.get_dict, name="get_dict"),
web.get("/v1/envelope", Handlers.get_envelope, name="get_envelope"),
web.get("/v1/list", Handlers.get_list, name="get_list"),
web.get("/v1/attobj", Handlers.get_attobj, name="get_attobj"),
web.get("/v1/string", Handlers.get_string, name="get_string"),
web.get("/v1/number", Handlers.get_number, name="get_number"),
web.get("/v1/mixed", Handlers.get_mixed, name="get_mixed"),
web.get(path, handler, name=handler.__name__)
for path, handler in [
("/v1/health", Handlers.get_health),
("/v1/dict", Handlers.get_dict),
("/v1/envelope", Handlers.get_envelope),
("/v1/list", Handlers.get_list),
("/v1/obj", Handlers.get_obj),
("/v1/string", Handlers.get_string),
("/v1/number", Handlers.get_number),
("/v1/mixed", Handlers.get_mixed),
# custom use cases
("/v1/raise_exception", Handlers.raise_exception),
("/v1/raise_error", Handlers.raise_error),
("/v1/raise_error_with_reason", Handlers.raise_error_with_reason),
("/v1/raise_success", Handlers.raise_success),
("/v1/raise_success_with_reason", Handlers.raise_success_with_reason),
("/v1/raise_success_with_text", Handlers.raise_success_with_text),
]
]
)

app.router.add_routes(
[
web.get(
"/free/raise_exception",
Handlers.raise_exception,
name="raise_exception_without_middleware",
)
]
)
# middlewares
app.middlewares.append(error_middleware_factory(api_version="/v1"))
app.middlewares.append(envelope_middleware_factory(api_version="/v1"))
Expand All @@ -112,14 +177,14 @@ def client(event_loop, aiohttp_client):
@pytest.mark.parametrize(
"path,expected_data",
[
("/health", Handlers.get("health")),
("/dict", Handlers.get("dict")),
("/envelope", Handlers.get("envelope")["data"]),
("/list", Handlers.get("list")),
("/attobj", Handlers.get("attobj")),
("/string", Handlers.get("string")),
("/number", Handlers.get("number")),
("/mixed", Handlers.get("mixed")),
("/health", Handlers.returns_value("health")),
("/dict", Handlers.returns_value("dict")),
("/envelope", Handlers.returns_value("envelope")["data"]),
("/list", Handlers.returns_value("list")),
("/obj", Handlers.returns_value("obj")),
("/string", Handlers.returns_value("string")),
("/number", Handlers.returns_value("number")),
("/mixed", Handlers.returns_value("mixed")),
],
)
async def test_envelope_middleware(path: str, expected_data: Any, client: TestClient):
Expand All @@ -133,7 +198,7 @@ async def test_envelope_middleware(path: str, expected_data: Any, client: TestCl
assert data == expected_data


async def test_404_not_found(client: TestClient):
async def test_404_not_found_when_entrypoint_not_exposed(client: TestClient):
response = await client.get("/some-invalid-address-outside-api")
payload = await response.text()
assert response.status == status.HTTP_404_NOT_FOUND, payload
Expand All @@ -147,3 +212,62 @@ async def test_404_not_found(client: TestClient):
data, error = unwrap_envelope(payload)
assert error
assert not data


async def test_raised_unhandled_exception(
client: TestClient, caplog: pytest.LogCaptureFixture
):
caplog.set_level(logging.ERROR)
response = await client.get("/v1/raise_exception")

# respond the client with 500
assert response.status == status.HTTP_500_INTERNAL_SERVER_ERROR

# response model
data, error = unwrap_envelope(await response.json())
assert not data
assert error

# user friendly message with OEC reference
assert "OEC" in error["message"]
parsed_oec = parse_error_code(error["message"]).pop()
assert FMSG_INTERNAL_ERROR_USER_FRIENDLY.format(parsed_oec) == error["message"]

# avoids details
assert not error.get("errors")
assert not error.get("logs")

# - log sufficient information to diagnose the issue
#
# ERROR servicelib.aiohttp.rest_middlewares:rest_middlewares.py:75 Oops! Something went wrong, but we've noted it down and we'll sort it out ASAP. Thanks for your patience! [OEC:128594540599840].
# {
# "exception_details": "Unexpected error",
# "error_code": "OEC:128594540599840",
# "context": {
# "request.remote": "127.0.0.1",
# "request.method": "GET",
# "request.path": "/v1/raise_exception"
# },
# "tip": null
# }
# Traceback (most recent call last):
# File "/osparc-simcore/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py", line 94, in _middleware_handler
# return await handler(request)
# ^^^^^^^^^^^^^^^^^^^^^^
# File "/osparc-simcore/packages/service-library/src/servicelib/aiohttp/rest_middlewares.py", line 186, in _middleware_handler
# resp = await handler(request)
# ^^^^^^^^^^^^^^^^^^^^^^
# File "/osparc-simcore/packages/service-library/tests/aiohttp/test_rest_middlewares.py", line 109, in raise_exception
# raise SomeUnexpectedError(cls.EXPECTED_RAISE_UNEXPECTED_REASON)
# tests.aiohttp.test_rest_middlewares.SomeUnexpectedError: Unexpected error

assert response.method in caplog.text
assert response.url.path in caplog.text
assert "exception_details" in caplog.text
assert "request.remote" in caplog.text
assert "context" in caplog.text
assert SomeUnexpectedError.__name__ in caplog.text
assert Handlers.EXPECTED_RAISE_UNEXPECTED_REASON in caplog.text

# log OEC
assert "OEC:" in caplog.text
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from ..users.api import get_user_id_from_gid, get_user_role
from ..users.exceptions import UserDefaultWalletNotFoundError
from ..utils_aiohttp import envelope_json_response
from ..wallets.errors import WalletNotEnoughCreditsError
from ..wallets.errors import WalletAccessForbiddenError, WalletNotEnoughCreditsError
from . import nodes_utils, projects_api
from ._common_models import ProjectPathParams, RequestContext
from ._nodes_api import NodeScreenshot, get_node_screenshots
Expand Down Expand Up @@ -120,6 +120,10 @@ async def wrapper(request: web.Request) -> web.StreamResponse:
raise web.HTTPConflict(reason=f"{exc}") from exc
except CatalogForbiddenError as exc:
raise web.HTTPForbidden(reason=f"{exc}") from exc
except WalletAccessForbiddenError as exc:
raise web.HTTPForbidden(
reason=f"Payment required, but the user lacks access to the project's linked wallet.: {exc}"
) from exc

return wrapper

Expand Down
Loading
Loading