Skip to content

Commit 17a8abf

Browse files
authored
♻️Dask-sidecar: remove dask Pub/Sub (#7660)
1 parent 50a4875 commit 17a8abf

File tree

10 files changed

+136
-232
lines changed

10 files changed

+136
-232
lines changed

services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import contextlib
33
import logging
44
from collections.abc import AsyncIterator
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
66
from typing import Final
77

88
import distributed
@@ -17,7 +17,7 @@
1717
from distributed.worker_state_machine import TaskState
1818
from models_library.progress_bar import ProgressReport
1919
from models_library.rabbitmq_messages import LoggerRabbitMessage
20-
from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch
20+
from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch, log_context
2121

2222
from ..rabbitmq_plugin import get_rabbitmq_client
2323

@@ -63,24 +63,23 @@ def get_current_task_resources() -> dict[str, float]:
6363
@dataclass(slots=True, kw_only=True)
6464
class TaskPublisher:
6565
task_owner: TaskOwner
66-
progress: distributed.Pub = field(init=False)
6766
_last_published_progress_value: float = -1
6867

69-
def __post_init__(self) -> None:
70-
self.progress = distributed.Pub(TaskProgressEvent.topic_name())
71-
7268
def publish_progress(self, report: ProgressReport) -> None:
7369
rounded_value = round(report.percent_value, ndigits=2)
7470
if rounded_value > self._last_published_progress_value:
75-
with log_catch(logger=_logger, reraise=False):
71+
with (
72+
log_catch(logger=_logger, reraise=False),
73+
log_context(
74+
_logger, logging.DEBUG, msg=f"publish progress {rounded_value=}"
75+
),
76+
):
7677
publish_event(
77-
self.progress,
7878
TaskProgressEvent.from_dask_worker(
7979
progress=rounded_value, task_owner=self.task_owner
8080
),
8181
)
8282
self._last_published_progress_value = rounded_value
83-
_logger.debug("PROGRESS: %s", rounded_value)
8483

8584
async def publish_logs(
8685
self,
@@ -169,7 +168,14 @@ async def periodicaly_check_if_aborted(task_name: str) -> None:
169168
await periodically_checking_task
170169

171170

172-
def publish_event(dask_pub: distributed.Pub, event: BaseTaskEvent) -> None:
171+
def publish_event(
172+
event: BaseTaskEvent,
173+
) -> None:
173174
"""never reraises, only CancellationError"""
174-
with log_catch(_logger, reraise=False):
175-
dask_pub.put(event.model_dump_json())
175+
worker = get_worker()
176+
_logger.debug("current worker %s", f"{worker=}")
177+
with (
178+
log_catch(_logger, reraise=False),
179+
log_context(_logger, logging.DEBUG, msg=f"publishing {event=}"),
180+
):
181+
worker.log_event(TaskProgressEvent.topic_name(), event.model_dump_json())

services/dask-sidecar/tests/unit/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def app_environment(
9999
model_dump_with_secrets(rabbit_service, show_secrets=True)
100100
),
101101
"SC_BOOT_MODE": "debug",
102-
"SIDECAR_LOGLEVEL": "DEBUG",
102+
"DASK_SIDECAR_LOGLEVEL": "DEBUG",
103103
"SIDECAR_COMP_SERVICES_SHARED_VOLUME_NAME": "simcore_computational_shared_data",
104104
"SIDECAR_COMP_SERVICES_SHARED_FOLDER": f"{shared_data_folder}",
105105
},

services/dask-sidecar/tests/unit/test_utils_dask.py

Lines changed: 62 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66

77
import asyncio
88
import concurrent.futures
9+
import logging
910
import time
10-
from collections.abc import AsyncIterator, Callable, Coroutine
1111
from typing import Any
12+
from unittest import mock
1213

1314
import distributed
1415
import pytest
16+
from common_library.async_tools import maybe_await
1517
from dask_task_models_library.container_tasks.errors import TaskCancelledError
1618
from dask_task_models_library.container_tasks.events import TaskProgressEvent
1719
from dask_task_models_library.container_tasks.io import TaskCancelEventName
1820
from dask_task_models_library.container_tasks.protocol import TaskOwner
21+
from pytest_simcore.helpers.logging_tools import log_context
1922
from simcore_service_dask_sidecar.utils.dask import (
2023
_DEFAULT_MAX_RESOURCES,
2124
TaskPublisher,
@@ -24,7 +27,7 @@
2427
monitor_task_abortion,
2528
publish_event,
2629
)
27-
from tenacity.asyncio import AsyncRetrying
30+
from tenacity import Retrying
2831
from tenacity.retry import retry_if_exception_type
2932
from tenacity.stop import stop_after_delay
3033
from tenacity.wait import wait_fixed
@@ -37,128 +40,85 @@
3740
]
3841

3942

40-
def test_publish_event(
41-
dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
43+
@pytest.fixture(params=["sync-dask-client", "async-dask-client"])
44+
def dask_client_multi(
45+
request: pytest.FixtureRequest,
46+
dask_client: distributed.Client,
47+
async_dask_client: distributed.Client,
48+
) -> distributed.Client:
49+
if request.param == "sync-dask-client":
50+
return dask_client
51+
return async_dask_client
52+
53+
54+
@pytest.mark.parametrize(
55+
"handler", [mock.Mock(), mock.AsyncMock()], ids=["sync-handler", "async-handler"]
56+
)
57+
async def test_publish_event(
58+
dask_client_multi: distributed.Client,
59+
job_id: str,
60+
task_owner: TaskOwner,
61+
handler: mock.Mock | mock.AsyncMock,
4262
):
43-
dask_pub = distributed.Pub("some_topic", client=dask_client)
44-
dask_sub = distributed.Sub("some_topic", client=dask_client)
4563
event_to_publish = TaskProgressEvent(
4664
job_id=job_id,
4765
msg="the log",
4866
progress=1,
4967
task_owner=task_owner,
5068
)
51-
publish_event(dask_pub=dask_pub, event=event_to_publish)
52-
53-
# NOTE: this tests runs a sync dask client,
54-
# and the CI seems to have sometimes difficulties having this run in a reasonable time
55-
# hence the long time out
56-
message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S)
57-
assert message is not None
58-
assert isinstance(message, str)
59-
received_task_log_event = TaskProgressEvent.model_validate_json(message)
60-
assert received_task_log_event == event_to_publish
61-
62-
63-
async def test_publish_event_async(
64-
async_dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
65-
):
66-
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
67-
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
68-
event_to_publish = TaskProgressEvent(
69-
job_id=job_id, msg="the log", progress=2, task_owner=task_owner
70-
)
71-
publish_event(dask_pub=dask_pub, event=event_to_publish)
72-
73-
# NOTE: this tests runs a sync dask client,
74-
# and the CI seems to have sometimes difficulties having this run in a reasonable time
75-
# hence the long time out
76-
message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S)
77-
assert isinstance(message, Coroutine)
78-
message = await message
79-
assert message is not None
80-
received_task_log_event = TaskProgressEvent.model_validate_json(message)
81-
assert received_task_log_event == event_to_publish
8269

70+
# NOTE: only 1 handler per topic is allowed at a time
71+
dask_client_multi.subscribe_topic(TaskProgressEvent.topic_name(), handler)
8372

84-
@pytest.fixture
85-
async def asyncio_task() -> AsyncIterator[Callable[[Coroutine], asyncio.Task]]:
86-
created_tasks = []
73+
def _worker_task() -> int:
74+
with log_context(logging.INFO, "_worker_task"):
8775

88-
def _creator(coro: Coroutine) -> asyncio.Task:
89-
task = asyncio.create_task(coro, name="pytest_asyncio_task")
90-
created_tasks.append(task)
91-
return task
76+
async def _() -> int:
77+
with log_context(logging.INFO, "_worker_task_async"):
78+
publish_event(event_to_publish)
79+
return 2
9280

93-
yield _creator
94-
for task in created_tasks:
95-
task.cancel()
81+
return asyncio.run(_())
9682

97-
await asyncio.gather(*created_tasks, return_exceptions=True)
83+
future = dask_client_multi.submit(_worker_task)
84+
assert await maybe_await(future.result(timeout=DASK_TESTING_TIMEOUT_S)) == 2
9885

99-
100-
async def test_publish_event_async_using_task(
101-
async_dask_client: distributed.Client,
102-
asyncio_task: Callable[[Coroutine], asyncio.Task],
103-
job_id: str,
104-
task_owner: TaskOwner,
105-
):
106-
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
107-
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
108-
NUMBER_OF_MESSAGES = 1000
109-
received_messages = []
110-
111-
async def _dask_sub_consumer_task(sub: distributed.Sub) -> None:
112-
print("--> starting consumer task")
113-
async for dask_event in sub:
114-
print(f"received {dask_event}")
115-
received_messages.append(dask_event)
116-
print("<-- finished consumer task")
117-
118-
consumer_task = asyncio_task(_dask_sub_consumer_task(dask_sub))
119-
assert consumer_task
120-
121-
async def _dask_publisher_task(pub: distributed.Pub) -> None:
122-
print("--> starting publisher task")
123-
for _ in range(NUMBER_OF_MESSAGES):
124-
event_to_publish = TaskProgressEvent(
125-
job_id=job_id,
126-
progress=0.5,
127-
task_owner=task_owner,
128-
)
129-
publish_event(dask_pub=pub, event=event_to_publish)
130-
print("<-- finished publisher task")
131-
132-
publisher_task = asyncio_task(_dask_publisher_task(dask_pub))
133-
assert publisher_task
134-
135-
async for attempt in AsyncRetrying(
136-
retry=retry_if_exception_type(AssertionError),
137-
stop=stop_after_delay(DASK_TESTING_TIMEOUT_S),
138-
wait=wait_fixed(0.01),
86+
for attempt in Retrying(
87+
wait=wait_fixed(0.2),
88+
stop=stop_after_delay(15),
13989
reraise=True,
90+
retry=retry_if_exception_type(AssertionError),
14091
):
14192
with attempt:
142-
print(
143-
f"checking number of received messages...currently {len(received_messages)}"
93+
events = await maybe_await(
94+
dask_client_multi.get_events(TaskProgressEvent.topic_name())
14495
)
145-
assert len(received_messages) == NUMBER_OF_MESSAGES
146-
print("all expected messages received")
96+
assert events is not None, "No events received"
97+
assert isinstance(events, tuple)
14798

99+
handler.assert_called_with(events[-1])
148100

149-
def _wait_for_task_to_start() -> None:
150-
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
101+
assert isinstance(events, tuple)
102+
assert len(events) == 1
103+
assert isinstance(events[0], tuple)
104+
received_task_log_event = TaskProgressEvent.model_validate_json(events[0][1])
105+
assert received_task_log_event == event_to_publish
106+
107+
108+
def _wait_for_task_to_start(dask_client: distributed.Client) -> None:
109+
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
151110
start_event.wait(timeout=DASK_TESTING_TIMEOUT_S)
152111

153112

154-
def _notify_task_is_started_and_ready() -> None:
155-
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
113+
def _notify_task_is_started_and_ready(dask_client: distributed.Client) -> None:
114+
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
156115
start_event.set()
157116

158117

159118
def _some_long_running_task() -> int:
160119
assert is_current_task_aborted() is False
161-
_notify_task_is_started_and_ready()
120+
dask_client = distributed.get_worker().client
121+
_notify_task_is_started_and_ready(dask_client)
162122

163123
for i in range(300):
164124
print("running iteration", i)
@@ -176,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
176136
not work in distributed mode where an Event is necessary."""
177137
# NOTE: this works because the cluster is in the same machine
178138
future = dask_client.submit(_some_long_running_task)
179-
_wait_for_task_to_start()
139+
_wait_for_task_to_start(dask_client)
180140
future.cancel()
181141
assert future.cancelled()
182142
with pytest.raises(concurrent.futures.CancelledError):
@@ -186,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
186146
def test_task_is_aborted_using_event(dask_client: distributed.Client):
187147
job_id = "myfake_job_id"
188148
future = dask_client.submit(_some_long_running_task, key=job_id)
189-
_wait_for_task_to_start()
149+
_wait_for_task_to_start(dask_client)
190150

191151
dask_event = distributed.Event(TaskCancelEventName.format(job_id))
192152
dask_event.set()
@@ -203,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int:
203163

204164
async def _long_running_task_async() -> int:
205165
task_publishers = TaskPublisher(task_owner=task_owner)
206-
_notify_task_is_started_and_ready()
166+
worker = distributed.get_worker()
167+
_notify_task_is_started_and_ready(worker.client)
207168
current_task = asyncio.current_task()
208169
assert current_task
209170
async with monitor_task_abortion(
@@ -229,7 +190,7 @@ def test_monitor_task_abortion(
229190
future = dask_client.submit(
230191
_some_long_running_task_with_monitoring, task_owner=task_owner, key=job_id
231192
)
232-
_wait_for_task_to_start()
193+
_wait_for_task_to_start(dask_client)
233194
# trigger cancellation
234195
dask_event = distributed.Event(TaskCancelEventName.format(job_id))
235196
dask_event.set()

0 commit comments

Comments
 (0)