Skip to content

Commit 044199e

Browse files
committed
handle rpc client exceptions
1 parent 51ef078 commit 044199e

File tree

3 files changed

+76
-8
lines changed

3 files changed

+76
-8
lines changed

services/api-server/src/simcore_service_api_server/exceptions/service_errors_utils.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from collections.abc import Callable, Coroutine, Mapping
34
from contextlib import contextmanager
@@ -8,6 +9,7 @@
89
import httpx
910
from fastapi import HTTPException, status
1011
from pydantic import ValidationError
12+
from servicelib.rabbitmq._errors import RemoteMethodNotRegisteredError
1113
from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError
1214

1315
from ..models.schemas.errors import ErrorGet
@@ -50,8 +52,12 @@ class ToApiTuple(NamedTuple):
5052

5153

5254
# service to public-api status maps
53-
E = TypeVar("E", bound=BaseBackEndError)
54-
HttpStatusMap: TypeAlias = Mapping[ServiceHTTPStatus, E]
55+
BackEndErrorType = TypeVar("BackEndErrorType", bound=BaseBackEndError)
56+
RpcExceptionType = TypeVar(
57+
"RpcExceptionType", bound=Exception
58+
) # need more specific rpc exception base class
59+
HttpStatusMap: TypeAlias = Mapping[ServiceHTTPStatus, BackEndErrorType]
60+
RabbitMqRpcExceptionMap: TypeAlias = Mapping[RpcExceptionType, BackEndErrorType]
5561

5662

5763
def _get_http_exception_kwargs(
@@ -98,6 +104,7 @@ def _get_http_exception_kwargs(
98104
def service_exception_handler(
99105
service_name: str,
100106
http_status_map: HttpStatusMap,
107+
rpc_exception_map: RabbitMqRpcExceptionMap,
101108
**context,
102109
):
103110
status_code: int
@@ -126,35 +133,64 @@ def service_exception_handler(
126133
status_code=status_code, detail=detail, headers=headers
127134
) from exc
128135

136+
except BaseException as exc: # currently no baseclass for rpc errors
137+
if (
138+
type(exc) == asyncio.TimeoutError
139+
): # https://github.com/ITISFoundation/osparc-simcore/blob/master/packages/service-library/src/servicelib/rabbitmq/_client_rpc.py#L76
140+
raise HTTPException(
141+
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
142+
detail="Request to backend timed out",
143+
) from exc
144+
if type(exc) in {
145+
asyncio.exceptions.CancelledError,
146+
RuntimeError,
147+
RemoteMethodNotRegisteredError,
148+
}: # https://github.com/ITISFoundation/osparc-simcore/blob/master/packages/service-library/src/servicelib/rabbitmq/_client_rpc.py#L76
149+
raise HTTPException(
150+
status_code=status.HTTP_502_BAD_GATEWAY, detail="Request to failed"
151+
) from exc
152+
if backend_error_type := rpc_exception_map.get(type(exc)):
153+
raise backend_error_type(**context) from exc
154+
raise
155+
129156

130157
def service_exception_mapper(
131158
*,
132159
service_name: str,
133-
http_status_map: HttpStatusMap,
160+
http_status_map: HttpStatusMap = {},
161+
rpc_exception_map: RabbitMqRpcExceptionMap = {},
134162
) -> Callable[
135163
[Callable[Concatenate[Self, P], Coroutine[Any, Any, R]]],
136164
Callable[Concatenate[Self, P], Coroutine[Any, Any, R]],
137165
]:
138166
def _decorator(member_func: Callable[Concatenate[Self, P], Coroutine[Any, Any, R]]):
139-
_assert_correct_kwargs(func=member_func, status_map=http_status_map)
167+
_assert_correct_kwargs(
168+
func=member_func,
169+
exception_types=set(http_status_map.values()).union(
170+
set(rpc_exception_map.values())
171+
),
172+
)
140173

141174
@wraps(member_func)
142175
async def _wrapper(self: Self, *args: P.args, **kwargs: P.kwargs) -> R:
143-
with service_exception_handler(service_name, http_status_map, **kwargs):
176+
with service_exception_handler(
177+
service_name, http_status_map, rpc_exception_map, **kwargs
178+
):
144179
return await member_func(self, *args, **kwargs)
145180

146181
return _wrapper
147182

148183
return _decorator
149184

150185

151-
def _assert_correct_kwargs(func: Callable, status_map: HttpStatusMap):
186+
def _assert_correct_kwargs(func: Callable, exception_types: set[BackEndErrorType]):
152187
_required_kwargs = {
153188
name
154189
for name, param in signature(func).parameters.items()
155190
if param.kind == param.KEYWORD_ONLY
156191
}
157-
for exc_type in status_map.values():
192+
for exc_type in exception_types:
193+
assert isinstance(exc_type, type) # nosec
158194
_exception_inputs = exc_type.named_fields()
159195
assert _exception_inputs.issubset(
160196
_required_kwargs

services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
from dataclasses import dataclass
2+
from functools import partial
23

34
from fastapi import FastAPI
45
from fastapi_pagination import Page, create_page
56
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
67
from servicelib.rabbitmq.rpc_interfaces.webserver.licenses.licensed_items import (
78
get_licensed_items as _get_licensed_items,
89
)
10+
from simcore_service_api_server.exceptions.service_errors_utils import (
11+
service_exception_mapper,
12+
)
913
from simcore_service_api_server.models.pagination import PaginationParams
1014

1115
from ..models.schemas.model_adapter import LicensedItemGet
1216

17+
_exception_mapper = partial(service_exception_mapper, service_name="WebApiServer")
18+
1319

1420
@dataclass
1521
class WbApiRpcClient:
1622
_rabbitmq_rpc_client: RabbitMQRPCClient
1723

24+
@_exception_mapper(rpc_exception_map={})
1825
async def get_licensed_items(
1926
self, product_name: str, page_params: PaginationParams
2027
) -> Page[LicensedItemGet]:

services/api-server/tests/unit/test_licensed_items.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import pytest
24
from fastapi import status
35
from httpx import AsyncClient, BasicAuth
@@ -10,6 +12,7 @@
1012
from pydantic import TypeAdapter
1113
from pytest_mock import MockerFixture
1214
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
15+
from servicelib.rabbitmq._errors import RemoteMethodNotRegisteredError
1316
from simcore_service_api_server._meta import API_VTAG
1417
from simcore_service_api_server.models.pagination import Page
1518
from simcore_service_api_server.models.schemas.model_adapter import LicensedItemGet
@@ -32,7 +35,7 @@ def _get_dummy_rpc_client():
3235

3336
@pytest.fixture
3437
async def mock_wb_api_server_rcp(
35-
mock_rabbitmq_rpc_client: MockerFixture,
38+
mock_rabbitmq_rpc_client: MockerFixture, exception_to_raise: Exception | None
3639
) -> MockerFixture:
3740
async def _get_backend_licensed_items(
3841
rabbitmq_rpc_client: RabbitMQRPCClient,
@@ -41,6 +44,8 @@ async def _get_backend_licensed_items(
4144
offset: int,
4245
limit: int,
4346
) -> _LicensedItemGetPage:
47+
if exception_to_raise is not None:
48+
raise exception_to_raise
4449
extra = _LicensedItemGet.model_config.get("json_schema_extra")
4550
assert isinstance(extra, dict)
4651
examples = extra.get("examples")
@@ -58,9 +63,29 @@ async def _get_backend_licensed_items(
5863
return mock_rabbitmq_rpc_client
5964

6065

66+
@pytest.mark.parametrize("exception_to_raise", [None])
6167
async def test_get_licensed_items(
6268
mock_wb_api_server_rcp: MockerFixture, client: AsyncClient, auth: BasicAuth
6369
):
6470
resp = await client.get(f"{API_VTAG}/licensed-items/page", auth=auth)
6571
assert resp.status_code == status.HTTP_200_OK
6672
TypeAdapter(Page[LicensedItemGet]).validate_json(resp.text)
73+
74+
75+
@pytest.mark.parametrize("exception_to_raise", [asyncio.TimeoutError()])
76+
async def test_get_licensed_items_timeout(
77+
mock_wb_api_server_rcp: MockerFixture, client: AsyncClient, auth: BasicAuth
78+
):
79+
resp = await client.get(f"{API_VTAG}/licensed-items/page", auth=auth)
80+
assert resp.status_code == status.HTTP_504_GATEWAY_TIMEOUT
81+
82+
83+
@pytest.mark.parametrize(
84+
"exception_to_raise",
85+
[asyncio.CancelledError(), RuntimeError(), RemoteMethodNotRegisteredError()],
86+
)
87+
async def test_get_licensed_items_502(
88+
mock_wb_api_server_rcp: MockerFixture, client: AsyncClient, auth: BasicAuth
89+
):
90+
resp = await client.get(f"{API_VTAG}/licensed-items/page", auth=auth)
91+
assert resp.status_code == status.HTTP_502_BAD_GATEWAY

0 commit comments

Comments
 (0)