diff --git a/packages/dask-task-models-library/src/dask_task_models_library/models.py b/packages/dask-task-models-library/src/dask_task_models_library/models.py index 2e0bbed7535b..2f701440d1d5 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/models.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/models.py @@ -1,4 +1,72 @@ -from typing import TypeAlias +from typing import Final, Literal, TypeAlias + +from dask.typing import Key +from distributed.scheduler import TaskStateState as SchedulerTaskState +from distributed.worker_state_machine import TaskStateState as WorkerTaskState +from models_library.projects_state import RunningState +from pydantic import BaseModel DaskJobID: TypeAlias = str DaskResources: TypeAlias = dict[str, int | float] + +TASK_LIFE_CYCLE_EVENT: Final[str] = "task-lifecycle-{key}" +TASK_RUNNING_PROGRESS_EVENT: Final[str] = "task-progress-{key}" +_SCHEDULER_TASK_STATE_TO_RUNNING_STATE: Final[ + dict[SchedulerTaskState, RunningState] +] = { + "released": RunningState.NOT_STARTED, # Known but not actively computing or in memory + "waiting": RunningState.PENDING, # On track to be computed, waiting on dependencies to arrive in memory + "no-worker": RunningState.WAITING_FOR_RESOURCES, # Ready to be computed, but no appropriate worker exists (for example because of resource restrictions, or because no worker is connected at all). + "queued": RunningState.WAITING_FOR_RESOURCES, # Ready to be computed, but all workers are already full. + "processing": RunningState.PENDING, # All dependencies are available and the task is assigned to a worker for compute (the scheduler doesn’t know whether it’s in a worker queue or actively being computed). + "memory": RunningState.SUCCESS, # In memory on one or more workers + "erred": RunningState.FAILED, # Task computation, or one of its dependencies, has encountered an error + "forgotten": RunningState.UNKNOWN, # Task is no longer needed by any client or dependent task, so it disappears from the scheduler as well. As soon as a task reaches this state, it is immediately dereferenced from the scheduler. +} + +_WORKER_TASK_STATE_TO_RUNNING_STATE: Final[dict[WorkerTaskState, RunningState]] = { + "cancelled": RunningState.ABORTED, # The scheduler asked to forget about this task, but it’s technically impossible at the moment. See Task cancellation. The task can be found in whatever collections it was in its previous state. + "constrained": RunningState.PENDING, # Like ready, but the user specified resource constraints for this task. The task can be found in the WorkerState.constrained queue. + "error": RunningState.FAILED, # Task execution failed + "executing": RunningState.STARTED, # The task is currently being computed on a thread. It can be found in the WorkerState.executing set and in the distributed.worker.Worker.active_threads dict. + "fetch": RunningState.PENDING, # This task is in memory on one or more peer workers, but not on this worker. Its data is queued to be transferred over the network, either because it’s a dependency of a task in waiting state, or because the Active Memory Manager requested it to be replicated here. The task can be found in the WorkerState.data_needed heap. + "flight": RunningState.PENDING, # The task data is currently being transferred over the network from another worker. The task can be found in the WorkerState.in_flight_tasks and WorkerState.in_flight_workers collections. + "forgotten": RunningState.UNKNOWN, # The scheduler asked this worker to forget about the task, and there are neither dependents nor dependencies on the same worker. + "long-running": RunningState.STARTED, # Like executing, but the user code called distributed.secede() so the task no longer counts towards the maximum number of concurrent tasks. It can be found in the WorkerState.long_running set and in the distributed.worker.Worker.active_threads dict. + "memory": RunningState.SUCCESS, # Task execution completed, or the task was successfully transferred from another worker, and is now held in either WorkerState.data or WorkerState.actors. + "missing": RunningState.PENDING, # Like fetch, but all peer workers that were listed by the scheduler are either unreachable or have responded they don’t actually have the task data. The worker will periodically ask the scheduler if it knows of additional replicas; when it does, the task will transition again to fetch. The task can be found in the WorkerState.missing_dep_flight set. + "ready": RunningState.PENDING, # The task is ready to be computed; all of its dependencies are in memory on the current worker and it’s waiting for an available thread. The task can be found in the WorkerState.ready heap. + "released": RunningState.PENDING, # Known but not actively computing or in memory. A task can stay in this state when the scheduler asked to forget it, but it has dependent tasks on the same worker. + "rescheduled": RunningState.PENDING, # The task just raised the Reschedule exception. This is a transitory state, which is not stored permanently. + "resumed": RunningState.PENDING, # The task was recovered from cancelled state. See Task cancellation. The task can be found in whatever collections it was in its previous state. + "waiting": RunningState.PENDING, # The scheduler has added the task to the worker queue. All of its dependencies are in memory somewhere on the cluster, but not all of them are in memory on the current worker, so they need to be fetched. +} + + +class TaskLifeCycleState(BaseModel): + key: str + source: Literal["scheduler", "worker"] + worker: str | None + state: RunningState + + @classmethod + def from_scheduler_task_state( + cls, key: Key, worker: str | None, task_state: SchedulerTaskState + ) -> "TaskLifeCycleState": + return cls( + key=f"{key!r}", + source="scheduler", + worker=worker, + state=_SCHEDULER_TASK_STATE_TO_RUNNING_STATE[task_state], + ) + + @classmethod + def from_worker_task_state( + cls, key: Key, worker: str | None, task_state: WorkerTaskState + ) -> "TaskLifeCycleState": + return cls( + key=f"{key!r}", + source="worker", + worker=worker, + state=_WORKER_TASK_STATE_TO_RUNNING_STATE[task_state], + ) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/plugins/__init__.py b/packages/dask-task-models-library/src/dask_task_models_library/plugins/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_scheduler_plugin.py b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_scheduler_plugin.py new file mode 100644 index 000000000000..69fbc35e21ec --- /dev/null +++ b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_scheduler_plugin.py @@ -0,0 +1,54 @@ +# pylint: disable=unused-argument +import logging +from typing import Any + +import click +from dask.typing import Key +from distributed import Scheduler, SchedulerPlugin +from distributed.scheduler import TaskStateState + +from ..models import TASK_LIFE_CYCLE_EVENT, TaskLifeCycleState + +_logger = logging.getLogger(__name__) + + +class TaskLifecycleSchedulerPlugin(SchedulerPlugin): + def __init__(self) -> None: + self.scheduler = None + _logger.info("initialized TaskLifecycleSchedulerPlugin") + + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler # type: ignore[assignment] + _logger.info("started TaskLifecycleSchedulerPlugin") + + def transition( + self, + key: Key, + start: TaskStateState, + finish: TaskStateState, + *args: Any, # noqa: ARG002 + stimulus_id: str, + **kwargs: Any, + ): + _logger.debug( + "Task %s transition from %s to %s due to %s", + key, + start, + finish, + stimulus_id, + ) + + assert self.scheduler # nosec + + self.scheduler.log_event( + TASK_LIFE_CYCLE_EVENT.format(key=key), + TaskLifeCycleState.from_scheduler_task_state( + key, kwargs.get("worker"), finish + ).model_dump(mode="json"), + ) + + +@click.command() +def dask_setup(scheduler): + plugin = TaskLifecycleSchedulerPlugin() + scheduler.add_plugin(plugin) diff --git a/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py new file mode 100644 index 000000000000..ebc6aabcad85 --- /dev/null +++ b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py @@ -0,0 +1,48 @@ +import logging +from collections.abc import Awaitable +from typing import Any + +import click +from dask.typing import Key +from distributed import WorkerPlugin +from distributed.worker import Worker +from distributed.worker_state_machine import TaskStateState + +from ..models import TASK_LIFE_CYCLE_EVENT, TaskLifeCycleState + +_logger = logging.getLogger(__name__) + + +class TaskLifecycleWorkerPlugin(WorkerPlugin): + def __init__(self) -> None: + self._worker = None + _logger.info("TaskLifecycleWorkerPlugin initialized") + + def setup(self, worker: Worker) -> Awaitable[None]: + async def _() -> None: + self._worker = worker # type: ignore[assignment] + _logger.info("TaskLifecycleWorkerPlugin setup completed") + + return _() + + def transition( + self, + key: Key, + start: TaskStateState, + finish: TaskStateState, + **kwargs: Any, + ): + _logger.info("Task '%s' transition from %s to %s", key, start, finish) + assert self._worker # nosec + self._worker.log_event( + TASK_LIFE_CYCLE_EVENT.format(key=key), + TaskLifeCycleState.from_worker_task_state( + key, kwargs.get("worker"), finish + ).model_dump(mode="json"), + ) + + +@click.command() +async def dask_setup(worker: Worker) -> None: + plugin = TaskLifecycleWorkerPlugin() + await worker.plugin_add(plugin) diff --git a/packages/pytest-simcore/src/pytest_simcore/dask_scheduler.py b/packages/pytest-simcore/src/pytest_simcore/dask_scheduler.py index 03c63c07d348..54fdd0212231 100644 --- a/packages/pytest-simcore/src/pytest_simcore/dask_scheduler.py +++ b/packages/pytest-simcore/src/pytest_simcore/dask_scheduler.py @@ -19,6 +19,9 @@ def dask_workers_config() -> dict[str, Any]: "options": { "nthreads": 2, "resources": {"CPU": 2, "RAM": 48e9}, + "preload": ( + "dask_task_models_library.plugins.task_life_cycle_worker_plugin", + ), }, }, "gpu-worker": { @@ -30,6 +33,9 @@ def dask_workers_config() -> dict[str, Any]: "GPU": 1, "RAM": 48e9, }, + "preload": ( + "dask_task_models_library.plugins.task_life_cycle_worker_plugin", + ), }, }, "large-ram-worker": { @@ -40,6 +46,9 @@ def dask_workers_config() -> dict[str, Any]: "CPU": 8, "RAM": 768e9, }, + "preload": ( + "dask_task_models_library.plugins.task_life_cycle_worker_plugin", + ), }, }, } @@ -54,6 +63,9 @@ def dask_scheduler_config( "options": { "port": unused_tcp_port_factory(), "dashboard_address": f":{unused_tcp_port_factory()}", + "preload": ( + "dask_task_models_library.plugins.task_life_cycle_scheduler_plugin", + ), }, } diff --git a/scripts/maintenance/computational-clusters/autoscaled_monitor/ssh.py b/scripts/maintenance/computational-clusters/autoscaled_monitor/ssh.py index 229dab0c3f40..88486c374286 100644 --- a/scripts/maintenance/computational-clusters/autoscaled_monitor/ssh.py +++ b/scripts/maintenance/computational-clusters/autoscaled_monitor/ssh.py @@ -265,7 +265,11 @@ def _needs_manual_intervention( user_id=containers[0].user_id, project_id=containers[0].project_id, created_at=containers[0].created_at, - needs_manual_intervention=_needs_manual_intervention(containers), + needs_manual_intervention=_needs_manual_intervention(containers) + and ( + (arrow.utcnow().datetime - containers[0].created_at) + > datetime.timedelta(minutes=2) + ), containers=[c.name for c in containers], service_name=containers[0].service_name, service_version=containers[0].service_version, diff --git a/services/clusters-keeper/requirements/ci.txt b/services/clusters-keeper/requirements/ci.txt index 7bb3a4afb29f..22fd83a2698c 100644 --- a/services/clusters-keeper/requirements/ci.txt +++ b/services/clusters-keeper/requirements/ci.txt @@ -14,6 +14,7 @@ # installs this repo's packages simcore-aws-library @ ../../packages/aws-library simcore-common-library @ ../../packages/common-library +simcore-dask-task-models-library @ ../../packages/dask-task-models-library simcore-models-library @ ../../packages/models-library pytest-simcore @ ../../packages/pytest-simcore simcore-service-library[fastapi] @ ../../packages/service-library diff --git a/services/clusters-keeper/requirements/dev.txt b/services/clusters-keeper/requirements/dev.txt index faf4378c83db..004ee6c6241c 100644 --- a/services/clusters-keeper/requirements/dev.txt +++ b/services/clusters-keeper/requirements/dev.txt @@ -18,6 +18,7 @@ --editable ../../packages/pytest-simcore --editable ../../packages/service-library[fastapi] --editable ../../packages/settings-library +--editable ../../packages/dask-task-models-library # installs current package --editable . diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_plugin.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_worker_plugin.py similarity index 98% rename from services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_plugin.py rename to services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_worker_plugin.py index 554988aa0b2f..ba4936284d7b 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_plugin.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_worker_plugin.py @@ -24,7 +24,7 @@ class RabbitMQPlugin(distributed.WorkerPlugin): """Dask Worker Plugin for RabbitMQ integration""" - name = "rabbitmq_plugin" + name = "rabbitmq_worker_plugin" _main_thread_loop: AbstractEventLoop | None = None _client: RabbitMQClient | None = None _settings: RabbitSettings | None = None @@ -60,7 +60,7 @@ async def _() -> None: if threading.current_thread() is not threading.main_thread(): _logger.warning( - "RabbitMQ client plugin setup is not in the main thread! Beware! if in pytest it's ok." + "RabbitMQ client plugin setup is not in the main thread! TIP: if in pytest it's ok." ) with log_context( @@ -98,7 +98,7 @@ async def _() -> None: ) else: _logger.warning( - "RabbitMQ client plugin setup is not the main thread!" + "RabbitMQ client plugin setup is not the main thread! TIP: if in pytest it's ok." ) # Cancel the message processor task diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/scheduler.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/scheduler.py index 4127fca2528b..0813da017416 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/scheduler.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/scheduler.py @@ -1,6 +1,9 @@ import logging import distributed +from dask_task_models_library.plugins.task_life_cycle_scheduler_plugin import ( + TaskLifecycleSchedulerPlugin, +) from servicelib.logging_utils import log_context from ._meta import print_dask_scheduler_banner @@ -19,9 +22,13 @@ async def dask_setup(scheduler: distributed.Scheduler) -> None: with log_context(_logger, logging.INFO, "Launch dask scheduler"): _logger.info("app settings: %s", settings.model_dump_json(indent=1)) + + scheduler.add_plugin(TaskLifecycleSchedulerPlugin()) print_dask_scheduler_banner() -async def dask_teardown(_worker: distributed.Worker) -> None: - with log_context(_logger, logging.INFO, "Tear down dask scheduler"): +async def dask_teardown(scheduler: distributed.Scheduler) -> None: + with log_context( + _logger, logging.INFO, f"Tear down dask scheduler at {scheduler.address}" + ): ... diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py index 1b4b40c8d348..60e84053ed0f 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py @@ -13,13 +13,14 @@ ) from dask_task_models_library.container_tasks.io import TaskCancelEventName from dask_task_models_library.container_tasks.protocol import TaskOwner +from dask_task_models_library.models import TASK_RUNNING_PROGRESS_EVENT from distributed.worker import get_worker from distributed.worker_state_machine import TaskState from models_library.progress_bar import ProgressReport from models_library.rabbitmq_messages import LoggerRabbitMessage from servicelib.logging_utils import LogLevelInt, LogMessageStr, log_catch, log_context -from ..rabbitmq_plugin import get_rabbitmq_client +from ..rabbitmq_worker_plugin import get_rabbitmq_client _logger = logging.getLogger(__name__) @@ -178,4 +179,10 @@ def publish_event( log_catch(_logger, reraise=False), log_context(_logger, logging.DEBUG, msg=f"publishing {event=}"), ): - worker.log_event(TaskProgressEvent.topic_name(), event.model_dump_json()) + worker.log_event( + [ + TaskProgressEvent.topic_name(), + TASK_RUNNING_PROGRESS_EVENT.format(key=event.job_id), + ], + event.model_dump_json(), + ) diff --git a/services/dask-sidecar/src/simcore_service_dask_sidecar/worker.py b/services/dask-sidecar/src/simcore_service_dask_sidecar/worker.py index de3ecd6d66cc..abaedd698ed6 100644 --- a/services/dask-sidecar/src/simcore_service_dask_sidecar/worker.py +++ b/services/dask-sidecar/src/simcore_service_dask_sidecar/worker.py @@ -11,12 +11,15 @@ ContainerTaskParameters, LogFileUploadURL, ) +from dask_task_models_library.plugins.task_life_cycle_worker_plugin import ( + TaskLifecycleWorkerPlugin, +) from servicelib.logging_utils import log_context from settings_library.s3 import S3Settings from ._meta import print_dask_sidecar_banner from .computational_sidecar.core import ComputationalSidecar -from .rabbitmq_plugin import RabbitMQPlugin +from .rabbitmq_worker_plugin import RabbitMQPlugin from .settings import ApplicationSettings from .utils.dask import ( TaskPublisher, @@ -76,14 +79,21 @@ async def dask_setup(worker: distributed.Worker) -> None: RabbitMQPlugin(settings.DASK_SIDECAR_RABBITMQ), catch_errors=False ) except Exception: - await worker.close() + await worker.close(reason="failed to add RabbitMQ plugin") raise + try: + await worker.plugin_add(TaskLifecycleWorkerPlugin(), catch_errors=False) + except Exception: + await worker.close(reason="failed to add TaskLifecycleWorkerPlugin") + raise print_dask_sidecar_banner() async def dask_teardown(worker: distributed.Worker) -> None: - with log_context(_logger, logging.INFO, f"tear down dask {worker.address}"): + with log_context( + _logger, logging.INFO, f"tear down dask worker at {worker.address}" + ): ... diff --git a/services/dask-sidecar/tests/unit/conftest.py b/services/dask-sidecar/tests/unit/conftest.py index 944e142c1298..2e3fb246f884 100644 --- a/services/dask-sidecar/tests/unit/conftest.py +++ b/services/dask-sidecar/tests/unit/conftest.py @@ -117,6 +117,7 @@ def local_cluster(app_environment: EnvVarsDict) -> Iterator[distributed.LocalClu with distributed.LocalCluster( worker_class=distributed.Worker, resources={"CPU": 10, "GPU": 10}, + scheduler_kwargs={"preload": "simcore_service_dask_sidecar.scheduler"}, preload="simcore_service_dask_sidecar.worker", ) as cluster: assert cluster diff --git a/services/dask-sidecar/tests/unit/test_worker.py b/services/dask-sidecar/tests/unit/test_computational_sidecar_tasks.py similarity index 99% rename from services/dask-sidecar/tests/unit/test_worker.py rename to services/dask-sidecar/tests/unit/test_computational_sidecar_tasks.py index 235b61e3d15b..a0756a0bc308 100644 --- a/services/dask-sidecar/tests/unit/test_worker.py +++ b/services/dask-sidecar/tests/unit/test_computational_sidecar_tasks.py @@ -113,7 +113,8 @@ def dask_subsystem_mock( ) # mock dask rabbitmq plugin mock_dask_rabbitmq_plugin = mocker.patch( - "simcore_service_dask_sidecar.rabbitmq_plugin.RabbitMQPlugin", autospec=True + "simcore_service_dask_sidecar.rabbitmq_worker_plugin.RabbitMQPlugin", + autospec=True, ) mock_rabbitmq_client = create_rabbitmq_client("pytest_dask_sidecar_logs_publisher") mock_dask_rabbitmq_plugin.get_client.return_value = mock_rabbitmq_client @@ -505,7 +506,7 @@ async def subscribe_and_process(a_mock: mock.AsyncMock): ready_event.set() # Wait until the test is done - while not shutdown_event.is_set(): + while not shutdown_event.is_set(): # noqa: ASYNC110 await asyncio.sleep(0.1) # Cleanup diff --git a/services/dask-sidecar/tests/unit/test_rabbitmq_plugin.py b/services/dask-sidecar/tests/unit/test_rabbitmq_plugin.py index de632c818ec8..6d2b56cf57ee 100644 --- a/services/dask-sidecar/tests/unit/test_rabbitmq_plugin.py +++ b/services/dask-sidecar/tests/unit/test_rabbitmq_plugin.py @@ -3,7 +3,12 @@ # pylint: disable=unused-variable # pylint: disable=no-member +import asyncio +from unittest import mock + import distributed +import pytest +from pytest_mock import MockerFixture # Selection of core and tool services started in this swarm fixture (integration) pytest_simcore_core_services_selection = [ @@ -14,3 +19,19 @@ def test_rabbitmq_plugin_initializes(dask_client: distributed.Client): ... + + +@pytest.fixture +def erroring_rabbitmq_plugin(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_dask_sidecar.rabbitmq_worker_plugin.RabbitMQPlugin", + autospec=True, + side_effect=RuntimeError("Pytest: RabbitMQ plugin initialization failed"), + ) + + +async def test_dask_worker_closes_if_plugin_fails_on_start( + erroring_rabbitmq_plugin: mock.Mock, + local_cluster: distributed.LocalCluster, +): + await asyncio.sleep(10) diff --git a/services/dask-sidecar/tests/unit/test_tasks_life_cycle.py b/services/dask-sidecar/tests/unit/test_tasks_life_cycle.py new file mode 100644 index 000000000000..2b8b077cbe08 --- /dev/null +++ b/services/dask-sidecar/tests/unit/test_tasks_life_cycle.py @@ -0,0 +1,82 @@ +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument +# pylint: disable=unused-variable +# pylint: disable=no-member + +import time +from collections.abc import Iterable + +import distributed +import pytest +from dask_task_models_library.models import TASK_LIFE_CYCLE_EVENT, TaskLifeCycleState +from models_library.projects_state import RunningState +from tenacity import Retrying, stop_after_delay, wait_fixed + +pytest_simcore_core_services_selection = [ + "rabbit", +] + + +def test_task_state_lifecycle(local_cluster: distributed.LocalCluster) -> None: + def _some_task() -> int: + time.sleep(1) + return 2 + + def _some_failing_task() -> None: + time.sleep(1) + msg = "Some error" + raise RuntimeError(msg) + + local_cluster.scale(0) + for attempt in Retrying( + stop=stop_after_delay(10), wait=wait_fixed(1), reraise=True + ): + with attempt: + assert len(local_cluster.workers) == 0 + with distributed.Client(local_cluster) as dask_client: + # submit the task and wait until it goes into WAITING_FOR_RESOURCES + future = dask_client.submit(_some_task, resources={"CPU": 1}) + for attempt in Retrying( + stop=stop_after_delay(10), wait=wait_fixed(1), reraise=True + ): + with attempt: + events = dask_client.get_events( + TASK_LIFE_CYCLE_EVENT.format(key=future.key) + ) + assert isinstance(events, tuple) + assert len(events) >= 2 + parsed_events = [ + TaskLifeCycleState.model_validate(event[1]) for event in events + ] + assert parsed_events[0].state is RunningState.PENDING + assert parsed_events[-1].state is RunningState.WAITING_FOR_RESOURCES + + # now add a worker and wait for it to take the task + local_cluster.scale(1) + + # we basically wait for the tasks to finish + assert future.result(timeout=15) == 2 + + events = dask_client.get_events(TASK_LIFE_CYCLE_EVENT.format(key=future.key)) + assert isinstance(events, tuple) + parsed_events = [ + TaskLifeCycleState.model_validate(event[1]) for event in events + ] + assert parsed_events[0].state is RunningState.PENDING + assert RunningState.STARTED in {event.state for event in parsed_events} + assert RunningState.FAILED not in {event.state for event in parsed_events} + assert parsed_events[-1].state is RunningState.SUCCESS + + future = dask_client.submit(_some_failing_task) + with pytest.raises(RuntimeError): + future.result(timeout=10) + events = dask_client.get_events(TASK_LIFE_CYCLE_EVENT.format(key=future.key)) + assert isinstance(events, Iterable) + parsed_events = [ + TaskLifeCycleState.model_validate(event[1]) for event in events + ] + assert parsed_events[0].state is RunningState.PENDING + assert RunningState.STARTED in {event.state for event in parsed_events} + assert RunningState.FAILED in {event.state for event in parsed_events} + assert RunningState.SUCCESS not in {event.state for event in parsed_events} + assert parsed_events[-1].state is RunningState.FAILED diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py index 8a7809ab7d65..9f096911030a 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py @@ -17,6 +17,7 @@ import contextlib import logging +from datetime import timedelta from typing import Annotated, Any, Final import networkx as nx @@ -84,7 +85,7 @@ from ..dependencies.rut_client import get_rut_client from .computations_tasks import analyze_pipeline -_PIPELINE_ABORT_TIMEOUT_S: Final[int] = 10 +_PIPELINE_ABORT_TIMEOUT_S: Final[timedelta] = timedelta(seconds=30) _logger = logging.getLogger(__name__) @@ -121,7 +122,7 @@ async def _check_pipeline_startable( ): raise HTTPException( status_code=status.HTTP_406_NOT_ACCEPTABLE, - detail=f"Project {computation.project_id} cannot run since it contains deprecated tasks {jsonable_encoder( deprecated_tasks)}", + detail=f"Project {computation.project_id} cannot run since it contains deprecated tasks {jsonable_encoder(deprecated_tasks)}", ) @@ -624,7 +625,7 @@ def return_last_value(retry_state: Any) -> Any: return retry_state.outcome.result() @retry( - stop=stop_after_delay(_PIPELINE_ABORT_TIMEOUT_S), + stop=stop_after_delay(_PIPELINE_ABORT_TIMEOUT_S.total_seconds()), wait=wait_random(0, 2), retry_error_callback=return_last_value, retry=retry_if_result(lambda result: result is False), @@ -643,7 +644,7 @@ async def check_pipeline_stopped() -> bool: # wait for the pipeline to be stopped if not await check_pipeline_stopped(): _logger.error( - "pipeline %s could not be stopped properly after %ss", + "pipeline %s could not be stopped properly after %s", project_id, _PIPELINE_ABORT_TIMEOUT_S, ) diff --git a/services/director-v2/src/simcore_service_director_v2/models/dask_subsystem.py b/services/director-v2/src/simcore_service_director_v2/models/dask_subsystem.py deleted file mode 100644 index 34a270cba88c..000000000000 --- a/services/director-v2/src/simcore_service_director_v2/models/dask_subsystem.py +++ /dev/null @@ -1,13 +0,0 @@ -from enum import Enum - - -# NOTE: mypy fails with src/simcore_service_director_v2/modules/dask_client.py:101:5: error: Dict entry 0 has incompatible type "str": "auto"; expected "Any": "DaskClientTaskState" [dict-item] -# when using StrAutoEnum -class DaskClientTaskState(str, Enum): - PENDING = "PENDING" - NO_WORKER = "NO_WORKER" - PENDING_OR_STARTED = "PENDING_OR_STARTED" - LOST = "LOST" - ERRED = "ERRED" - ABORTED = "ABORTED" - SUCCESS = "SUCCESS" diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py index 49465084d5d7..7473eec54fd0 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -116,11 +116,11 @@ class SortedTasks: async def _triage_changed_tasks( - changed_tasks: list[tuple[_Previous, _Current]], + changed_tasks_or_executing: list[tuple[_Previous, _Current]], ) -> SortedTasks: started_tasks = [ current - for previous, current in changed_tasks + for previous, current in changed_tasks_or_executing if current.state in RUNNING_STATES or ( previous.state in WAITING_FOR_START_STATES @@ -130,17 +130,21 @@ async def _triage_changed_tasks( # NOTE: some tasks can be both started and completed since we might have the time they were running completed_tasks = [ - current for _, current in changed_tasks if current.state in COMPLETED_STATES + current + for _, current in changed_tasks_or_executing + if current.state in COMPLETED_STATES ] waiting_for_resources_tasks = [ current - for previous, current in changed_tasks + for previous, current in changed_tasks_or_executing if current.state in WAITING_FOR_START_STATES ] lost_or_momentarily_lost_tasks = [ - current for _, current in changed_tasks if current.state is RunningState.UNKNOWN + current + for _, current in changed_tasks_or_executing + if current.state is RunningState.UNKNOWN ] if lost_or_momentarily_lost_tasks: _logger.warning( @@ -321,21 +325,30 @@ async def _get_changed_tasks_from_backend( user_id: UserID, processing_tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB, - ) -> list[tuple[_Previous, _Current]]: + ) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]: tasks_backend_status = await self._get_tasks_status( user_id, processing_tasks, comp_run ) - return [ - ( - task, - task.model_copy(update={"state": backend_state}), - ) - for task, backend_state in zip( - processing_tasks, tasks_backend_status, strict=True - ) - if task.state is not backend_state - ] + return ( + [ + ( + task, + task.model_copy(update={"state": backend_state}), + ) + for task, backend_state in zip( + processing_tasks, tasks_backend_status, strict=True + ) + if task.state is not backend_state + ], + [ + task + for task, backend_state in zip( + processing_tasks, tasks_backend_status, strict=True + ) + if task.state is backend_state is RunningState.STARTED + ], + ) async def _process_started_tasks( self, @@ -476,7 +489,10 @@ async def _update_states_from_comp_backend( return # get the tasks which state actually changed since last check - tasks_with_changed_states = await self._get_changed_tasks_from_backend( + ( + tasks_with_changed_states, + executing_tasks, + ) = await self._get_changed_tasks_from_backend( user_id, tasks_inprocess, comp_run ) # NOTE: typical states a task goes through @@ -488,9 +504,10 @@ async def _update_states_from_comp_backend( # now process the tasks if sorted_tasks.started: # NOTE: the dask-scheduler cannot differentiate between tasks that are effectively computing and - # tasks that are only queued and accepted by a dask-worker. + # tasks that are only queued and accepted by a dask-worker. We use dask plugins to report on tasks states + # states are published to log_event, and we directly publish into RabbitMQ the sidecar and services logs. # tasks_started should therefore be mostly empty but for cases where - # - dask Pub/Sub mechanism failed, the tasks goes from PENDING -> SUCCESS/FAILED/ABORTED without STARTED + # - dask log_event/subscribe_topic mechanism failed, the tasks goes from PENDING -> SUCCESS/FAILED/ABORTED without STARTED # - the task finished so fast that the STARTED state was skipped between 2 runs of the dv-2 comp scheduler await self._process_started_tasks( sorted_tasks.started, @@ -511,6 +528,9 @@ async def _update_states_from_comp_backend( if sorted_tasks.waiting: await self._process_waiting_tasks(sorted_tasks.waiting) + if executing_tasks: + await self._process_executing_tasks(user_id, executing_tasks, comp_run) + @abstractmethod async def _start_tasks( self, @@ -545,6 +565,15 @@ async def _process_completed_tasks( ) -> None: """process tasks from the 3rd party backend""" + @abstractmethod + async def _process_executing_tasks( + self, + user_id: UserID, + tasks: list[CompTaskAtDB], + comp_run: CompRunsAtDB, + ) -> None: + """process executing tasks from the 3rd party backend""" + async def apply( self, *, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index ab31bcd13679..c3c21fa0b8fa 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -23,6 +23,9 @@ from models_library.users import UserID from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE from servicelib.logging_utils import log_catch +from simcore_service_director_v2.modules.comp_scheduler._utils import ( + WAITING_FOR_START_STATES, +) from ...core.errors import ( ComputationalBackendNotConnectedError, @@ -31,7 +34,6 @@ ) from ...models.comp_runs import CompRunsAtDB, Iteration, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB -from ...models.dask_subsystem import DaskClientTaskState from ...utils.dask import ( clean_task_output_and_log_files_if_invalid, parse_output_data, @@ -45,25 +47,15 @@ from ..clusters_keeper import get_or_create_on_demand_cluster from ..dask_client import DaskClient, PublishedComputationTask from ..dask_clients_pool import DaskClientsPool -from ..db.repositories.comp_runs import CompRunsRepository +from ..db.repositories.comp_runs import ( + CompRunsRepository, +) from ..db.repositories.comp_tasks import CompTasksRepository from ._scheduler_base import BaseCompScheduler _logger = logging.getLogger(__name__) -_DASK_CLIENT_TASK_STATE_TO_RUNNING_STATE_MAP: dict[ - DaskClientTaskState, RunningState -] = { - DaskClientTaskState.PENDING: RunningState.PENDING, - DaskClientTaskState.NO_WORKER: RunningState.WAITING_FOR_RESOURCES, - DaskClientTaskState.LOST: RunningState.UNKNOWN, - DaskClientTaskState.ERRED: RunningState.FAILED, - DaskClientTaskState.ABORTED: RunningState.ABORTED, - DaskClientTaskState.SUCCESS: RunningState.SUCCESS, -} - - @asynccontextmanager async def _cluster_dask_client( user_id: UserID, @@ -159,32 +151,68 @@ async def _get_tasks_status( use_on_demand_clusters=comp_run.use_on_demand_clusters, run_metadata=comp_run.metadata, ) as client: - tasks_statuses = await client.get_tasks_status( - [f"{t.job_id}" for t in tasks] - ) - # process dask states - running_states: list[RunningState] = [] - for dask_task_state, task in zip(tasks_statuses, tasks, strict=True): - if dask_task_state is DaskClientTaskState.PENDING_OR_STARTED: - running_states += [ - ( - RunningState.STARTED - if task.progress is not None - else RunningState.PENDING - ) - ] - else: - running_states += [ - _DASK_CLIENT_TASK_STATE_TO_RUNNING_STATE_MAP.get( - dask_task_state, RunningState.UNKNOWN - ) - ] - return running_states + return await client.get_tasks_status([f"{t.job_id}" for t in tasks]) except ComputationalBackendOnDemandNotReadyError: _logger.info("The on demand computational backend is not ready yet...") return [RunningState.WAITING_FOR_CLUSTER] * len(tasks) + async def _process_executing_tasks( + self, + user_id: UserID, + tasks: list[CompTaskAtDB], + comp_run: CompRunsAtDB, + ) -> None: + task_progresses = [] + try: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_metadata=comp_run.metadata, + ) as client: + task_progresses = await client.get_tasks_progress( + [f"{t.job_id}" for t in tasks], + ) + for task_progress_event in task_progresses: + if task_progress_event: + await CompTasksRepository( + self.db_engine + ).update_project_task_progress( + task_progress_event.task_owner.project_id, + task_progress_event.task_owner.node_id, + task_progress_event.progress, + ) + + except ComputationalBackendOnDemandNotReadyError: + _logger.info("The on demand computational backend is not ready yet...") + + comp_tasks_repo = CompTasksRepository(self.db_engine) + await asyncio.gather( + *( + comp_tasks_repo.update_project_task_progress( + t.task_owner.project_id, + t.task_owner.node_id, + t.progress, + ) + for t in task_progresses + if t + ) + ) + await asyncio.gather( + *( + publish_service_progress( + self.rabbitmq_client, + user_id=t.task_owner.user_id, + project_id=t.task_owner.project_id, + node_id=t.task_owner.node_id, + progress=t.progress, + ) + for t in task_progresses + if t + ) + ) + async def _stop_tasks( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: @@ -356,7 +384,7 @@ async def _task_progress_change_handler( node_id = task_progress_event.task_owner.node_id comp_tasks_repo = CompTasksRepository(self.db_engine) task = await comp_tasks_repo.get_task(project_id, node_id) - if task.progress is None: + if task.state in WAITING_FOR_START_STATES: task.state = RunningState.STARTED task.progress = task_progress_event.progress run = await CompRunsRepository(self.db_engine).get(user_id, project_id) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index a5dcfc4fd7bc..7468f4bd63ad 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -8,15 +8,15 @@ """ +import asyncio import logging import traceback -from collections.abc import Callable +from collections.abc import Callable, Iterable from copy import deepcopy from dataclasses import dataclass from http.client import HTTPException from typing import Any, Final, cast -import dask.typing import distributed from aiohttp import ClientResponseError from common_library.json_serialization import json_dumps @@ -38,16 +38,22 @@ TaskOwner, ) from dask_task_models_library.container_tasks.utils import generate_dask_job_id -from dask_task_models_library.models import DaskJobID, DaskResources +from dask_task_models_library.models import ( + TASK_LIFE_CYCLE_EVENT, + TASK_RUNNING_PROGRESS_EVENT, + DaskJobID, + DaskResources, + TaskLifeCycleState, +) from dask_task_models_library.resource_constraints import ( create_ec2_resource_constraint_key, ) -from distributed.scheduler import TaskStateState as DaskSchedulerTaskState from fastapi import FastAPI from models_library.api_schemas_directorv2.clusters import ClusterDetails, Scheduler from models_library.clusters import ClusterAuthentication, ClusterTypeInModel from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID +from models_library.projects_state import RunningState from models_library.resource_tracker import HardwareInfo from models_library.services import ServiceRunID from models_library.users import UserID @@ -71,12 +77,12 @@ from ..core.settings import AppSettings, ComputationalBackendSettings from ..models.comp_runs import RunMetadataDict from ..models.comp_tasks import Image -from ..models.dask_subsystem import DaskClientTaskState from ..modules.storage import StorageClient from ..utils import dask as dask_utils from ..utils.dask_client_utils import ( DaskSubSystem, TaskHandlers, + UnixTimestamp, connect_to_dask_scheduler, ) from .db import get_db_engine @@ -84,25 +90,6 @@ _logger = logging.getLogger(__name__) -# NOTE: processing does not mean the task is currently being computed, it means -# the task was accepted by a worker, but might still be queud in it -# see https://distributed.dask.org/en/stable/scheduling-state.html#task-state - - -_DASK_TASK_STATUS_DASK_CLIENT_TASK_STATE_MAP: dict[ - DaskSchedulerTaskState, DaskClientTaskState -] = { - "queued": DaskClientTaskState.PENDING, - "released": DaskClientTaskState.PENDING, - "waiting": DaskClientTaskState.PENDING, - "no-worker": DaskClientTaskState.NO_WORKER, - "processing": DaskClientTaskState.PENDING_OR_STARTED, - "memory": DaskClientTaskState.SUCCESS, - "erred": DaskClientTaskState.ERRED, - "forgotten": DaskClientTaskState.LOST, -} - - _DASK_DEFAULT_TIMEOUT_S: Final[int] = 5 @@ -211,7 +198,7 @@ def _comp_sidecar_fct( ) -> TaskOutputData: """This function is serialized by the Dask client and sent over to the Dask sidecar(s) Therefore, (screaming here) DO NOT MOVE THAT IMPORT ANYWHERE ELSE EVER!!""" - from simcore_service_dask_sidecar.worker import ( # type: ignore[import-not-found] # this runs inside the dask-sidecar + from simcore_service_dask_sidecar.worker import ( # type: ignore[import-not-found] # this runs inside the dask-sidecar run_computational_sidecar, ) @@ -419,35 +406,47 @@ async def send_computation_tasks( return list_of_node_id_to_job_id - async def get_tasks_status(self, job_ids: list[str]) -> list[DaskClientTaskState]: + async def get_tasks_progress( + self, job_ids: list[str] + ) -> list[TaskProgressEvent | None]: dask_utils.check_scheduler_is_still_the_same( self.backend.scheduler_id, self.backend.client ) dask_utils.check_communication_with_scheduler_is_open(self.backend.client) dask_utils.check_scheduler_status(self.backend.client) - # try to get the task from the scheduler - def _get_pipeline_statuses( - dask_scheduler: distributed.Scheduler, - ) -> dict[dask.typing.Key, DaskSchedulerTaskState | None]: - statuses: dict[dask.typing.Key, DaskSchedulerTaskState | None] = ( - dask_scheduler.get_task_status(keys=job_ids) + async def _get_task_progress(job_id: str) -> TaskProgressEvent | None: + dask_events: tuple[tuple[UnixTimestamp, str], ...] = ( + await self.backend.client.get_events( + TASK_RUNNING_PROGRESS_EVENT.format(key=job_id) + ) ) - return statuses + if not dask_events: + return None + # we are interested in the last event + return TaskProgressEvent.model_validate_json(dask_events[-1][1]) - task_statuses: dict[dask.typing.Key, DaskSchedulerTaskState | None] = ( - await self.backend.client.run_on_scheduler(_get_pipeline_statuses) - ) - assert isinstance(task_statuses, dict) # nosec + return await asyncio.gather(*(_get_task_progress(job_id) for job_id in job_ids)) - _logger.debug("found dask task statuses: %s", f"{task_statuses=}") + async def get_tasks_status(self, job_ids: Iterable[str]) -> list[RunningState]: + dask_utils.check_scheduler_is_still_the_same( + self.backend.scheduler_id, self.backend.client + ) + dask_utils.check_communication_with_scheduler_is_open(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) - running_states: list[DaskClientTaskState] = [] - for job_id in job_ids: - dask_status = cast( - DaskSchedulerTaskState | None, task_statuses.get(job_id, "lost") + async def _get_task_state(job_id: str) -> RunningState: + dask_events: tuple[tuple[UnixTimestamp, str], ...] = ( + await self.backend.client.get_events( + TASK_LIFE_CYCLE_EVENT.format(key=job_id) + ) ) - if dask_status == "erred": + if not dask_events: + return RunningState.UNKNOWN + # we are interested in the last event + parsed_event = TaskLifeCycleState.model_validate(dask_events[-1][1]) + + if parsed_event.state == RunningState.FAILED: try: # find out if this was a cancellation var = distributed.Variable(job_id, client=self.backend.client) @@ -458,34 +457,26 @@ def _get_pipeline_statuses( assert isinstance(exception, Exception) # nosec if isinstance(exception, TaskCancelledError): - running_states.append(DaskClientTaskState.ABORTED) - else: - assert exception # nosec - _logger.warning( - "Task %s completed in error:\n%s\nTrace:\n%s", - job_id, - exception, - "".join(traceback.format_exception(exception)), - ) - running_states.append(DaskClientTaskState.ERRED) + return RunningState.ABORTED + assert exception # nosec + _logger.warning( + "Task %s completed in error:\n%s\nTrace:\n%s", + job_id, + exception, + "".join(traceback.format_exception(exception)), + ) + return RunningState.FAILED except TimeoutError: _logger.warning( "Task %s could not be retrieved from dask-scheduler, it is lost\n" "TIP:If the task was unpublished this can happen, or if the dask-scheduler was restarted.", job_id, ) - running_states.append(DaskClientTaskState.LOST) - - elif dask_status is None: - running_states.append(DaskClientTaskState.LOST) - else: - running_states.append( - _DASK_TASK_STATUS_DASK_CLIENT_TASK_STATE_MAP.get( - dask_status, DaskClientTaskState.LOST - ) - ) + return RunningState.UNKNOWN + + return parsed_event.state - return running_states + return await asyncio.gather(*(_get_task_state(job_id) for job_id in job_ids)) async def abort_computation_task(self, job_id: str) -> None: # Dask future may be cancelled, but only a future that was not already taken by @@ -536,7 +527,7 @@ async def release_task_result(self, job_id: str) -> None: # the variable was effectively deleted. # This is annoying as one can re-create the variable without error. var = distributed.Variable(job_id, client=self.backend.client) - var.delete() + await asyncio.get_event_loop().run_in_executor(None, var.delete) # first check if the key exists await dask_utils.wrap_client_async_routine( self.backend.client.get_dataset(name=job_id) diff --git a/services/director-v2/tests/unit/test_modules_dask_client.py b/services/director-v2/tests/unit/test_modules_dask_client.py index 370cc967190d..909c3c238de7 100644 --- a/services/director-v2/tests/unit/test_modules_dask_client.py +++ b/services/director-v2/tests/unit/test_modules_dask_client.py @@ -6,6 +6,7 @@ # pylint: disable=reimported import asyncio import functools +import logging import traceback from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine from dataclasses import dataclass @@ -43,11 +44,13 @@ from models_library.docker import to_simcore_runtime_docker_label_key from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID +from models_library.projects_state import RunningState from models_library.resource_tracker import HardwareInfo from models_library.services_types import ServiceRunID from models_library.users import UserID from pydantic import AnyUrl, ByteSize, TypeAdapter from pytest_mock.plugin import MockerFixture +from pytest_simcore.helpers.logging_tools import log_context from pytest_simcore.helpers.typing_env import EnvVarsDict from settings_library.s3 import S3Settings from simcore_sdk.node_ports_v2 import FileLinkType @@ -60,7 +63,6 @@ ) from simcore_service_director_v2.models.comp_runs import RunMetadataDict from simcore_service_director_v2.models.comp_tasks import Image -from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState from simcore_service_director_v2.modules.dask_client import DaskClient, TaskHandlers from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type @@ -90,7 +92,7 @@ async def _assert_wait_for_cb_call(mocked_fct, timeout: int | None = None): async def _assert_wait_for_task_status( job_id: str, dask_client: DaskClient, - expected_status: DaskClientTaskState, + expected_status: RunningState, timeout: int | None = None, # noqa: ASYNC109 ): async for attempt in AsyncRetrying( @@ -105,11 +107,11 @@ async def _assert_wait_for_task_status( f"Attempt={attempt.retry_state.attempt_number}" ) got = (await dask_client.get_tasks_status([job_id]))[0] - assert isinstance(got, DaskClientTaskState) + assert isinstance(got, RunningState) print(f"{got=} vs {expected_status=}") - if got is DaskClientTaskState.ERRED and expected_status not in [ - DaskClientTaskState.ERRED, - DaskClientTaskState.LOST, + if got is RunningState.FAILED and expected_status not in [ + RunningState.FAILED, + RunningState.UNKNOWN, ]: try: # we can fail fast here @@ -140,41 +142,48 @@ def _minimal_dask_config( @pytest.fixture async def create_dask_client_from_scheduler( _minimal_dask_config: None, - dask_spec_local_cluster: SpecCluster, + dask_spec_local_cluster: distributed.SpecCluster, minimal_app: FastAPI, tasks_file_link_type: FileLinkType, ) -> AsyncIterator[Callable[[], Awaitable[DaskClient]]]: created_clients = [] async def factory() -> DaskClient: - client = await DaskClient.create( - app=minimal_app, - settings=minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, - endpoint=TypeAdapter(AnyUrl).validate_python( - dask_spec_local_cluster.scheduler_address - ), - authentication=NoAuthentication(), - tasks_file_link_type=tasks_file_link_type, - cluster_type=ClusterTypeInModel.ON_PREMISE, - ) - assert client - assert client.app == minimal_app - assert ( - client.settings - == minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND - ) + with log_context( + logging.INFO, + f"Create director-v2 DaskClient to {dask_spec_local_cluster.scheduler_address}", + ) as ctx: + client = await DaskClient.create( + app=minimal_app, + settings=minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, + endpoint=TypeAdapter(AnyUrl).validate_python( + dask_spec_local_cluster.scheduler_address + ), + authentication=NoAuthentication(), + tasks_file_link_type=tasks_file_link_type, + cluster_type=ClusterTypeInModel.ON_PREMISE, + ) + assert client + assert client.app == minimal_app + assert ( + client.settings + == minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND + ) + + assert client.backend.client + scheduler_infos = client.backend.client.scheduler_info() # type: ignore + ctx.logger.info( + "%s", + f"--> Connected to scheduler via client {client=} to scheduler {scheduler_infos=}", + ) - assert client.backend.client - scheduler_infos = client.backend.client.scheduler_info() # type: ignore - print( - f"--> Connected to scheduler via client {client=} to scheduler {scheduler_infos=}" - ) created_clients.append(client) return client yield factory - await asyncio.gather(*[client.delete() for client in created_clients]) - print(f"<-- Disconnected scheduler clients {created_clients=}") + + with log_context(logging.INFO, "Disconnect scheduler clients"): + await asyncio.gather(*[client.delete() for client in created_clients]) @pytest.fixture(params=["create_dask_client_from_scheduler"]) @@ -516,7 +525,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task.job_id, dask_client, - expected_status=DaskClientTaskState.PENDING_OR_STARTED, + expected_status=RunningState.STARTED, ) # using the event we let the remote fct continue @@ -530,7 +539,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task.job_id, dask_client, - expected_status=DaskClientTaskState.SUCCESS, + expected_status=RunningState.SUCCESS, ) # check the results @@ -544,7 +553,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task.job_id, dask_client, - expected_status=DaskClientTaskState.LOST, + expected_status=RunningState.UNKNOWN, timeout=60, ) @@ -609,7 +618,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - expected_status=DaskClientTaskState.SUCCESS, + expected_status=RunningState.SUCCESS, ) assert published_computation_task[0].node_id in image_params.fake_tasks # creating a new future shows that it is not done???? @@ -704,7 +713,7 @@ def fake_remote_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - DaskClientTaskState.PENDING_OR_STARTED, + RunningState.STARTED, ) # we wait to be sure the remote fct is started @@ -721,20 +730,25 @@ def fake_remote_fct( await _assert_wait_for_cb_call(mocked_user_completed_cb) await _assert_wait_for_task_status( - published_computation_task[0].job_id, dask_client, DaskClientTaskState.ABORTED + published_computation_task[0].job_id, dask_client, RunningState.ABORTED ) # getting the results should throw the cancellation error with pytest.raises(TaskCancelledError): await dask_client.get_task_result(published_computation_task[0].job_id) - # after releasing the results, the task shall be UNKNOWN await dask_client.release_task_result(published_computation_task[0].job_id) + # after releasing the results, the task shall be UNKNOWN + _ALLOW_TIME_FOR_LOCAL_DASK_SCHEDULER_TO_UPDATE_TIMEOUT_S = 5 + await asyncio.sleep(_ALLOW_TIME_FOR_LOCAL_DASK_SCHEDULER_TO_UPDATE_TIMEOUT_S) # NOTE: this change of status takes a very long time to happen and is not relied upon so we skip it since it # makes the test fail a lot for no gain (it's kept here in case it ever becomes an issue) - # await _assert_wait_for_task_status( - # job_id, dask_client, RunningState.UNKNOWN, timeout=120 - # ) + await _assert_wait_for_task_status( + published_computation_task[0].job_id, + dask_client, + RunningState.UNKNOWN, + timeout=10, + ) async def test_failed_task_returns_exceptions( @@ -784,7 +798,7 @@ def fake_failing_sidecar_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - expected_status=DaskClientTaskState.ERRED, + expected_status=RunningState.FAILED, ) with pytest.raises( ValueError, @@ -1047,7 +1061,7 @@ def fake_remote_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - DaskClientTaskState.PENDING_OR_STARTED, + RunningState.STARTED, ) # let the remote fct run through now @@ -1057,7 +1071,7 @@ def fake_remote_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - DaskClientTaskState.ERRED if fail_remote_fct else DaskClientTaskState.SUCCESS, + RunningState.FAILED if fail_remote_fct else RunningState.SUCCESS, ) # release the task results await dask_client.release_task_result(published_computation_task[0].job_id) @@ -1069,7 +1083,7 @@ def fake_remote_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - DaskClientTaskState.LOST, + RunningState.UNKNOWN, timeout=60, ) @@ -1209,7 +1223,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - expected_status=DaskClientTaskState.PENDING_OR_STARTED, + expected_status=RunningState.STARTED, ) # check we have one worker using the resources @@ -1240,7 +1254,7 @@ def fake_sidecar_fct( await _assert_wait_for_task_status( published_computation_task[0].job_id, dask_client, - expected_status=DaskClientTaskState.SUCCESS, + expected_status=RunningState.SUCCESS, ) # check the resources are released diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py index 80666d7f6f03..717bc3381caa 100644 --- a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py @@ -60,7 +60,6 @@ from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB, Image -from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState from simcore_service_director_v2.modules.comp_scheduler._manager import ( run_new_pipeline, stop_pipeline, @@ -206,8 +205,8 @@ async def _assert_publish_in_dask_backend( for p in expected_pending_tasks: published_tasks.remove(p) - async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]: - return [DaskClientTaskState.PENDING for job_id in job_ids] + async def _return_tasks_pending(job_ids: list[str]) -> list[RunningState]: + return [RunningState.PENDING for job_id in job_ids] mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending assert published_project.project.prj_owner @@ -445,17 +444,16 @@ async def test_proper_pipeline_is_scheduled( # noqa: PLR0915 ) # ------------------------------------------------------------------------------- - # 2.1. the dask-worker might be taking the task, until we get a progress we do not know - # whether it effectively started or it is still queued in the worker process + # 2.1. the dask-worker takes the task exp_started_task = expected_pending_tasks[0] expected_pending_tasks.remove(exp_started_task) - async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.PENDING_OR_STARTED + RunningState.STARTED if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -469,7 +467,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta await assert_comp_runs( sqlalchemy_async_engine, expected_total=1, - expected_state=RunningState.PENDING, + expected_state=RunningState.STARTED, where_statement=and_( comp_runs.c.user_id == published_project.project.prj_owner, comp_runs.c.project_uuid == f"{published_project.project.uuid}", @@ -478,8 +476,14 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta await assert_comp_tasks( sqlalchemy_async_engine, project_uuid=published_project.project.uuid, - task_ids=[exp_started_task.node_id] - + [p.node_id for p in expected_pending_tasks], + task_ids=[exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=None, + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) @@ -572,12 +576,12 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 4. the dask-worker completed the task successfully - async def _return_1st_task_success(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_1st_task_success(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.SUCCESS + RunningState.SUCCESS if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -679,12 +683,12 @@ async def _return_random_task_result(job_id) -> TaskOutputData: # 6. the dask-worker starts processing a task exp_started_task = next_pending_task - async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_2nd_task_running(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.PENDING_OR_STARTED + RunningState.STARTED if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -743,12 +747,12 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 7. the task fails - async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_2nd_task_failed(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.ERRED + RunningState.FAILED if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -805,12 +809,12 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskStat # 8. the last task shall succeed exp_started_task = expected_pending_tasks[0] - async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_3rd_task_success(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.SUCCESS + RunningState.SUCCESS if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -917,12 +921,12 @@ async def with_started_project( exp_started_task = expected_pending_tasks[0] expected_pending_tasks.remove(exp_started_task) - async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.PENDING_OR_STARTED + RunningState.STARTED if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ] @@ -939,7 +943,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta await assert_comp_runs( sqlalchemy_async_engine, expected_total=1, - expected_state=RunningState.PENDING, + expected_state=RunningState.STARTED, where_statement=and_( comp_runs.c.user_id == published_project.project.prj_owner, comp_runs.c.project_uuid == f"{published_project.project.uuid}", @@ -948,8 +952,14 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta await assert_comp_tasks( sqlalchemy_async_engine, project_uuid=published_project.project.uuid, - task_ids=[exp_started_task.node_id] - + [p.node_id for p in expected_pending_tasks], + task_ids=[exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=None, + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) @@ -1308,7 +1318,7 @@ async def test_handling_of_disconnected_scheduler_dask( @dataclass(frozen=True, kw_only=True) class RebootState: - dask_task_status: DaskClientTaskState + dask_task_status: RunningState task_result: Exception | TaskOutputData expected_task_state_group1: RunningState expected_task_progress_group1: float @@ -1322,7 +1332,7 @@ class RebootState: [ pytest.param( RebootState( - dask_task_status=DaskClientTaskState.LOST, + dask_task_status=RunningState.UNKNOWN, task_result=ComputationalBackendTaskNotFoundError(job_id="fake_job_id"), expected_task_state_group1=RunningState.FAILED, expected_task_progress_group1=1, @@ -1334,7 +1344,7 @@ class RebootState: ), pytest.param( RebootState( - dask_task_status=DaskClientTaskState.ABORTED, + dask_task_status=RunningState.ABORTED, task_result=TaskCancelledError(job_id="fake_job_id"), expected_task_state_group1=RunningState.ABORTED, expected_task_progress_group1=1, @@ -1346,7 +1356,7 @@ class RebootState: ), pytest.param( RebootState( - dask_task_status=DaskClientTaskState.ERRED, + dask_task_status=RunningState.FAILED, task_result=ValueError("some error during the call"), expected_task_state_group1=RunningState.FAILED, expected_task_progress_group1=1, @@ -1358,7 +1368,7 @@ class RebootState: ), pytest.param( RebootState( - dask_task_status=DaskClientTaskState.PENDING_OR_STARTED, + dask_task_status=RunningState.STARTED, task_result=ComputationalBackendTaskResultsNotReadyError( job_id="fake_job_id" ), @@ -1372,7 +1382,7 @@ class RebootState: ), pytest.param( RebootState( - dask_task_status=DaskClientTaskState.SUCCESS, + dask_task_status=RunningState.SUCCESS, task_result=TaskOutputData.model_validate({"whatever_output": 123}), expected_task_state_group1=RunningState.SUCCESS, expected_task_progress_group1=1, @@ -1399,7 +1409,7 @@ async def test_handling_scheduled_tasks_after_director_reboots( shall continue scheduling correctly. Even though the task might have continued to run in the dask-scheduler.""" - async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]: + async def mocked_get_tasks_status(job_ids: list[str]) -> list[RunningState]: return [reboot_state.dask_task_status for j in job_ids] mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status @@ -1514,8 +1524,8 @@ async def test_handling_cancellation_of_jobs_after_reboot( ) # the backend shall report the tasks as running - async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]: - return [DaskClientTaskState.PENDING_OR_STARTED for j in job_ids] + async def mocked_get_tasks_status(job_ids: list[str]) -> list[RunningState]: + return [RunningState.STARTED for j in job_ids] mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status # Running the scheduler, should actually cancel the run now @@ -1559,8 +1569,8 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat # the backend shall now report the tasks as aborted async def mocked_get_tasks_status_aborted( job_ids: list[str], - ) -> list[DaskClientTaskState]: - return [DaskClientTaskState.ABORTED for j in job_ids] + ) -> list[RunningState]: + return [RunningState.ABORTED for j in job_ids] mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status_aborted @@ -1641,12 +1651,12 @@ async def test_running_pipeline_triggers_heartbeat( exp_started_task = expected_pending_tasks[0] expected_pending_tasks.remove(exp_started_task) - async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + async def _return_1st_task_running(job_ids: list[str]) -> list[RunningState]: return [ ( - DaskClientTaskState.PENDING_OR_STARTED + RunningState.STARTED if job_id == exp_started_task.job_id - else DaskClientTaskState.PENDING + else RunningState.PENDING ) for job_id in job_ids ]