diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py index feab5e4d632..1b4b40c8d34 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py @@ -2,7 +2,7 @@ import contextlib import logging from collections.abc import AsyncIterator -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Final import distributed @@ -17,7 +17,7 @@ from distributed.worker_state_machine import TaskState from models_library.progress_bar import ProgressReport from models_library.rabbitmq_messages import LoggerRabbitMessage -from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch +from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch, log_context from ..rabbitmq_plugin import get_rabbitmq_client @@ -63,24 +63,23 @@ def get_current_task_resources() -> dict[str, float]: @dataclass(slots=True, kw_only=True) class TaskPublisher: task_owner: TaskOwner - progress: distributed.Pub = field(init=False) _last_published_progress_value: float = -1 - def __post_init__(self) -> None: - self.progress = distributed.Pub(TaskProgressEvent.topic_name()) - def publish_progress(self, report: ProgressReport) -> None: rounded_value = round(report.percent_value, ndigits=2) if rounded_value > self._last_published_progress_value: - with log_catch(logger=_logger, reraise=False): + with ( + log_catch(logger=_logger, reraise=False), + log_context( + _logger, logging.DEBUG, msg=f"publish progress {rounded_value=}" + ), + ): publish_event( - self.progress, TaskProgressEvent.from_dask_worker( progress=rounded_value, task_owner=self.task_owner ), ) self._last_published_progress_value = rounded_value - _logger.debug("PROGRESS: %s", rounded_value) async def publish_logs( self, @@ -169,7 +168,14 @@ async def periodicaly_check_if_aborted(task_name: str) -> None: await periodically_checking_task -def publish_event(dask_pub: distributed.Pub, event: BaseTaskEvent) -> None: +def publish_event( + event: BaseTaskEvent, +) -> None: """never reraises, only CancellationError""" - with log_catch(_logger, reraise=False): - dask_pub.put(event.model_dump_json()) + worker = get_worker() + _logger.debug("current worker %s", f"{worker=}") + with ( + log_catch(_logger, reraise=False), + log_context(_logger, logging.DEBUG, msg=f"publishing {event=}"), + ): + worker.log_event(TaskProgressEvent.topic_name(), event.model_dump_json()) diff --git a/services/dask-sidecar/tests/unit/conftest.py b/services/dask-sidecar/tests/unit/conftest.py index 1fe06628ccc..944e142c129 100644 --- a/services/dask-sidecar/tests/unit/conftest.py +++ b/services/dask-sidecar/tests/unit/conftest.py @@ -99,7 +99,7 @@ def app_environment( model_dump_with_secrets(rabbit_service, show_secrets=True) ), "SC_BOOT_MODE": "debug", - "SIDECAR_LOGLEVEL": "DEBUG", + "DASK_SIDECAR_LOGLEVEL": "DEBUG", "SIDECAR_COMP_SERVICES_SHARED_VOLUME_NAME": "simcore_computational_shared_data", "SIDECAR_COMP_SERVICES_SHARED_FOLDER": f"{shared_data_folder}", }, diff --git a/services/dask-sidecar/tests/unit/test_utils_dask.py b/services/dask-sidecar/tests/unit/test_utils_dask.py index 9a1f6c7d18a..0e8193558be 100644 --- a/services/dask-sidecar/tests/unit/test_utils_dask.py +++ b/services/dask-sidecar/tests/unit/test_utils_dask.py @@ -6,16 +6,19 @@ import asyncio import concurrent.futures +import logging import time -from collections.abc import AsyncIterator, Callable, Coroutine from typing import Any +from unittest import mock import distributed import pytest +from common_library.async_tools import maybe_await from dask_task_models_library.container_tasks.errors import TaskCancelledError from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import TaskCancelEventName from dask_task_models_library.container_tasks.protocol import TaskOwner +from pytest_simcore.helpers.logging_tools import log_context from simcore_service_dask_sidecar.utils.dask import ( _DEFAULT_MAX_RESOURCES, TaskPublisher, @@ -24,7 +27,7 @@ monitor_task_abortion, publish_event, ) -from tenacity.asyncio import AsyncRetrying +from tenacity import Retrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay from tenacity.wait import wait_fixed @@ -37,128 +40,85 @@ ] -def test_publish_event( - dask_client: distributed.Client, job_id: str, task_owner: TaskOwner +@pytest.fixture(params=["sync-dask-client", "async-dask-client"]) +def dask_client_multi( + request: pytest.FixtureRequest, + dask_client: distributed.Client, + async_dask_client: distributed.Client, +) -> distributed.Client: + if request.param == "sync-dask-client": + return dask_client + return async_dask_client + + +@pytest.mark.parametrize( + "handler", [mock.Mock(), mock.AsyncMock()], ids=["sync-handler", "async-handler"] +) +async def test_publish_event( + dask_client_multi: distributed.Client, + job_id: str, + task_owner: TaskOwner, + handler: mock.Mock | mock.AsyncMock, ): - dask_pub = distributed.Pub("some_topic", client=dask_client) - dask_sub = distributed.Sub("some_topic", client=dask_client) event_to_publish = TaskProgressEvent( job_id=job_id, msg="the log", progress=1, task_owner=task_owner, ) - publish_event(dask_pub=dask_pub, event=event_to_publish) - - # NOTE: this tests runs a sync dask client, - # and the CI seems to have sometimes difficulties having this run in a reasonable time - # hence the long time out - message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S) - assert message is not None - assert isinstance(message, str) - received_task_log_event = TaskProgressEvent.model_validate_json(message) - assert received_task_log_event == event_to_publish - - -async def test_publish_event_async( - async_dask_client: distributed.Client, job_id: str, task_owner: TaskOwner -): - dask_pub = distributed.Pub("some_topic", client=async_dask_client) - dask_sub = distributed.Sub("some_topic", client=async_dask_client) - event_to_publish = TaskProgressEvent( - job_id=job_id, msg="the log", progress=2, task_owner=task_owner - ) - publish_event(dask_pub=dask_pub, event=event_to_publish) - - # NOTE: this tests runs a sync dask client, - # and the CI seems to have sometimes difficulties having this run in a reasonable time - # hence the long time out - message = dask_sub.get(timeout=DASK_TESTING_TIMEOUT_S) - assert isinstance(message, Coroutine) - message = await message - assert message is not None - received_task_log_event = TaskProgressEvent.model_validate_json(message) - assert received_task_log_event == event_to_publish + # NOTE: only 1 handler per topic is allowed at a time + dask_client_multi.subscribe_topic(TaskProgressEvent.topic_name(), handler) -@pytest.fixture -async def asyncio_task() -> AsyncIterator[Callable[[Coroutine], asyncio.Task]]: - created_tasks = [] + def _worker_task() -> int: + with log_context(logging.INFO, "_worker_task"): - def _creator(coro: Coroutine) -> asyncio.Task: - task = asyncio.create_task(coro, name="pytest_asyncio_task") - created_tasks.append(task) - return task + async def _() -> int: + with log_context(logging.INFO, "_worker_task_async"): + publish_event(event_to_publish) + return 2 - yield _creator - for task in created_tasks: - task.cancel() + return asyncio.run(_()) - await asyncio.gather(*created_tasks, return_exceptions=True) + future = dask_client_multi.submit(_worker_task) + assert await maybe_await(future.result(timeout=DASK_TESTING_TIMEOUT_S)) == 2 - -async def test_publish_event_async_using_task( - async_dask_client: distributed.Client, - asyncio_task: Callable[[Coroutine], asyncio.Task], - job_id: str, - task_owner: TaskOwner, -): - dask_pub = distributed.Pub("some_topic", client=async_dask_client) - dask_sub = distributed.Sub("some_topic", client=async_dask_client) - NUMBER_OF_MESSAGES = 1000 - received_messages = [] - - async def _dask_sub_consumer_task(sub: distributed.Sub) -> None: - print("--> starting consumer task") - async for dask_event in sub: - print(f"received {dask_event}") - received_messages.append(dask_event) - print("<-- finished consumer task") - - consumer_task = asyncio_task(_dask_sub_consumer_task(dask_sub)) - assert consumer_task - - async def _dask_publisher_task(pub: distributed.Pub) -> None: - print("--> starting publisher task") - for _ in range(NUMBER_OF_MESSAGES): - event_to_publish = TaskProgressEvent( - job_id=job_id, - progress=0.5, - task_owner=task_owner, - ) - publish_event(dask_pub=pub, event=event_to_publish) - print("<-- finished publisher task") - - publisher_task = asyncio_task(_dask_publisher_task(dask_pub)) - assert publisher_task - - async for attempt in AsyncRetrying( - retry=retry_if_exception_type(AssertionError), - stop=stop_after_delay(DASK_TESTING_TIMEOUT_S), - wait=wait_fixed(0.01), + for attempt in Retrying( + wait=wait_fixed(0.2), + stop=stop_after_delay(15), reraise=True, + retry=retry_if_exception_type(AssertionError), ): with attempt: - print( - f"checking number of received messages...currently {len(received_messages)}" + events = await maybe_await( + dask_client_multi.get_events(TaskProgressEvent.topic_name()) ) - assert len(received_messages) == NUMBER_OF_MESSAGES - print("all expected messages received") + assert events is not None, "No events received" + assert isinstance(events, tuple) + handler.assert_called_with(events[-1]) -def _wait_for_task_to_start() -> None: - start_event = distributed.Event(DASK_TASK_STARTED_EVENT) + assert isinstance(events, tuple) + assert len(events) == 1 + assert isinstance(events[0], tuple) + received_task_log_event = TaskProgressEvent.model_validate_json(events[0][1]) + assert received_task_log_event == event_to_publish + + +def _wait_for_task_to_start(dask_client: distributed.Client) -> None: + start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client) start_event.wait(timeout=DASK_TESTING_TIMEOUT_S) -def _notify_task_is_started_and_ready() -> None: - start_event = distributed.Event(DASK_TASK_STARTED_EVENT) +def _notify_task_is_started_and_ready(dask_client: distributed.Client) -> None: + start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client) start_event.set() def _some_long_running_task() -> int: assert is_current_task_aborted() is False - _notify_task_is_started_and_ready() + dask_client = distributed.get_worker().client + _notify_task_is_started_and_ready(dask_client) for i in range(300): print("running iteration", i) @@ -176,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client): not work in distributed mode where an Event is necessary.""" # NOTE: this works because the cluster is in the same machine future = dask_client.submit(_some_long_running_task) - _wait_for_task_to_start() + _wait_for_task_to_start(dask_client) future.cancel() assert future.cancelled() with pytest.raises(concurrent.futures.CancelledError): @@ -186,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client): def test_task_is_aborted_using_event(dask_client: distributed.Client): job_id = "myfake_job_id" future = dask_client.submit(_some_long_running_task, key=job_id) - _wait_for_task_to_start() + _wait_for_task_to_start(dask_client) dask_event = distributed.Event(TaskCancelEventName.format(job_id)) dask_event.set() @@ -203,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int: async def _long_running_task_async() -> int: task_publishers = TaskPublisher(task_owner=task_owner) - _notify_task_is_started_and_ready() + worker = distributed.get_worker() + _notify_task_is_started_and_ready(worker.client) current_task = asyncio.current_task() assert current_task async with monitor_task_abortion( @@ -229,7 +190,7 @@ def test_monitor_task_abortion( future = dask_client.submit( _some_long_running_task_with_monitoring, task_owner=task_owner, key=job_id ) - _wait_for_task_to_start() + _wait_for_task_to_start(dask_client) # trigger cancellation dask_event = distributed.Event(TaskCancelEventName.format(job_id)) dask_event.set() diff --git a/services/dask-sidecar/tests/unit/test_worker.py b/services/dask-sidecar/tests/unit/test_worker.py index d826270fca4..235b61e3d15 100644 --- a/services/dask-sidecar/tests/unit/test_worker.py +++ b/services/dask-sidecar/tests/unit/test_worker.py @@ -106,14 +106,6 @@ def dask_subsystem_mock( "pytest_jobid" ) # mock dask event publishing - dask_utils_publish_event_mock = mocker.patch( - "simcore_service_dask_sidecar.utils.dask.distributed.Pub", - autospec=True, - ) - mocker.patch( - "simcore_service_dask_sidecar.utils.dask.distributed.Sub", - autospec=True, - ) mocker.patch( "simcore_service_dask_sidecar.utils.dask.is_current_task_aborted", autospec=True, @@ -138,7 +130,6 @@ def dask_subsystem_mock( return { "dask_client": dask_client_mock, "dask_task_state": dask_task_mock, - "dask_event_publish": dask_utils_publish_event_mock, } @@ -562,10 +553,6 @@ def test_run_computational_sidecar_real_fct( sleeper_task.service_key, sleeper_task.service_version, ) - for event in [TaskProgressEvent]: - dask_subsystem_mock["dask_event_publish"].assert_any_call( - name=event.topic_name() - ) assert log_rabbit_client_parser.called # check that the task produces expected logs @@ -641,8 +628,26 @@ def test_run_multiple_computational_sidecar_dask( @pytest.fixture -def progress_sub(dask_client: distributed.Client) -> distributed.Sub: - return distributed.Sub(TaskProgressEvent.topic_name(), client=dask_client) +def progress_event_handler(dask_client: distributed.Client) -> mock.Mock: + mocked_parser = mock.Mock() + dask_client.subscribe_topic(TaskProgressEvent.topic_name(), mocked_parser) + return mocked_parser + + +def _assert_parse_progresses_from_progress_event_handler( + progress_event_handler: mock.Mock, +) -> list[float]: + assert progress_event_handler.called + worker_progresses = [ + TaskProgressEvent.model_validate_json(msg.args[0][1]).progress + for msg in progress_event_handler.call_args_list + ] + assert worker_progresses == sorted( + set(worker_progresses) + ), "ordering of progress values incorrectly sorted!" + assert worker_progresses[0] == 0, "missing/incorrect initial progress value" + assert worker_progresses[-1] == 1, "missing/incorrect final progress value" + return worker_progresses @pytest.mark.parametrize( @@ -651,7 +656,7 @@ def progress_sub(dask_client: distributed.Client) -> distributed.Sub: async def test_run_computational_sidecar_dask( app_environment: EnvVarsDict, sleeper_task: ServiceExampleParam, - progress_sub: distributed.Sub, + progress_event_handler: mock.Mock, mocked_get_image_labels: mock.Mock, s3_settings: S3Settings, log_rabbit_client_parser: mock.AsyncMock, @@ -670,16 +675,8 @@ async def test_run_computational_sidecar_dask( assert isinstance(output_data, TaskOutputData) # check that the task produces expected logs - worker_progresses = [ - TaskProgressEvent.model_validate_json(msg).progress - for msg in progress_sub.buffer - ] - # check ordering - assert worker_progresses == sorted( - set(worker_progresses) - ), "ordering of progress values incorrectly sorted!" - assert worker_progresses[0] == 0, "missing/incorrect initial progress value" - assert worker_progresses[-1] == 1, "missing/incorrect final progress value" + _assert_parse_progresses_from_progress_event_handler(progress_event_handler) + async for attempt in AsyncRetrying( wait=wait_fixed(1), stop=stop_after_delay(30), @@ -731,7 +728,7 @@ async def test_run_computational_sidecar_dask( async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub( dask_client: distributed.Client, sidecar_task: Callable[..., ServiceExampleParam], - progress_sub: distributed.Sub, + progress_event_handler: mock.Mock, mocked_get_image_labels: mock.Mock, log_rabbit_client_parser: mock.AsyncMock, ): @@ -757,17 +754,7 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub assert isinstance(output_data, TaskOutputData) # check that the task produces expected logs - worker_progresses = [ - TaskProgressEvent.model_validate_json(msg).progress - for msg in progress_sub.buffer - ] - # check length - assert len(worker_progresses) == len( - set(worker_progresses) - ), "there are duplicate progresses!" - assert sorted(worker_progresses) == worker_progresses, "ordering issue" - assert worker_progresses[0] == 0, "missing/incorrect initial progress value" - assert worker_progresses[-1] == 1, "missing/incorrect final progress value" + _assert_parse_progresses_from_progress_event_handler(progress_event_handler) async for attempt in AsyncRetrying( wait=wait_fixed(1), diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 77a4c807e6e..ab31bcd1367 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Any import arrow from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -35,7 +36,7 @@ clean_task_output_and_log_files_if_invalid, parse_output_data, ) -from ...utils.dask_client_utils import TaskHandlers +from ...utils.dask_client_utils import TaskHandlers, UnixTimestamp from ...utils.rabbitmq import ( publish_service_progress, publish_service_resource_tracking_stopped, @@ -344,9 +345,11 @@ async def _process_task_result( optional_stopped=arrow.utcnow().datetime, ) - async def _task_progress_change_handler(self, event: str) -> None: + async def _task_progress_change_handler( + self, event: tuple[UnixTimestamp, Any] + ) -> None: with log_catch(_logger, reraise=False): - task_progress_event = TaskProgressEvent.model_validate_json(event) + task_progress_event = TaskProgressEvent.model_validate_json(event[1]) _logger.debug("received task progress update: %s", task_progress_event) user_id = task_progress_event.task_owner.user_id project_id = task_progress_event.task_owner.project_id diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index 6ac294238e8..a5dcfc4fd7b 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -8,12 +8,11 @@ """ -import asyncio import logging import traceback from collections.abc import Callable from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass from http.client import HTTPException from typing import Any, Final, cast @@ -23,6 +22,7 @@ from common_library.json_serialization import json_dumps from dask_task_models_library.container_tasks.docker import DockerBasicAuth from dask_task_models_library.container_tasks.errors import TaskCancelledError +from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import ( TaskCancelEventName, TaskInputData, @@ -53,7 +53,7 @@ from models_library.users import UserID from pydantic import TypeAdapter, ValidationError from pydantic.networks import AnyUrl -from servicelib.logging_utils import log_catch +from servicelib.logging_utils import log_catch, log_context from settings_library.s3 import S3Settings from simcore_sdk.node_ports_common.exceptions import NodeportsException from simcore_sdk.node_ports_v2 import FileLinkType @@ -123,8 +123,6 @@ class DaskClient: tasks_file_link_type: FileLinkType cluster_type: ClusterTypeInModel - _subscribed_tasks: list[asyncio.Task] = field(default_factory=list) - @classmethod async def create( cls, @@ -177,24 +175,15 @@ async def create( raise ValueError(err_msg) async def delete(self) -> None: - _logger.debug("closing dask client...") - for task in self._subscribed_tasks: - task.cancel() - await asyncio.gather(*self._subscribed_tasks, return_exceptions=True) - await self.backend.close() - _logger.info("dask client properly closed") + with log_context(_logger, logging.INFO, msg="close dask client"): + await self.backend.close() def register_handlers(self, task_handlers: TaskHandlers) -> None: _event_consumer_map = [ - (self.backend.progress_sub, task_handlers.task_progress_handler), - ] - self._subscribed_tasks = [ - asyncio.create_task( - dask_utils.dask_sub_consumer_task(dask_sub, handler), - name=f"{dask_sub.name}_dask_sub_consumer_task", - ) - for dask_sub, handler in _event_consumer_map + (TaskProgressEvent.topic_name(), task_handlers.task_progress_handler), ] + for topic_name, handler in _event_consumer_map: + self.backend.client.subscribe_topic(topic_name, handler) async def _publish_in_dask( # noqa: PLR0913 # pylint: disable=too-many-arguments self, diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask.py b/services/director-v2/src/simcore_service_director_v2/utils/dask.py index e774cdb7cc9..29391141dfc 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask.py @@ -1,8 +1,7 @@ -import asyncio import collections import logging -from collections.abc import Awaitable, Callable, Coroutine, Generator -from typing import Any, Final, NoReturn, ParamSpec, TypeVar, cast +from collections.abc import Coroutine, Generator +from typing import Any, ParamSpec, TypeVar, cast import distributed from common_library.json_serialization import json_dumps @@ -30,7 +29,6 @@ from models_library.users import UserID from models_library.wallets import WalletID from pydantic import AnyUrl, ByteSize, TypeAdapter, ValidationError -from servicelib.logging_utils import log_catch, log_context from simcore_sdk import node_ports_v2 from simcore_sdk.node_ports_common.exceptions import ( NodeportsException, @@ -438,40 +436,6 @@ async def clean_task_output_and_log_files_if_invalid( ) -async def _dask_sub_consumer( - dask_sub: distributed.Sub, - handler: Callable[[str], Awaitable[None]], -) -> None: - async for dask_event in dask_sub: - _logger.debug( - "received dask event '%s' of topic %s", - dask_event, - dask_sub.name, - ) - await handler(dask_event) - - -_REST_TIMEOUT_S: Final[int] = 1 - - -async def dask_sub_consumer_task( - dask_sub: distributed.Sub, - handler: Callable[[str], Awaitable[None]], -) -> NoReturn: - while True: - with ( - log_catch(_logger, reraise=False), - log_context( - _logger, - level=logging.DEBUG, - msg=f"dask sub task for topic {dask_sub.name}", - ), - ): - await _dask_sub_consumer(dask_sub, handler) - # we sleep a bit before restarting - await asyncio.sleep(_REST_TIMEOUT_S) - - def from_node_reqs_to_dask_resources( node_reqs: NodeRequirements, ) -> dict[str, int | float]: diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py b/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py index 34e11952d31..a692b3abf16 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py @@ -2,22 +2,22 @@ import os import socket from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from dataclasses import dataclass +from typing import Any, TypeAlias import distributed -from dask_task_models_library.container_tasks.events import ( - TaskProgressEvent, -) from models_library.clusters import ClusterAuthentication, TLSAuthentication from pydantic import AnyUrl from ..core.errors import ConfigurationError from .dask import wrap_client_async_routine +UnixTimestamp: TypeAlias = float + @dataclass class TaskHandlers: - task_progress_handler: Callable[[str], Awaitable[None]] + task_progress_handler: Callable[[tuple[UnixTimestamp, Any]], Awaitable[None]] logger = logging.getLogger(__name__) @@ -27,12 +27,6 @@ class TaskHandlers: class DaskSubSystem: client: distributed.Client scheduler_id: str - progress_sub: distributed.Sub = field(init=False) - - def __post_init__(self) -> None: - self.progress_sub = distributed.Sub( - TaskProgressEvent.topic_name(), client=self.client - ) async def close(self) -> None: # NOTE: if the Sub are deleted before closing the connection, diff --git a/services/director-v2/tests/unit/test_modules_dask_client.py b/services/director-v2/tests/unit/test_modules_dask_client.py index c52647b7047..370cc967190 100644 --- a/services/director-v2/tests/unit/test_modules_dask_client.py +++ b/services/director-v2/tests/unit/test_modules_dask_client.py @@ -163,7 +163,6 @@ async def factory() -> DaskClient: client.settings == minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND ) - assert not client._subscribed_tasks # noqa: SLF001 assert client.backend.client scheduler_infos = client.backend.client.scheduler_info() # type: ignore @@ -181,7 +180,7 @@ async def factory() -> DaskClient: @pytest.fixture(params=["create_dask_client_from_scheduler"]) async def dask_client( create_dask_client_from_scheduler: Callable[[], Awaitable[DaskClient]], - request, + request: pytest.FixtureRequest, ) -> DaskClient: client: DaskClient = await { "create_dask_client_from_scheduler": create_dask_client_from_scheduler, @@ -402,7 +401,7 @@ def comp_run_metadata(faker: Faker) -> RunMetadataDict: return RunMetadataDict( product_name=faker.pystr(), simcore_user_agent=faker.pystr(), - ) | cast(dict[str, str], faker.pydict(allowed_types=(str,))) + ) | cast(RunMetadataDict, faker.pydict(allowed_types=(str,))) @pytest.fixture @@ -1102,8 +1101,7 @@ def fake_remote_fct( log_file_url: LogFileUploadURL, s3_settings: S3Settings | None, ) -> TaskOutputData: - progress_pub = distributed.Pub(TaskProgressEvent.topic_name()) - progress_pub.put("my name is progress") + get_worker().log_event(TaskProgressEvent.topic_name(), "my name is progress") # tell the client we are done published_event = Event(name=_DASK_START_EVENT) published_event.set() @@ -1147,7 +1145,7 @@ def fake_remote_fct( ) # we should have received data in our TaskHandlers fake_task_handlers.task_progress_handler.assert_called_with( - "my name is progress" + (mock.ANY, "my name is progress") ) await _assert_wait_for_cb_call(mocked_user_completed_cb) diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py index f84ff112186..80666d7f6f0 100644 --- a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py @@ -16,6 +16,7 @@ from typing import Any, cast from unittest import mock +import arrow import pytest from _helpers import ( PublishedProject, @@ -401,11 +402,10 @@ async def _trigger_progress_event( ), ) await cast(DaskScheduler, scheduler)._task_progress_change_handler( # noqa: SLF001 - event.model_dump_json() + (arrow.utcnow().timestamp(), event.model_dump_json()) ) -@pytest.mark.acceptance_test() async def test_proper_pipeline_is_scheduled( # noqa: PLR0915 with_disabled_auto_scheduling: mock.Mock, with_disabled_scheduler_publisher: mock.Mock, @@ -1191,7 +1191,9 @@ async def test_task_progress_triggers( ) await cast( # noqa: SLF001 DaskScheduler, scheduler_api - )._task_progress_change_handler(progress_event.model_dump_json()) + )._task_progress_change_handler( + (arrow.utcnow().timestamp(), progress_event.model_dump_json()) + ) # NOTE: not sure whether it should switch to STARTED.. it would make sense await assert_comp_tasks( sqlalchemy_async_engine,