Skip to content

Commit 37cf6d5

Browse files
🐛 Deregister log streamer via starlette background task (#7626)
1 parent 975c581 commit 37cf6d5

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

services/api-server/src/simcore_service_api_server/api/routes/solvers_jobs_getters.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from collections import deque
55
from collections.abc import Callable
6+
from functools import partial
67
from typing import Annotated, Any, Union
78
from uuid import UUID
89

@@ -16,9 +17,9 @@
1617
from models_library.wallets import ZERO_CREDITS
1718
from pydantic import HttpUrl, NonNegativeInt
1819
from pydantic.types import PositiveInt
19-
from servicelib.fastapi.requests_decorators import cancel_on_disconnect
2020
from servicelib.logging_utils import log_context
2121
from sqlalchemy.ext.asyncio import AsyncEngine
22+
from starlette.background import BackgroundTask
2223

2324
from ..._service_solvers import SolverService
2425
from ...exceptions.custom_errors import InsufficientCreditsError, MissingWalletError
@@ -478,7 +479,6 @@ async def get_job_pricing_unit(
478479
response_class=LogStreamingResponse,
479480
responses=_LOGSTREAM_STATUS_CODES,
480481
)
481-
@cancel_on_disconnect
482482
async def get_log_stream(
483483
request: Request,
484484
solver_key: SolverKeyId,
@@ -505,8 +505,10 @@ async def get_log_stream(
505505
log_distributor=log_distributor,
506506
log_check_timeout=log_check_timeout,
507507
)
508+
await log_distributor.register(job_id, log_streamer.queue)
508509
return LogStreamingResponse(
509510
log_streamer.log_generator(),
511+
background=BackgroundTask(partial(log_distributor.deregister, job_id)),
510512
)
511513

512514

services/api-server/src/simcore_service_api_server/services_http/log_streaming.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ async def _distribute_logs(self, data: bytes):
6969
return False
7070

7171
async def register(self, job_id: JobID, queue: Queue[JobLog]):
72+
_logger.debug("Registering log streamer for job_id=%s", job_id)
7273
if job_id in self._log_streamers:
7374
raise LogStreamerRegistrationConflictError(job_id=job_id)
7475
self._log_streamers[job_id] = queue
@@ -77,13 +78,14 @@ async def register(self, job_id: JobID, queue: Queue[JobLog]):
7778
)
7879

7980
async def deregister(self, job_id: JobID):
81+
_logger.debug("Deregistering log streamer for job_id=%s", job_id)
8082
if job_id not in self._log_streamers:
8183
msg = f"No stream was connected to {job_id}."
8284
raise LogStreamerNotRegisteredError(details=msg, job_id=job_id)
8385
await self._rabbit_client.remove_topics(
8486
LoggerRabbitMessage.get_channel_name(), topics=[f"{job_id}.*"]
8587
)
86-
del self._log_streamers[job_id]
88+
self._log_streamers.pop(job_id)
8789

8890
@property
8991
def iter_log_queue_sizes(self) -> Iterator[tuple[JobID, int]]:
@@ -103,7 +105,7 @@ def __init__(
103105
):
104106
self._user_id = user_id
105107
self._director2_api = director2_api
106-
self._queue: Queue[JobLog] = Queue()
108+
self.queue: Queue[JobLog] = Queue()
107109
self._job_id: JobID = job_id
108110
self._log_distributor: LogDistributor = log_distributor
109111
self._log_check_timeout: NonNegativeInt = log_check_timeout
@@ -116,12 +118,11 @@ async def _project_done(self) -> bool:
116118

117119
async def log_generator(self) -> AsyncIterable[str]:
118120
try:
119-
await self._log_distributor.register(self._job_id, self._queue)
120121
done: bool = False
121122
while not done:
122123
try:
123124
log: JobLog = await asyncio.wait_for(
124-
self._queue.get(), timeout=self._log_check_timeout
125+
self.queue.get(), timeout=self._log_check_timeout
125126
)
126127
yield log.model_dump_json() + _NEW_LINE
127128
except TimeoutError:
@@ -145,6 +146,3 @@ async def log_generator(self) -> AsyncIterable[str]:
145146
)
146147
)
147148
yield ErrorGet(errors=[error_msg]).model_dump_json() + _NEW_LINE
148-
149-
finally:
150-
await self._log_distributor.deregister(self._job_id)

services/api-server/tests/unit/test_services_rabbitmq.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ async def test_log_streamer_with_distributor(
361361
project_id: ProjectID,
362362
node_id: NodeID,
363363
produce_logs: Callable,
364+
log_distributor: LogDistributor,
364365
log_streamer_with_distributor: LogStreamer,
365366
faker: Faker,
366367
computation_done: Callable[[], bool],
@@ -375,12 +376,21 @@ async def _log_publisher():
375376

376377
publish_task = asyncio.create_task(_log_publisher())
377378

379+
@asynccontextmanager
380+
async def registered_log_streamer():
381+
await log_distributor.register(project_id, log_streamer_with_distributor.queue)
382+
try:
383+
yield
384+
finally:
385+
await log_distributor.deregister(project_id)
386+
378387
collected_messages: list[str] = []
379-
async for log in log_streamer_with_distributor.log_generator():
380-
job_log: JobLog = JobLog.model_validate_json(log)
381-
assert len(job_log.messages) == 1
382-
assert job_log.job_id == project_id
383-
collected_messages.append(job_log.messages[0])
388+
async with registered_log_streamer():
389+
async for log in log_streamer_with_distributor.log_generator():
390+
job_log: JobLog = JobLog.model_validate_json(log)
391+
assert len(job_log.messages) == 1
392+
assert job_log.job_id == project_id
393+
collected_messages.append(job_log.messages[0])
384394

385395
if not publish_task.done():
386396
publish_task.cancel()
@@ -451,7 +461,7 @@ async def test_log_generator(mocker: MockFixture, faker: Faker):
451461
msg = faker.text()
452462
published_logs.append(msg)
453463
job_log.messages = [msg]
454-
await log_streamer._queue.put(job_log)
464+
await log_streamer.queue.put(job_log)
455465

456466
collected_logs: list[str] = []
457467
async for log in log_streamer.log_generator():

0 commit comments

Comments
 (0)