Skip to content

Commit 5c95fe2

Browse files
🐛 make logstreaming callback safer (ITISFoundation#5633)
1 parent c5aa314 commit 5c95fe2

File tree

3 files changed

+81
-42
lines changed

3 files changed

+81
-42
lines changed

packages/service-library/src/servicelib/rabbitmq/_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import aio_pika
88
from pydantic import NonNegativeInt
99

10-
from ..logging_utils import log_context
10+
from ..logging_utils import log_catch, log_context
1111
from ._client_base import RabbitMQClientBase
1212
from ._models import MessageHandler, RabbitMessage
1313
from ._utils import (
@@ -82,7 +82,9 @@ async def _on_message(
8282
if not await message_handler(message.body):
8383
await _safe_nack(message_handler, max_retries_upon_error, message)
8484
except Exception: # pylint: disable=broad-exception-caught
85-
await _safe_nack(message_handler, max_retries_upon_error, message)
85+
_logger.exception("Exception raised when handling message")
86+
with log_catch(_logger, reraise=False):
87+
await _safe_nack(message_handler, max_retries_upon_error, message)
8688

8789

8890
@dataclass

services/api-server/src/simcore_service_api_server/services/log_streaming.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from models_library.rabbitmq_messages import LoggerRabbitMessage
77
from models_library.users import UserID
8-
from pydantic import NonNegativeInt, ValidationError
8+
from pydantic import NonNegativeInt
9+
from servicelib.logging_utils import log_catch
910
from servicelib.rabbitmq import RabbitMQClient
1011

1112
from ..models.schemas.jobs import JobID, JobLog
@@ -31,7 +32,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException):
3132
class LogDistributor:
3233
def __init__(self, rabbitmq_client: RabbitMQClient):
3334
self._rabbit_client = rabbitmq_client
34-
self._log_streamers: dict[JobID, Queue] = {}
35+
self._log_streamers: dict[JobID, Queue[JobLog]] = {}
3536
self._queue_name: str
3637

3738
async def setup(self):
@@ -53,34 +54,24 @@ async def __aexit__(self, exc_type, exc, tb):
5354
await self.teardown()
5455

5556
async def _distribute_logs(self, data: bytes):
56-
try:
57-
got = LoggerRabbitMessage.parse_raw(
58-
data
59-
) # rabbitmq client safe_nacks the message if this deserialization fails
60-
except ValidationError as e:
61-
_logger.debug(
62-
"Could not parse log message from RabbitMQ in LogDistributor._distribute_logs"
57+
with log_catch(_logger, reraise=False):
58+
got = LoggerRabbitMessage.parse_raw(data)
59+
item = JobLog(
60+
job_id=got.project_id,
61+
node_id=got.node_id,
62+
log_level=got.log_level,
63+
messages=got.messages,
6364
)
64-
raise e
65-
_logger.debug(
66-
"LogDistributor._distribute_logs received message message from RabbitMQ: %s",
67-
got.json(),
68-
)
69-
item = JobLog(
70-
job_id=got.project_id,
71-
node_id=got.node_id,
72-
log_level=got.log_level,
73-
messages=got.messages,
74-
)
75-
queue = self._log_streamers.get(item.job_id)
76-
if queue is None:
77-
raise LogStreamerNotRegistered(
78-
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
79-
)
80-
await queue.put(item)
81-
return True
65+
queue = self._log_streamers.get(item.job_id)
66+
if queue is None:
67+
raise LogStreamerNotRegistered(
68+
f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered"
69+
)
70+
await queue.put(item)
71+
return True
72+
return False
8273

83-
async def register(self, job_id: JobID, queue: Queue):
74+
async def register(self, job_id: JobID, queue: Queue[JobLog]):
8475
if job_id in self._log_streamers:
8576
raise LogStreamerRegistionConflict(
8677
f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time"

services/api-server/tests/unit/test_services_rabbitmq.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import AsyncIterable, Callable
1212
from contextlib import asynccontextmanager
1313
from datetime import datetime, timedelta
14-
from typing import Final, Iterable
14+
from typing import Final, Iterable, Literal
1515
from unittest.mock import AsyncMock
1616

1717
import httpx
@@ -24,9 +24,9 @@
2424
from models_library.projects import ProjectID
2525
from models_library.projects_nodes_io import NodeID
2626
from models_library.projects_state import RunningState
27-
from models_library.rabbitmq_messages import LoggerRabbitMessage
27+
from models_library.rabbitmq_messages import LoggerRabbitMessage, RabbitMessageBase
2828
from models_library.users import UserID
29-
from pydantic import parse_obj_as
29+
from pydantic import ValidationError, parse_obj_as
3030
from pytest_mock import MockerFixture, MockFixture
3131
from pytest_simcore.helpers.utils_envs import (
3232
EnvVarsDict,
@@ -170,15 +170,23 @@ def produce_logs(
170170
create_rabbitmq_client: Callable[[str], RabbitMQClient],
171171
user_id: UserID,
172172
):
173-
async def _go(name, project_id_=None, node_id_=None, messages_=None, level_=None):
173+
async def _go(
174+
name,
175+
project_id_=None,
176+
node_id_=None,
177+
messages_=None,
178+
level_=None,
179+
log_message: RabbitMessageBase | None = None,
180+
):
174181
rabbitmq_producer = create_rabbitmq_client(f"pytest_producer_{name}")
175-
log_message = LoggerRabbitMessage(
176-
user_id=user_id,
177-
project_id=project_id_ or faker.uuid4(),
178-
node_id=node_id_,
179-
messages=messages_ or [faker.text() for _ in range(10)],
180-
log_level=level_ or logging.INFO,
181-
)
182+
if log_message is None:
183+
log_message = LoggerRabbitMessage(
184+
user_id=user_id,
185+
project_id=project_id_ or faker.uuid4(),
186+
node_id=node_id_,
187+
messages=messages_ or [faker.text() for _ in range(10)],
188+
log_level=level_ or logging.INFO,
189+
)
182190
await rabbitmq_producer.publish(log_message.channel_name, log_message)
183191

184192
return _go
@@ -381,7 +389,6 @@ async def test_log_streamer_with_distributor(
381389
async def _log_publisher():
382390
while not computation_done():
383391
msg: str = faker.text()
384-
await asyncio.sleep(0.2)
385392
await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG)
386393
published_logs.append(msg)
387394

@@ -399,6 +406,45 @@ async def _log_publisher():
399406
assert published_logs == collected_messages
400407

401408

409+
async def test_log_streamer_not_raise_with_distributor(
410+
client: httpx.AsyncClient,
411+
app: FastAPI,
412+
user_id,
413+
project_id: ProjectID,
414+
node_id: NodeID,
415+
produce_logs: Callable,
416+
log_streamer_with_distributor: LogStreamer,
417+
faker: Faker,
418+
computation_done: Callable[[], bool],
419+
):
420+
class InvalidLoggerRabbitMessage(LoggerRabbitMessage):
421+
channel_name: Literal["simcore.services.logs.v2"] = "simcore.services.logs.v2"
422+
node_id: NodeID | None
423+
messages: int
424+
log_level: int = logging.INFO
425+
426+
def routing_key(self) -> str:
427+
return f"{self.project_id}.{self.log_level}"
428+
429+
log_rabbit_message = InvalidLoggerRabbitMessage(
430+
user_id=user_id,
431+
project_id=project_id,
432+
node_id=node_id,
433+
messages=100,
434+
log_level=logging.INFO,
435+
)
436+
with pytest.raises(ValidationError):
437+
LoggerRabbitMessage.parse_obj(log_rabbit_message.dict())
438+
439+
await produce_logs("expected", log_message=log_rabbit_message)
440+
441+
ii: int = 0
442+
async for log in log_streamer_with_distributor.log_generator():
443+
_ = JobLog.parse_raw(log)
444+
ii += 1
445+
assert ii == 0
446+
447+
402448
async def test_log_generator(mocker: MockFixture, faker: Faker):
403449
mocker.patch(
404450
"simcore_service_api_server.services.log_streaming.LogStreamer._project_done",

0 commit comments

Comments
 (0)