Skip to content

Commit 33b4faa

Browse files
committed
dask-sidecar is free of Pub/Sub
1 parent 4cebb72 commit 33b4faa

File tree

2 files changed

+32
-45
lines changed

2 files changed

+32
-45
lines changed

services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import contextlib
33
import logging
44
from collections.abc import AsyncIterator
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66
from typing import Final
77

88
import distributed
@@ -63,23 +63,23 @@ def get_current_task_resources() -> dict[str, float]:
6363
@dataclass(slots=True, kw_only=True)
6464
class TaskPublisher:
6565
task_owner: TaskOwner
66-
progress: distributed.Pub = field(init=False)
6766
_last_published_progress_value: float = -1
6867

69-
def __post_init__(self) -> None:
70-
self.progress = distributed.Pub(TaskProgressEvent.topic_name())
71-
7268
def publish_progress(self, report: ProgressReport) -> None:
7369
rounded_value = round(report.percent_value, ndigits=2)
7470
if rounded_value > self._last_published_progress_value:
75-
with log_catch(logger=_logger, reraise=False):
71+
with (
72+
log_catch(logger=_logger, reraise=False),
73+
log_context(
74+
_logger, logging.DEBUG, msg=f"publish progress {rounded_value=}"
75+
),
76+
):
7677
publish_event(
7778
TaskProgressEvent.from_dask_worker(
7879
progress=rounded_value, task_owner=self.task_owner
7980
),
8081
)
8182
self._last_published_progress_value = rounded_value
82-
_logger.debug("PROGRESS: %s", rounded_value)
8383

8484
async def publish_logs(
8585
self,

services/dask-sidecar/tests/unit/test_worker.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,6 @@ def dask_subsystem_mock(
106106
"pytest_jobid"
107107
)
108108
# mock dask event publishing
109-
dask_utils_publish_event_mock = mocker.patch(
110-
"simcore_service_dask_sidecar.utils.dask.distributed.Pub",
111-
autospec=True,
112-
)
113-
mocker.patch(
114-
"simcore_service_dask_sidecar.utils.dask.distributed.Sub",
115-
autospec=True,
116-
)
117109
mocker.patch(
118110
"simcore_service_dask_sidecar.utils.dask.is_current_task_aborted",
119111
autospec=True,
@@ -138,7 +130,6 @@ def dask_subsystem_mock(
138130
return {
139131
"dask_client": dask_client_mock,
140132
"dask_task_state": dask_task_mock,
141-
"dask_event_publish": dask_utils_publish_event_mock,
142133
}
143134

144135

@@ -562,10 +553,6 @@ def test_run_computational_sidecar_real_fct(
562553
sleeper_task.service_key,
563554
sleeper_task.service_version,
564555
)
565-
for event in [TaskProgressEvent]:
566-
dask_subsystem_mock["dask_event_publish"].assert_any_call(
567-
name=event.topic_name()
568-
)
569556
assert log_rabbit_client_parser.called
570557

571558
# check that the task produces expected logs
@@ -641,8 +628,26 @@ def test_run_multiple_computational_sidecar_dask(
641628

642629

643630
@pytest.fixture
644-
def progress_sub(dask_client: distributed.Client) -> distributed.Sub:
645-
return distributed.Sub(TaskProgressEvent.topic_name(), client=dask_client)
631+
def progress_event_handler(dask_client: distributed.Client) -> mock.Mock:
632+
mocked_parser = mock.Mock()
633+
dask_client.subscribe_topic(TaskProgressEvent.topic_name(), mocked_parser)
634+
return mocked_parser
635+
636+
637+
def _assert_parse_progresses_from_progress_event_handler(
638+
progress_event_handler: mock.Mock,
639+
) -> list[float]:
640+
assert progress_event_handler.called
641+
worker_progresses = [
642+
TaskProgressEvent.model_validate_json(msg.args[0][1]).progress
643+
for msg in progress_event_handler.call_args_list
644+
]
645+
assert worker_progresses == sorted(
646+
set(worker_progresses)
647+
), "ordering of progress values incorrectly sorted!"
648+
assert worker_progresses[0] == 0, "missing/incorrect initial progress value"
649+
assert worker_progresses[-1] == 1, "missing/incorrect final progress value"
650+
return worker_progresses
646651

647652

648653
@pytest.mark.parametrize(
@@ -651,7 +656,7 @@ def progress_sub(dask_client: distributed.Client) -> distributed.Sub:
651656
async def test_run_computational_sidecar_dask(
652657
app_environment: EnvVarsDict,
653658
sleeper_task: ServiceExampleParam,
654-
progress_sub: distributed.Sub,
659+
progress_event_handler: mock.Mock,
655660
mocked_get_image_labels: mock.Mock,
656661
s3_settings: S3Settings,
657662
log_rabbit_client_parser: mock.AsyncMock,
@@ -670,16 +675,8 @@ async def test_run_computational_sidecar_dask(
670675
assert isinstance(output_data, TaskOutputData)
671676

672677
# check that the task produces expected logs
673-
worker_progresses = [
674-
TaskProgressEvent.model_validate_json(msg).progress
675-
for msg in progress_sub.buffer
676-
]
677-
# check ordering
678-
assert worker_progresses == sorted(
679-
set(worker_progresses)
680-
), "ordering of progress values incorrectly sorted!"
681-
assert worker_progresses[0] == 0, "missing/incorrect initial progress value"
682-
assert worker_progresses[-1] == 1, "missing/incorrect final progress value"
678+
_assert_parse_progresses_from_progress_event_handler(progress_event_handler)
679+
683680
async for attempt in AsyncRetrying(
684681
wait=wait_fixed(1),
685682
stop=stop_after_delay(30),
@@ -731,7 +728,7 @@ async def test_run_computational_sidecar_dask(
731728
async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub(
732729
dask_client: distributed.Client,
733730
sidecar_task: Callable[..., ServiceExampleParam],
734-
progress_sub: distributed.Sub,
731+
progress_event_handler: mock.Mock,
735732
mocked_get_image_labels: mock.Mock,
736733
log_rabbit_client_parser: mock.AsyncMock,
737734
):
@@ -757,17 +754,7 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub
757754
assert isinstance(output_data, TaskOutputData)
758755

759756
# check that the task produces expected logs
760-
worker_progresses = [
761-
TaskProgressEvent.model_validate_json(msg).progress
762-
for msg in progress_sub.buffer
763-
]
764-
# check length
765-
assert len(worker_progresses) == len(
766-
set(worker_progresses)
767-
), "there are duplicate progresses!"
768-
assert sorted(worker_progresses) == worker_progresses, "ordering issue"
769-
assert worker_progresses[0] == 0, "missing/incorrect initial progress value"
770-
assert worker_progresses[-1] == 1, "missing/incorrect final progress value"
757+
_assert_parse_progresses_from_progress_event_handler(progress_event_handler)
771758

772759
async for attempt in AsyncRetrying(
773760
wait=wait_fixed(1),

0 commit comments

Comments
 (0)