Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Final

import distributed
Expand All @@ -17,7 +17,7 @@
from distributed.worker_state_machine import TaskState
from models_library.progress_bar import ProgressReport
from models_library.rabbitmq_messages import LoggerRabbitMessage
from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch
from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch, log_context

from ..rabbitmq_plugin import get_rabbitmq_client

Expand Down Expand Up @@ -63,24 +63,23 @@ def get_current_task_resources() -> dict[str, float]:
@dataclass(slots=True, kw_only=True)
class TaskPublisher:
task_owner: TaskOwner
progress: distributed.Pub = field(init=False)
_last_published_progress_value: float = -1

def __post_init__(self) -> None:
self.progress = distributed.Pub(TaskProgressEvent.topic_name())

def publish_progress(self, report: ProgressReport) -> None:
rounded_value = round(report.percent_value, ndigits=2)
if rounded_value > self._last_published_progress_value:
with log_catch(logger=_logger, reraise=False):
with (
log_catch(logger=_logger, reraise=False),
log_context(
_logger, logging.DEBUG, msg=f"publish progress {rounded_value=}"
),
):
publish_event(
self.progress,
TaskProgressEvent.from_dask_worker(
progress=rounded_value, task_owner=self.task_owner
),
)
self._last_published_progress_value = rounded_value
_logger.debug("PROGRESS: %s", rounded_value)

async def publish_logs(
self,
Expand Down Expand Up @@ -169,7 +168,14 @@ async def periodicaly_check_if_aborted(task_name: str) -> None:
await periodically_checking_task


def publish_event(dask_pub: distributed.Pub, event: BaseTaskEvent) -> None:
def publish_event(
event: BaseTaskEvent,
) -> None:
"""never reraises, only CancellationError"""
with log_catch(_logger, reraise=False):
dask_pub.put(event.model_dump_json())
worker = get_worker()
_logger.debug("current worker %s", f"{worker=}")
with (
log_catch(_logger, reraise=False),
log_context(_logger, logging.DEBUG, msg=f"publishing {event=}"),
):
worker.log_event(TaskProgressEvent.topic_name(), event.model_dump_json())
2 changes: 1 addition & 1 deletion services/dask-sidecar/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def app_environment(
model_dump_with_secrets(rabbit_service, show_secrets=True)
),
"SC_BOOT_MODE": "debug",
"SIDECAR_LOGLEVEL": "DEBUG",
"DASK_SIDECAR_LOGLEVEL": "DEBUG",
"SIDECAR_COMP_SERVICES_SHARED_VOLUME_NAME": "simcore_computational_shared_data",
"SIDECAR_COMP_SERVICES_SHARED_FOLDER": f"{shared_data_folder}",
},
Expand Down
163 changes: 62 additions & 101 deletions services/dask-sidecar/tests/unit/test_utils_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@

import asyncio
import concurrent.futures
import logging
import time
from collections.abc import AsyncIterator, Callable, Coroutine
from typing import Any
from unittest import mock

import distributed
import pytest
from common_library.async_tools import maybe_await
from dask_task_models_library.container_tasks.errors import TaskCancelledError
from dask_task_models_library.container_tasks.events import TaskProgressEvent
from dask_task_models_library.container_tasks.io import TaskCancelEventName
from dask_task_models_library.container_tasks.protocol import TaskOwner
from pytest_simcore.helpers.logging_tools import log_context
from simcore_service_dask_sidecar.utils.dask import (
_DEFAULT_MAX_RESOURCES,
TaskPublisher,
Expand All @@ -24,7 +27,7 @@
monitor_task_abortion,
publish_event,
)
from tenacity.asyncio import AsyncRetrying
from tenacity import Retrying
from tenacity.retry import retry_if_exception_type
from tenacity.stop import stop_after_delay
from tenacity.wait import wait_fixed
Expand All @@ -37,128 +40,85 @@
]


def test_publish_event(
dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
@pytest.fixture(params=["sync-dask-client", "async-dask-client"])
def dask_client_multi(
request: pytest.FixtureRequest,
dask_client: distributed.Client,
async_dask_client: distributed.Client,
) -> distributed.Client:
if request.param == "sync-dask-client":
return dask_client
return async_dask_client


@pytest.mark.parametrize(
"handler", [mock.Mock(), mock.AsyncMock()], ids=["sync-handler", "async-handler"]
)
async def test_publish_event(
dask_client_multi: distributed.Client,
job_id: str,
task_owner: TaskOwner,
handler: mock.Mock | mock.AsyncMock,
):
dask_pub = distributed.Pub("some_topic", client=dask_client)
dask_sub = distributed.Sub("some_topic", client=dask_client)
event_to_publish = TaskProgressEvent(
job_id=job_id,
msg="the log",
progress=1,
task_owner=task_owner,
)
publish_event(dask_pub=dask_pub, event=event_to_publish)

# NOTE: this tests runs a sync dask client,
# and the CI seems to have sometimes difficulties having this run in a reasonable time
# hence the long time out
message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S)
assert message is not None
assert isinstance(message, str)
received_task_log_event = TaskProgressEvent.model_validate_json(message)
assert received_task_log_event == event_to_publish


async def test_publish_event_async(
async_dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
):
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
event_to_publish = TaskProgressEvent(
job_id=job_id, msg="the log", progress=2, task_owner=task_owner
)
publish_event(dask_pub=dask_pub, event=event_to_publish)

# NOTE: this tests runs a sync dask client,
# and the CI seems to have sometimes difficulties having this run in a reasonable time
# hence the long time out
message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S)
assert isinstance(message, Coroutine)
message = await message
assert message is not None
received_task_log_event = TaskProgressEvent.model_validate_json(message)
assert received_task_log_event == event_to_publish

# NOTE: only 1 handler per topic is allowed at a time
dask_client_multi.subscribe_topic(TaskProgressEvent.topic_name(), handler)

@pytest.fixture
async def asyncio_task() -> AsyncIterator[Callable[[Coroutine], asyncio.Task]]:
created_tasks = []
def _worker_task() -> int:
with log_context(logging.INFO, "_worker_task"):

def _creator(coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro, name="pytest_asyncio_task")
created_tasks.append(task)
return task
async def _() -> int:
with log_context(logging.INFO, "_worker_task_async"):
publish_event(event_to_publish)
return 2

yield _creator
for task in created_tasks:
task.cancel()
return asyncio.run(_())

await asyncio.gather(*created_tasks, return_exceptions=True)
future = dask_client_multi.submit(_worker_task)
assert await maybe_await(future.result(timeout=DASK_TESTING_TIMEOUT_S)) == 2


async def test_publish_event_async_using_task(
async_dask_client: distributed.Client,
asyncio_task: Callable[[Coroutine], asyncio.Task],
job_id: str,
task_owner: TaskOwner,
):
dask_pub = distributed.Pub("some_topic", client=async_dask_client)
dask_sub = distributed.Sub("some_topic", client=async_dask_client)
NUMBER_OF_MESSAGES = 1000
received_messages = []

async def _dask_sub_consumer_task(sub: distributed.Sub) -> None:
print("--> starting consumer task")
async for dask_event in sub:
print(f"received {dask_event}")
received_messages.append(dask_event)
print("<-- finished consumer task")

consumer_task = asyncio_task(_dask_sub_consumer_task(dask_sub))
assert consumer_task

async def _dask_publisher_task(pub: distributed.Pub) -> None:
print("--> starting publisher task")
for _ in range(NUMBER_OF_MESSAGES):
event_to_publish = TaskProgressEvent(
job_id=job_id,
progress=0.5,
task_owner=task_owner,
)
publish_event(dask_pub=pub, event=event_to_publish)
print("<-- finished publisher task")

publisher_task = asyncio_task(_dask_publisher_task(dask_pub))
assert publisher_task

async for attempt in AsyncRetrying(
retry=retry_if_exception_type(AssertionError),
stop=stop_after_delay(DASK_TESTING_TIMEOUT_S),
wait=wait_fixed(0.01),
for attempt in Retrying(
wait=wait_fixed(0.2),
stop=stop_after_delay(15),
reraise=True,
retry=retry_if_exception_type(AssertionError),
):
with attempt:
print(
f"checking number of received messages...currently {len(received_messages)}"
events = await maybe_await(
dask_client_multi.get_events(TaskProgressEvent.topic_name())
)
assert len(received_messages) == NUMBER_OF_MESSAGES
print("all expected messages received")
assert events is not None, "No events received"
assert isinstance(events, tuple)

handler.assert_called_with(events[-1])

def _wait_for_task_to_start() -> None:
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
assert isinstance(events, tuple)
assert len(events) == 1
assert isinstance(events[0], tuple)
received_task_log_event = TaskProgressEvent.model_validate_json(events[0][1])
assert received_task_log_event == event_to_publish


def _wait_for_task_to_start(dask_client: distributed.Client) -> None:
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
start_event.wait(timeout=DASK_TESTING_TIMEOUT_S)


def _notify_task_is_started_and_ready() -> None:
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
def _notify_task_is_started_and_ready(dask_client: distributed.Client) -> None:
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
start_event.set()


def _some_long_running_task() -> int:
assert is_current_task_aborted() is False
_notify_task_is_started_and_ready()
dask_client = distributed.get_worker().client
_notify_task_is_started_and_ready(dask_client)

for i in range(300):
print("running iteration", i)
Expand All @@ -176,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
not work in distributed mode where an Event is necessary."""
# NOTE: this works because the cluster is in the same machine
future = dask_client.submit(_some_long_running_task)
_wait_for_task_to_start()
_wait_for_task_to_start(dask_client)
future.cancel()
assert future.cancelled()
with pytest.raises(concurrent.futures.CancelledError):
Expand All @@ -186,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
def test_task_is_aborted_using_event(dask_client: distributed.Client):
job_id = "myfake_job_id"
future = dask_client.submit(_some_long_running_task, key=job_id)
_wait_for_task_to_start()
_wait_for_task_to_start(dask_client)

dask_event = distributed.Event(TaskCancelEventName.format(job_id))
dask_event.set()
Expand All @@ -203,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int:

async def _long_running_task_async() -> int:
task_publishers = TaskPublisher(task_owner=task_owner)
_notify_task_is_started_and_ready()
worker = distributed.get_worker()
_notify_task_is_started_and_ready(worker.client)
current_task = asyncio.current_task()
assert current_task
async with monitor_task_abortion(
Expand All @@ -229,7 +190,7 @@ def test_monitor_task_abortion(
future = dask_client.submit(
_some_long_running_task_with_monitoring, task_owner=task_owner, key=job_id
)
_wait_for_task_to_start()
_wait_for_task_to_start(dask_client)
# trigger cancellation
dask_event = distributed.Event(TaskCancelEventName.format(job_id))
dask_event.set()
Expand Down
Loading
Loading