diff --git a/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py index ebc6aabcad85..ad7135d94d0c 100644 --- a/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py +++ b/packages/dask-task-models-library/src/dask_task_models_library/plugins/task_life_cycle_worker_plugin.py @@ -34,6 +34,7 @@ def transition( ): _logger.info("Task '%s' transition from %s to %s", key, start, finish) assert self._worker # nosec + assert isinstance(self._worker, Worker) # nosec self._worker.log_event( TASK_LIFE_CYCLE_EVENT.format(key=key), TaskLifeCycleState.from_worker_task_state( diff --git a/packages/pytest-simcore/src/pytest_simcore/pydantic_models.py b/packages/pytest-simcore/src/pytest_simcore/pydantic_models.py index e8691a10724c..8266de0947bb 100644 --- a/packages/pytest-simcore/src/pytest_simcore/pydantic_models.py +++ b/packages/pytest-simcore/src/pytest_simcore/pydantic_models.py @@ -97,7 +97,6 @@ def _is_model_cls(obj) -> bool: assert inspect.ismodule(module) for model_name, model_cls in inspect.getmembers(module, _is_model_cls): - yield from iter_model_examples_in_class(model_cls, model_name) @@ -172,7 +171,7 @@ def model_cls_examples(model_cls: type[BaseModel]) -> dict[str, dict[str, Any]]: """ warnings.warn( "The 'model_cls_examples' fixture is deprecated and will be removed in a future version. " - "Please use 'iter_model_example_in_class' or 'iter_model_examples_in_module' as an alternative.", + "Please use 'iter_model_examples_in_class' or 'iter_model_examples_in_module' as an alternative.", DeprecationWarning, stacklevel=2, ) diff --git a/services/director-v2/src/simcore_service_director_v2/core/settings.py b/services/director-v2/src/simcore_service_director_v2/core/settings.py index 28208ec34ff9..7ec136b65d51 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/settings.py +++ b/services/director-v2/src/simcore_service_director_v2/core/settings.py @@ -99,6 +99,14 @@ class ComputationalBackendSettings(BaseCustomSettings): ), ] = datetime.timedelta(minutes=10) + COMPUTATIONAL_BACKEND_MAX_WAITING_FOR_RETRIEVING_RESULTS: Annotated[ + datetime.timedelta, + Field( + description="maximum time the computational scheduler waits until retrieving results from the computational backend is failed" + "(default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formatting)." + ), + ] = datetime.timedelta(minutes=10) + @cached_property def default_cluster(self) -> BaseCluster: return BaseCluster( diff --git a/services/director-v2/src/simcore_service_director_v2/models/comp_run_snapshot_tasks.py b/services/director-v2/src/simcore_service_director_v2/models/comp_run_snapshot_tasks.py index 435ba460d017..66945da11d1c 100644 --- a/services/director-v2/src/simcore_service_director_v2/models/comp_run_snapshot_tasks.py +++ b/services/director-v2/src/simcore_service_director_v2/models/comp_run_snapshot_tasks.py @@ -1,3 +1,4 @@ +from contextlib import suppress from datetime import datetime from typing import Annotated, Any @@ -5,8 +6,16 @@ from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState from models_library.resource_tracker import HardwareInfo -from pydantic import BaseModel, BeforeValidator, ConfigDict, PositiveInt +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + PositiveInt, + field_validator, +) +from simcore_postgres_database.models.comp_pipeline import StateType +from ..utils.db import DB_TO_RUNNING_STATE from .comp_tasks import BaseCompTaskAtDB, Image @@ -100,3 +109,15 @@ class CompRunSnapshotTaskDBGet(BaseModel): started_at: datetime | None ended_at: datetime | None iteration: PositiveInt + + @field_validator("state", mode="before") + @classmethod + def convert_result_from_state_type_enum_if_needed(cls, v): + if isinstance(v, str): + # try to convert to a StateType, if it fails the validations will continue + # and pydantic will try to convert it to a RunninState later on + with suppress(ValueError): + v = StateType(v) + if isinstance(v, StateType): + return RunningState(DB_TO_RUNNING_STATE[StateType(v)]) + return v diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 1fb4564ab246..2638271d1a0f 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -23,12 +23,15 @@ from models_library.users import UserID from pydantic import PositiveInt from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.logging_utils import log_catch, log_context +from servicelib.utils import limited_as_completed from ...core.errors import ( ComputationalBackendNotConnectedError, ComputationalBackendOnDemandNotReadyError, - TaskSchedulingError, + ComputationalBackendTaskResultsNotReadyError, + PortsValidationError, ) from ...models.comp_runs import CompRunsAtDB, Iteration, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB @@ -49,6 +52,9 @@ CompRunsRepository, ) from ..db.repositories.comp_tasks import CompTasksRepository +from ._constants import ( + MAX_CONCURRENT_PIPELINE_SCHEDULING, +) from ._scheduler_base import BaseCompScheduler from ._utils import ( WAITING_FOR_START_STATES, @@ -57,6 +63,8 @@ _logger = logging.getLogger(__name__) _DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{project_id}:{run_id}" +_TASK_RETRIEVAL_ERROR_TYPE: Final[str] = "task-result-retrieval-timeout" +_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time" @asynccontextmanager @@ -289,39 +297,173 @@ async def _process_completed_tasks( iteration: Iteration, comp_run: CompRunsAtDB, ) -> None: - try: - async with _cluster_dask_client( - user_id, - self, - use_on_demand_clusters=comp_run.use_on_demand_clusters, - project_id=comp_run.project_uuid, - run_id=comp_run.run_id, - run_metadata=comp_run.metadata, - ) as client: - tasks_results = await asyncio.gather( - *[client.get_task_result(t.job_id or "undefined") for t in tasks], - return_exceptions=True, - ) - await asyncio.gather( - *[ + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + project_id=comp_run.project_uuid, + run_id=comp_run.run_id, + run_metadata=comp_run.metadata, + ) as client: + tasks_results = await asyncio.gather( + *[client.get_task_result(t.job_id or "undefined") for t in tasks], + return_exceptions=True, + ) + async for future in limited_as_completed( + ( self._process_task_result( task, result, comp_run.metadata, iteration, comp_run.run_id ) for task, result in zip(tasks, tasks_results, strict=True) - ] + ), + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + ): + with log_catch(_logger, reraise=False): + task_can_be_cleaned, job_id = await future + if task_can_be_cleaned: + await client.release_task_result(job_id) + + async def _handle_successful_run( + self, + task: CompTaskAtDB, + result: TaskOutputData, + log_error_context: dict[str, Any], + ) -> tuple[RunningState, SimcorePlatformStatus, list[ErrorDict], bool]: + assert task.job_id # nosec + try: + await parse_output_data( + self.db_engine, + task.job_id, + result, ) - finally: - async with _cluster_dask_client( - user_id, - self, - use_on_demand_clusters=comp_run.use_on_demand_clusters, - project_id=comp_run.project_uuid, - run_id=comp_run.run_id, - run_metadata=comp_run.metadata, - ) as client: - await asyncio.gather( - *[client.release_task_result(t.job_id) for t in tasks if t.job_id] + return RunningState.SUCCESS, SimcorePlatformStatus.OK, [], True + except PortsValidationError as err: + _logger.exception( + **create_troubleshootting_log_kwargs( + "Unexpected error while parsing output data, comp_tasks/comp_pipeline is not in sync with what was started", + error=err, + error_context=log_error_context, ) + ) + # NOTE: simcore platform state is still OK as the task ran fine, the issue is likely due to the service labels + return RunningState.FAILED, SimcorePlatformStatus.OK, err.get_errors(), True + + async def _handle_computational_retrieval_error( + self, + task: CompTaskAtDB, + user_id: UserID, + result: ComputationalBackendTaskResultsNotReadyError, + log_error_context: dict[str, Any], + ) -> tuple[RunningState, SimcorePlatformStatus, list[ErrorDict], bool]: + assert task.job_id # nosec + _logger.warning( + **create_troubleshootting_log_kwargs( + f"Retrieval of task {task.job_id} result timed-out", + error=result, + error_context=log_error_context, + tip="This can happen if the computational backend is overloaded with requests. It will be automatically retried again.", + ) + ) + task_errors: list[ErrorDict] = [] + check_time = arrow.utcnow() + if task.errors: + for error in task.errors: + if error["type"] == _TASK_RETRIEVAL_ERROR_TYPE: + # already had a timeout error, let's keep it + task_errors.append(error) + assert "ctx" in error # nosec + assert ( + _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY in error["ctx"] + ) # nosec + check_time = arrow.get( + error["ctx"][_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY] + ) + break + if not task_errors: + # first time we have this error + task_errors.append( + ErrorDict( + loc=(f"{task.project_id}", f"{task.node_id}"), + msg=f"{result}", + type=_TASK_RETRIEVAL_ERROR_TYPE, + ctx={ + _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: f"{check_time}", + "user_id": user_id, + "project_id": f"{task.project_id}", + "node_id": f"{task.node_id}", + "job_id": task.job_id, + }, + ) + ) + + # if the task has been running for too long, we consider it failed + elapsed_time = arrow.utcnow() - check_time + if ( + elapsed_time + > self.settings.COMPUTATIONAL_BACKEND_MAX_WAITING_FOR_RETRIEVING_RESULTS + ): + return RunningState.FAILED, SimcorePlatformStatus.BAD, task_errors, True + # state is kept as STARTED so it will be retried + return RunningState.STARTED, SimcorePlatformStatus.BAD, task_errors, False + + @staticmethod + async def _handle_computational_backend_not_connected_error( + task: CompTaskAtDB, + result: ComputationalBackendNotConnectedError, + log_error_context: dict[str, Any], + ) -> tuple[RunningState, SimcorePlatformStatus, list[ErrorDict], bool]: + assert task.job_id # nosec + _logger.warning( + **create_troubleshootting_log_kwargs( + f"Computational backend disconnected when retrieving task {task.job_id} result", + error=result, + error_context=log_error_context, + tip="This can happen if the computational backend is temporarily disconnected. It will be automatically retried again.", + ) + ) + # NOTE: the task will be set to UNKNOWN on the next processing loop + + # state is kept as STARTED so it will be retried + return RunningState.STARTED, SimcorePlatformStatus.BAD, [], False + + @staticmethod + async def _handle_task_error( + task: CompTaskAtDB, + result: BaseException, + log_error_context: dict[str, Any], + ) -> tuple[RunningState, SimcorePlatformStatus, list[ErrorDict], bool]: + assert task.job_id # nosec + + # the task itself failed, check why + if isinstance(result, TaskCancelledError): + _logger.info( + **create_troubleshootting_log_kwargs( + f"Task {task.job_id} was cancelled", + error=result, + error_context=log_error_context, + ) + ) + return RunningState.ABORTED, SimcorePlatformStatus.OK, [], True + + _logger.info( + **create_troubleshootting_log_kwargs( + f"Task {task.job_id} completed with errors", + error=result, + error_context=log_error_context, + ) + ) + return ( + RunningState.FAILED, + SimcorePlatformStatus.OK, + [ + ErrorDict( + loc=(f"{task.project_id}", f"{task.node_id}"), + msg=f"{result}", + type="runtime", + ) + ], + True, + ) async def _process_task_result( self, @@ -330,65 +472,68 @@ async def _process_task_result( run_metadata: RunMetadataDict, iteration: Iteration, run_id: PositiveInt, - ) -> None: + ) -> tuple[bool, str]: + """Returns True and the job ID if the task was successfully processed and can be released from the Dask cluster.""" _logger.debug("received %s result: %s", f"{task=}", f"{result=}") - task_final_state = RunningState.FAILED - simcore_platform_status = SimcorePlatformStatus.OK - errors: list[ErrorDict] = [] - if task.job_id is not None: + assert task.job_id # nosec + ( + _service_key, + _service_version, + user_id, + project_id, + node_id, + ) = parse_dask_job_id(task.job_id) + + assert task.project_id == project_id # nosec + assert task.node_id == node_id # nosec + log_error_context = { + "user_id": user_id, + "project_id": project_id, + "node_id": node_id, + "job_id": task.job_id, + } + + if isinstance(result, TaskOutputData): ( - _service_key, - _service_version, - user_id, - project_id, - node_id, - ) = parse_dask_job_id(task.job_id) - - assert task.project_id == project_id # nosec - assert task.node_id == node_id # nosec - - try: - if isinstance(result, TaskOutputData): - # success! - await parse_output_data( - self.db_engine, - task.job_id, - result, - ) - task_final_state = RunningState.SUCCESS - - else: - if isinstance(result, TaskCancelledError): - task_final_state = RunningState.ABORTED - else: - task_final_state = RunningState.FAILED - errors.append( - { - "loc": ( - f"{task.project_id}", - f"{task.node_id}", - ), - "msg": f"{result}", - "type": "runtime", - } - ) - if isinstance(result, ComputationalBackendNotConnectedError): - simcore_platform_status = SimcorePlatformStatus.BAD - # we need to remove any invalid files in the storage - await clean_task_output_and_log_files_if_invalid( - self.db_engine, user_id, project_id, node_id - ) - except TaskSchedulingError as err: - task_final_state = RunningState.FAILED - simcore_platform_status = SimcorePlatformStatus.BAD - errors = err.get_errors() - _logger.debug( - "Unexpected failure while processing results of %s: %s", - f"{task=}", - f"{errors=}", - ) + task_final_state, + simcore_platform_status, + task_errors, + task_completed, + ) = await self._handle_successful_run(task, result, log_error_context) + + elif isinstance(result, ComputationalBackendTaskResultsNotReadyError): + ( + task_final_state, + simcore_platform_status, + task_errors, + task_completed, + ) = await self._handle_computational_retrieval_error( + task, user_id, result, log_error_context + ) + elif isinstance(result, ComputationalBackendNotConnectedError): + ( + task_final_state, + simcore_platform_status, + task_errors, + task_completed, + ) = await self._handle_computational_backend_not_connected_error( + task, result, log_error_context + ) + else: + ( + task_final_state, + simcore_platform_status, + task_errors, + task_completed, + ) = await self._handle_task_error(task, result, log_error_context) + # we need to remove any invalid files in the storage + await clean_task_output_and_log_files_if_invalid( + self.db_engine, user_id, project_id, node_id + ) + + if task_completed: # resource tracking await publish_service_resource_tracking_stopped( self.rabbitmq_client, @@ -412,12 +557,14 @@ async def _process_task_result( task.project_id, run_id, [task.node_id], - task_final_state, - errors=errors, - optional_progress=1, - optional_stopped=arrow.utcnow().datetime, + task_final_state if task_completed else RunningState.STARTED, + errors=task_errors, + optional_progress=1 if task_completed else None, + optional_stopped=arrow.utcnow().datetime if task_completed else None, ) + return task_completed, task.job_id + async def _task_progress_change_handler( self, event: tuple[UnixTimestamp, Any] ) -> None: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index fd09e3db015d..17e1697495cf 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -10,12 +10,10 @@ import asyncio import logging -import traceback from collections.abc import Callable, Iterable -from copy import deepcopy from dataclasses import dataclass from http.client import HTTPException -from typing import Any, Final, cast +from typing import Final, cast import distributed from aiohttp import ClientResponseError @@ -49,7 +47,6 @@ create_ec2_resource_constraint_key, ) from fastapi import FastAPI -from models_library.api_schemas_directorv2.clusters import ClusterDetails, Scheduler from models_library.clusters import ClusterAuthentication, ClusterTypeInModel from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID @@ -57,9 +54,10 @@ from models_library.resource_tracker import HardwareInfo from models_library.services import ServiceRunID from models_library.users import UserID -from pydantic import TypeAdapter, ValidationError +from pydantic import ValidationError from pydantic.networks import AnyUrl -from servicelib.logging_utils import log_catch, log_context +from servicelib.logging_errors import create_troubleshootting_log_kwargs +from servicelib.logging_utils import log_context from settings_library.s3 import S3Settings from simcore_sdk.node_ports_common.exceptions import NodeportsException from simcore_sdk.node_ports_v2 import FileLinkType @@ -448,28 +446,56 @@ async def _get_task_state(job_id: str) -> RunningState: if parsed_event.state == RunningState.FAILED: try: # find out if this was a cancellation - var = distributed.Variable(job_id, client=self.backend.client) - future: distributed.Future = await var.get( + task_future: distributed.Future = ( + await dask_utils.wrap_client_async_routine( + self.backend.client.get_dataset(name=job_id) + ) + ) + exception = await task_future.exception( timeout=_DASK_DEFAULT_TIMEOUT_S ) - exception = await future.exception(timeout=_DASK_DEFAULT_TIMEOUT_S) assert isinstance(exception, Exception) # nosec - + log_error_context = { + "job_id": job_id, + "dask-scheduler": self.backend.scheduler_id, + } if isinstance(exception, TaskCancelledError): + _logger.info( + **create_troubleshootting_log_kwargs( + f"Task {job_id} was aborted by user", + error=exception, + error_context=log_error_context, + ) + ) return RunningState.ABORTED assert exception # nosec - _logger.warning( - "Task %s completed in error:\n%s\nTrace:\n%s", - job_id, - exception, - "".join(traceback.format_exception(exception)), + _logger.info( + **create_troubleshootting_log_kwargs( + f"Task {job_id} completed with an error", + error=exception, + error_context=log_error_context, + ) ) return RunningState.FAILED - except TimeoutError: + except TimeoutError as exc: + _logger.exception( + **create_troubleshootting_log_kwargs( + f"Task {job_id} exception could not be retrieved due to timeout", + error=exc, + error_context=log_error_context, + tip="The dask-scheduler is probably under load, this should resolve itself later.", + ), + ) + return RunningState.UNKNOWN + except KeyError as exc: + # the task does not exist _logger.warning( - "Task %s could not be retrieved from dask-scheduler, it is lost\n" - "TIP:If the task was unpublished this can happen, or if the dask-scheduler was restarted.", - job_id, + **create_troubleshootting_log_kwargs( + f"Task {job_id} not found. State is UNKNOWN.", + error=exc, + error_context=log_error_context, + tip="If the task is supposed to exist, the dask-schdeler has probably restarted. Check its status.", + ), ) return RunningState.UNKNOWN @@ -503,6 +529,8 @@ async def abort_computation_task(self, job_id: str) -> None: async def get_task_result(self, job_id: str) -> TaskOutputData: _logger.debug("getting result of %s", f"{job_id=}") + dask_utils.check_communication_with_scheduler_is_open(self.backend.client) + dask_utils.check_scheduler_status(self.backend.client) try: task_future: distributed.Future = ( await dask_utils.wrap_client_async_routine( @@ -538,50 +566,3 @@ async def release_task_result(self, job_id: str) -> None: except KeyError: _logger.warning("Unknown task cannot be unpublished: %s", f"{job_id=}") - - async def get_cluster_details(self) -> ClusterDetails: - dask_utils.check_scheduler_is_still_the_same( - self.backend.scheduler_id, self.backend.client - ) - dask_utils.check_communication_with_scheduler_is_open(self.backend.client) - dask_utils.check_scheduler_status(self.backend.client) - scheduler_info = self.backend.client.scheduler_info() - scheduler_status = self.backend.client.status - dashboard_link = self.backend.client.dashboard_link - - def _get_worker_used_resources( - dask_scheduler: distributed.Scheduler, - ) -> dict[str, dict]: - used_resources = {} - for worker_name, worker_state in dask_scheduler.workers.items(): - used_resources[worker_name] = worker_state.used_resources - return used_resources - - with log_catch(_logger, reraise=False): - # NOTE: this runs directly on the dask-scheduler and may rise exceptions - used_resources_per_worker: dict[str, dict[str, Any]] = ( - await dask_utils.wrap_client_async_routine( - self.backend.client.run_on_scheduler(_get_worker_used_resources) - ) - ) - - # let's update the scheduler info, with default to 0s since sometimes - # workers are destroyed/created without us knowing right away - for worker_name, worker_info in scheduler_info.get("workers", {}).items(): - used_resources: dict[str, float] = deepcopy( - worker_info.get("resources", {}) - ) - # reset default values - for res_name in used_resources: - used_resources[res_name] = 0 - # if the scheduler has info, let's override them - used_resources = used_resources_per_worker.get( - worker_name, used_resources - ) - worker_info.update(used_resources=used_resources) - - assert dashboard_link # nosec - return ClusterDetails( - scheduler=Scheduler(status=scheduler_status, **scheduler_info), - dashboard_link=TypeAdapter(AnyUrl).validate_python(dashboard_link), - ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index 2f471360930e..0a28464bb347 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -386,11 +386,14 @@ async def list_for_user_and_project_all_iterations( total_count = await conn.scalar(count_query) items = [ - ComputationRunRpcGet.model_validate( - { - **row, - "state": DB_TO_RUNNING_STATE[row["state"]], - } + ComputationRunRpcGet( + project_uuid=row.project_uuid, + iteration=row.iteration, + state=DB_TO_RUNNING_STATE[row.state], + info=row.info, + submitted_at=row.submitted_at, + started_at=row.started_at, + ended_at=row.ended_at, ) async for row in await conn.stream(list_query) ] diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs_snapshot_tasks.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs_snapshot_tasks.py index 23db21a26eb8..53bfb47f0d4f 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs_snapshot_tasks.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs_snapshot_tasks.py @@ -14,7 +14,6 @@ ) from ....models.comp_run_snapshot_tasks import CompRunSnapshotTaskDBGet -from ....utils.db import DB_TO_RUNNING_STATE from ..tables import comp_run_snapshot_tasks, comp_runs from ._base import BaseRepository @@ -22,7 +21,6 @@ class CompRunsSnapshotTasksRepository(BaseRepository): - async def batch_create( self, *, data: list[dict] ) -> None: # list[CompRunSnapshotTaskAtDBGet]: @@ -33,7 +31,6 @@ async def batch_create( return async with transaction_context(self.db_engine) as conn: - try: await conn.execute( comp_run_snapshot_tasks.insert().returning( @@ -111,12 +108,7 @@ async def list_computation_collection_run_tasks( total_count = await conn.scalar(count_query) items = [ - CompRunSnapshotTaskDBGet.model_validate( - { - **row, - "state": DB_TO_RUNNING_STATE[row["state"]], # Convert the state - } - ) + CompRunSnapshotTaskDBGet.model_validate(row, from_attributes=True) async for row in await conn.stream(list_query) ] return cast(int, total_count), items diff --git a/services/director-v2/tests/unit/test_models_comp_runs.py b/services/director-v2/tests/unit/test_models_comp_runs.py index 5505982f2d1c..af71c92c9e61 100644 --- a/services/director-v2/tests/unit/test_models_comp_runs.py +++ b/services/director-v2/tests/unit/test_models_comp_runs.py @@ -8,35 +8,33 @@ import pytest from models_library.projects_state import RunningState from pydantic.main import BaseModel +from pytest_simcore.pydantic_models import ( + assert_validation_model, + iter_model_examples_in_class, +) from simcore_service_director_v2.models.comp_runs import CompRunsAtDB @pytest.mark.parametrize( - "model_cls", - [ - CompRunsAtDB, - ], + "model_cls, example_name, example_data", + iter_model_examples_in_class(CompRunsAtDB), ) def test_computation_run_model_examples( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] + model_cls: type[BaseModel], example_name: str, example_data: dict[str, Any] ): - for name, example in model_cls_examples.items(): - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" + assert_validation_model( + model_cls, example_name=example_name, example_data=example_data + ) @pytest.mark.parametrize( - "model_cls", - [ - CompRunsAtDB, - ], + "model_cls, example_name, example_data", + iter_model_examples_in_class(CompRunsAtDB), ) def test_computation_run_model_with_run_result_value_field( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] + model_cls: type[BaseModel], example_name: str, example_data: dict[str, Any] ): - for name, example in model_cls_examples.items(): - example["result"] = RunningState.WAITING_FOR_RESOURCES.value - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" + example_data["result"] = RunningState.WAITING_FOR_RESOURCES.value + print(example_name, ":", pformat(example_data)) + model_instance = model_cls(**example_data) + assert model_instance, f"Failed with {example_name}" diff --git a/services/director-v2/tests/unit/test_models_comp_tasks.py b/services/director-v2/tests/unit/test_models_comp_tasks.py index 6898acface49..f03e103e728b 100644 --- a/services/director-v2/tests/unit/test_models_comp_tasks.py +++ b/services/director-v2/tests/unit/test_models_comp_tasks.py @@ -2,73 +2,58 @@ # pylint:disable=unused-argument # pylint:disable=redefined-outer-name -from pprint import pformat from typing import Any import pytest from models_library.projects_state import RunningState from pydantic.main import BaseModel +from pytest_simcore.pydantic_models import ( + assert_validation_model, + iter_model_examples_in_class, +) from simcore_postgres_database.models.comp_pipeline import StateType from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB @pytest.mark.parametrize( - "model_cls", - (CompTaskAtDB,), + "model_cls, example_name, example_data", + iter_model_examples_in_class(CompTaskAtDB), ) def test_computation_task_model_examples( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] -): - for name, example in model_cls_examples.items(): - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" - - -@pytest.mark.parametrize( - "model_cls", - [CompTaskAtDB], -) -def test_computation_task_model_export_to_db_model( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] + model_cls: type[BaseModel], example_name: str, example_data: dict[str, Any] ): - for name, example in model_cls_examples.items(): - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" + model_instance = assert_validation_model( + model_cls, example_name=example_name, example_data=example_data + ) - assert isinstance(model_instance, CompTaskAtDB) - db_model = model_instance.to_db_model() + assert isinstance(model_instance, CompTaskAtDB) + db_model = model_instance.to_db_model() - assert isinstance(db_model, dict) - assert StateType(db_model["state"]) + assert isinstance(db_model, dict) + assert StateType(db_model["state"]) @pytest.mark.parametrize( - "model_cls", - [CompTaskAtDB], + "model_cls, example_name, example_data", + iter_model_examples_in_class(CompTaskAtDB), ) def test_computation_task_model_with_running_state_value_field( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] + model_cls: type[BaseModel], example_name: str, example_data: dict[str, Any] ): - for name, example in model_cls_examples.items(): - example["state"] = RunningState.WAITING_FOR_RESOURCES.value - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" + example_data["state"] = RunningState.WAITING_FOR_RESOURCES.value + model_instance = model_cls(**example_data) + assert model_instance, f"Failed with {example_name}" @pytest.mark.parametrize( - "model_cls", - [CompTaskAtDB], + "model_cls, example_name, example_data", + iter_model_examples_in_class(CompTaskAtDB), ) def test_computation_task_model_with_wrong_default_value_field( - model_cls: type[BaseModel], model_cls_examples: dict[str, dict[str, Any]] + model_cls: type[BaseModel], example_name: str, example_data: dict[str, Any] ): - for name, example in model_cls_examples.items(): - for output_schema in example.get("schema", {}).get("outputs", {}).values(): - output_schema["defaultValue"] = None + for output_schema in example_data.get("schema", {}).get("outputs", {}).values(): + output_schema["defaultValue"] = None - print(name, ":", pformat(example)) - model_instance = model_cls(**example) - assert model_instance, f"Failed with {name}" + model_instance = model_cls(**example_data) + assert model_instance, f"Failed with {example_name}" diff --git a/services/director-v2/tests/unit/test_modules_dask_client.py b/services/director-v2/tests/unit/test_modules_dask_client.py index 902c98b1e7e6..d9b205b5cba9 100644 --- a/services/director-v2/tests/unit/test_modules_dask_client.py +++ b/services/director-v2/tests/unit/test_modules_dask_client.py @@ -1162,108 +1162,3 @@ def fake_remote_fct( (mock.ANY, "my name is progress") ) await _assert_wait_for_cb_call(mocked_user_completed_cb) - - -async def test_get_cluster_details( - dask_client: DaskClient, - user_id: UserID, - project_id: ProjectID, - image_params: ImageParams, - _mocked_node_ports: None, - mocked_user_completed_cb: mock.AsyncMock, - mocked_storage_service_api: respx.MockRouter, - comp_run_metadata: RunMetadataDict, - empty_hardware_info: HardwareInfo, - faker: Faker, - resource_tracking_run_id: ServiceRunID, -): - cluster_details = await dask_client.get_cluster_details() - assert cluster_details - - _DASK_EVENT_NAME = faker.pystr() - - # send a fct that uses resources - def fake_sidecar_fct( - task_parameters: ContainerTaskParameters, - docker_auth: DockerBasicAuth, - log_file_url: LogFileUploadURL, - s3_settings: S3Settings | None, - expected_annotations, - ) -> TaskOutputData: - # get the task data - worker = get_worker() - task = worker.state.tasks.get(worker.get_current_task()) - assert task is not None - assert task.annotations == expected_annotations - assert task_parameters.command == ["run"] - event = distributed.Event(_DASK_EVENT_NAME) - event.wait(timeout=25) - - return TaskOutputData.model_validate({"some_output_key": 123}) - - # NOTE: We pass another fct so it can run in our localy created dask cluster - published_computation_task = await dask_client.send_computation_tasks( - user_id=user_id, - project_id=project_id, - tasks=image_params.fake_tasks, - callback=mocked_user_completed_cb, - remote_fct=functools.partial( - fake_sidecar_fct, expected_annotations=image_params.expected_annotations - ), - metadata=comp_run_metadata, - hardware_info=empty_hardware_info, - resource_tracking_run_id=resource_tracking_run_id, - ) - assert published_computation_task - assert len(published_computation_task) == 1 - - assert published_computation_task[0].node_id in image_params.fake_tasks - - # check status goes to PENDING/STARTED - await _assert_wait_for_task_status( - published_computation_task[0].job_id, - dask_client, - expected_status=RunningState.STARTED, - ) - - # check we have one worker using the resources - # one of the workers should now get the job and use the resources - worker_with_the_task: AnyUrl | None = None - async for attempt in AsyncRetrying(reraise=True, stop=stop_after_delay(10)): - with attempt: - cluster_details = await dask_client.get_cluster_details() - assert cluster_details - assert ( - cluster_details.scheduler.workers - ), f"there are no workers in {cluster_details.scheduler=!r}" - for worker_url, worker_data in cluster_details.scheduler.workers.items(): - if all( - worker_data.used_resources.get(res_name) == res_value - for res_name, res_value in image_params.expected_used_resources.items() - ): - worker_with_the_task = worker_url - assert ( - worker_with_the_task is not None - ), f"there is no worker in {cluster_details.scheduler.workers.keys()=} consuming {image_params.expected_annotations=!r}" - - # using the event we let the remote fct continue - event = distributed.Event(_DASK_EVENT_NAME, client=dask_client.backend.client) - await event.set() # type: ignore - - # wait for the task to complete - await _assert_wait_for_task_status( - published_computation_task[0].job_id, - dask_client, - expected_status=RunningState.SUCCESS, - ) - - # check the resources are released - cluster_details = await dask_client.get_cluster_details() - assert cluster_details - assert cluster_details.scheduler.workers - assert worker_with_the_task - currently_used_resources = cluster_details.scheduler.workers[ - worker_with_the_task - ].used_resources - - assert all(res == 0.0 for res in currently_used_resources.values()) diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py index 5f78823f2880..b60be5bdcc02 100644 --- a/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py @@ -73,12 +73,24 @@ def with_disabled_scheduler_publisher(mocker: MockerFixture) -> mock.Mock: @pytest.fixture -def with_short_max_wait_for_clusters_keeper( +def with_short_max_wait_for_cluster( monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture ) -> datetime.timedelta: - short_time = datetime.timedelta(seconds=5) + short_time = datetime.timedelta(seconds=2) setenvs_from_dict( monkeypatch, {"COMPUTATIONAL_BACKEND_MAX_WAITING_FOR_CLUSTER_TIMEOUT": f"{short_time}"}, ) return short_time + + +@pytest.fixture +def with_short_max_wait_for_retrieving_results( + monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture +) -> datetime.timedelta: + short_time = datetime.timedelta(seconds=2) + setenvs_from_dict( + monkeypatch, + {"COMPUTATIONAL_BACKEND_MAX_WAITING_FOR_RETRIEVING_RESULTS": f"{short_time}"}, + ) + return short_time diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py index 638d10f88197..aa04cf322ff4 100644 --- a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py @@ -11,6 +11,7 @@ import asyncio import datetime +import random from collections.abc import AsyncIterator, Awaitable, Callable from copy import deepcopy from dataclasses import dataclass @@ -26,13 +27,17 @@ assert_comp_runs_empty, assert_comp_tasks_and_comp_run_snapshot_tasks, ) -from dask_task_models_library.container_tasks.errors import TaskCancelledError +from dask_task_models_library.container_tasks.errors import ( + ServiceRuntimeError, + TaskCancelledError, +) from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import TaskOutputData from dask_task_models_library.container_tasks.protocol import TaskOwner from faker import Faker from fastapi.applications import FastAPI from models_library.computations import CollectionRunID +from models_library.errors import ErrorDict from models_library.projects import ProjectAtDB, ProjectID from models_library.projects_nodes_io import NodeID from models_library.projects_state import RunningState @@ -72,6 +77,8 @@ BaseCompScheduler, ) from simcore_service_director_v2.modules.comp_scheduler._scheduler_dask import ( + _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY, + _TASK_RETRIEVAL_ERROR_TYPE, DaskScheduler, ) from simcore_service_director_v2.modules.comp_scheduler._utils import COMPLETED_STATES @@ -834,7 +841,13 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[RunningState]: ] mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed - mocked_dask_client.get_task_result.side_effect = None + mocked_dask_client.get_task_result.side_effect = ServiceRuntimeError( + service_key="simcore/services/dynamic/some-service", + service_version="1.0.0", + container_id="some-container-id", + exit_code=1, + service_logs="simulated error", + ) await scheduler_api.apply( user_id=run_in_db.user_id, project_id=run_in_db.project_uuid, @@ -2051,7 +2064,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_waits_and_eventually_timesout_fails( with_disabled_auto_scheduling: mock.Mock, with_disabled_scheduler_publisher: mock.Mock, - with_short_max_wait_for_clusters_keeper: datetime.timedelta, + with_short_max_wait_for_cluster: datetime.timedelta, initialized_app: FastAPI, scheduler_api: BaseCompScheduler, sqlalchemy_async_engine: AsyncEngine, @@ -2165,7 +2178,7 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_waits_and expected_progress=None, run_id=run_in_db.run_id, ) - await asyncio.sleep(with_short_max_wait_for_clusters_keeper.total_seconds() + 1) + await asyncio.sleep(with_short_max_wait_for_cluster.total_seconds() + 1) # again will trigger the call again, but now it will start failing, first the task will be mark as FAILED await scheduler_api.apply( user_id=run_in_db.user_id, @@ -2273,3 +2286,156 @@ async def test_run_new_pipeline_called_twice_prevents_duplicate_runs( 0, # No new messages expected ComputationalPipelineStatusMessage.model_validate_json, ) + + +async def test_getting_task_result_raises_exception_does_not_fail_task_and_retries( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + with_short_max_wait_for_retrieving_results: datetime.timedelta, + mocked_dask_client: mock.MagicMock, + initialized_app: FastAPI, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, + running_project: RunningProject, + mocked_parse_output_data_fct: mock.Mock, +): + # this tests the behavior of the scheduling when the dask client cannot retrieve + # the result of a task because of some communication error. In this case the task + # should be processed again in the next iteration and not marked as failed + # immediately. + async def mocked_get_tasks_status(job_ids: list[str]) -> list[RunningState]: + return [RunningState.SUCCESS for j in job_ids] + + mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status + + computational_tasks = [ + t for t in running_project.tasks if t.node_class is NodeClass.COMPUTATIONAL + ] + expected_timeouted_tasks = random.sample( + computational_tasks, k=len(computational_tasks) - 1 + ) + successful_tasks = [ + t for t in computational_tasks if t not in expected_timeouted_tasks + ] + + async def mocked_get_task_result(job_id: str) -> TaskOutputData: + if job_id in [t.job_id for t in successful_tasks]: + return TaskOutputData.model_validate({"whatever_output": 123}) + raise ComputationalBackendTaskResultsNotReadyError(job_id=job_id) + + mocked_dask_client.get_task_result.side_effect = mocked_get_task_result + # calling apply should not raise + assert running_project.project.prj_owner + await scheduler_api.apply( + user_id=running_project.project.prj_owner, + project_id=running_project.project.uuid, + iteration=1, + ) + assert mocked_dask_client.get_task_result.call_count == len(computational_tasks) + mocked_dask_client.get_task_result.reset_mock() + + # check the tasks in the DB, the error shall be set there and the task state is set back to STARTED + comp_tasks, _ = await assert_comp_tasks_and_comp_run_snapshot_tasks( + sqlalchemy_async_engine, + project_uuid=running_project.project.uuid, + task_ids=[t.node_id for t in expected_timeouted_tasks], + expected_state=RunningState.STARTED, + expected_progress=0, + run_id=running_project.runs.run_id, + ) + # we should have an error in all these comp_tasks + retrieval_times = [] + for t in comp_tasks: + assert t.errors + assert len(t.errors) == 1 + error_dict = TypeAdapter(ErrorDict).validate_python(t.errors[0]) + assert error_dict["type"] == _TASK_RETRIEVAL_ERROR_TYPE + assert "ctx" in error_dict + assert _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY in error_dict["ctx"] + retrieval_times.append( + error_dict["ctx"][_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY] + ) + assert len(retrieval_times) == len(expected_timeouted_tasks) + + await assert_comp_tasks_and_comp_run_snapshot_tasks( + sqlalchemy_async_engine, + project_uuid=running_project.project.uuid, + task_ids=[t.node_id for t in successful_tasks], + expected_state=RunningState.SUCCESS, + expected_progress=1.0, + run_id=running_project.runs.run_id, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == running_project.project.prj_owner, + comp_runs.c.project_uuid == f"{running_project.project.uuid}", + ), + ) + + # calling again should not raise neither but try again + assert running_project.project.prj_owner + for _ in range(3): + await scheduler_api.apply( + user_id=running_project.project.prj_owner, + project_id=running_project.project.uuid, + iteration=1, + ) + assert mocked_dask_client.get_task_result.call_count == ( + len(expected_timeouted_tasks) + ) + mocked_dask_client.get_task_result.reset_mock() + + comp_tasks, _ = await assert_comp_tasks_and_comp_run_snapshot_tasks( + sqlalchemy_async_engine, + project_uuid=running_project.project.uuid, + task_ids=[t.node_id for t in expected_timeouted_tasks], + expected_state=RunningState.STARTED, + expected_progress=0, + run_id=running_project.runs.run_id, + ) + # the times shall remain the same + for t in comp_tasks: + assert t.errors + assert len(t.errors) == 1 + error_dict = TypeAdapter(ErrorDict).validate_python(t.errors[0]) + assert error_dict["type"] == _TASK_RETRIEVAL_ERROR_TYPE + assert "ctx" in error_dict + assert _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY in error_dict["ctx"] + # the time shall be the same as before + assert ( + error_dict["ctx"][_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY] in retrieval_times + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == running_project.project.prj_owner, + comp_runs.c.project_uuid == f"{running_project.project.uuid}", + ), + ) + + # now we wait for the max time and the task should be marked as FAILED + await asyncio.sleep(with_short_max_wait_for_retrieving_results.total_seconds() + 1) + await scheduler_api.apply( + user_id=running_project.project.prj_owner, + project_id=running_project.project.uuid, + iteration=1, + ) + assert mocked_dask_client.get_task_result.call_count == len( + expected_timeouted_tasks + ) + # NOTE: we do not check all tasks here as some are depending on random others + # so some are ABORTED and others are FAILED depending on the random sample above + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.FAILED, + where_statement=and_( + comp_runs.c.user_id == running_project.project.prj_owner, + comp_runs.c.project_uuid == f"{running_project.project.uuid}", + ), + )