diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py index 28dca04dc536..4fc5c1831ef6 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Literal from models_library.projects import ProjectID @@ -5,15 +6,22 @@ from models_library.users import UserID from ...models.comp_runs import Iteration +from ...models.comp_tasks import CompTaskAtDB class SchedulePipelineRabbitMessage(RabbitMessageBase): - channel_name: Literal[ + channel_name: Literal["simcore.services.director-v2.scheduling"] = ( "simcore.services.director-v2.scheduling" - ] = "simcore.services.director-v2.scheduling" + ) user_id: UserID project_id: ProjectID iteration: Iteration def routing_key(self) -> str | None: # pylint: disable=no-self-use # abstract return None + + +@dataclass(frozen=True, slots=True) +class TaskStateTracker: + previous: CompTaskAtDB + current: CompTaskAtDB diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py index 0cdf52685f08..17a7b8a09afd 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -63,6 +63,7 @@ from ..db.repositories.comp_pipelines import CompPipelinesRepository from ..db.repositories.comp_runs import CompRunsRepository from ..db.repositories.comp_tasks import CompTasksRepository +from ._models import TaskStateTracker from ._publisher import request_pipeline_scheduling from ._utils import ( COMPLETED_STATES, @@ -76,9 +77,6 @@ _logger = logging.getLogger(__name__) -_Previous = CompTaskAtDB -_Current = CompTaskAtDB - _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta( seconds=30 ) @@ -117,47 +115,49 @@ async def _async_cb() -> None: @dataclass(frozen=True, slots=True) class SortedTasks: started: list[CompTaskAtDB] - completed: list[CompTaskAtDB] - waiting: list[CompTaskAtDB] - potentially_lost: list[CompTaskAtDB] + completed: list[TaskStateTracker] + waiting: list[TaskStateTracker] + potentially_lost: list[TaskStateTracker] async def _triage_changed_tasks( - changed_tasks: list[tuple[_Previous, _Current]], + changed_tasks: list[TaskStateTracker], ) -> SortedTasks: started_tasks = [ - current - for previous, current in changed_tasks - if current.state in RUNNING_STATES + tracker.current + for tracker in changed_tasks + if tracker.current.state in RUNNING_STATES or ( - previous.state in WAITING_FOR_START_STATES - and current.state in COMPLETED_STATES + tracker.previous.state in WAITING_FOR_START_STATES + and tracker.current.state in COMPLETED_STATES ) ] completed_tasks = [ - current for _, current in changed_tasks if current.state in COMPLETED_STATES + tracker + for tracker in changed_tasks + if tracker.current.state in COMPLETED_STATES ] waiting_for_resources_tasks = [ - current - for previous, current in changed_tasks - if current.state in WAITING_FOR_START_STATES + tracker + for tracker in changed_tasks + if tracker.current.state in WAITING_FOR_START_STATES ] lost_tasks = [ - current - for previous, current in changed_tasks - if (current.state is RunningState.UNKNOWN) + tracker + for tracker in changed_tasks + if (tracker.current.state is RunningState.UNKNOWN) and ( - (arrow.utcnow().datetime - previous.modified) + (arrow.utcnow().datetime - tracker.previous.modified) > _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS ) ] if lost_tasks: _logger.warning( "%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!", - [t.node_id for t in lost_tasks], + [t.current.node_id for t in lost_tasks], ) return SortedTasks( @@ -321,6 +321,7 @@ async def _send_running_tasks_heartbeat( def _need_heartbeat(task: CompTaskAtDB) -> bool: if task.state not in RUNNING_STATES: return False + if task.last_heartbeat is None: assert task.start # nosec return bool( @@ -362,14 +363,14 @@ async def _get_changed_tasks_from_backend( user_id: UserID, processing_tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB, - ) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]: + ) -> tuple[list[TaskStateTracker], list[CompTaskAtDB]]: tasks_backend_status = await self._get_tasks_status( user_id, processing_tasks, comp_run ) return ( [ - ( + TaskStateTracker( task, task.model_copy(update={"state": backend_state}), ) @@ -502,16 +503,16 @@ async def _process_started_tasks( ) async def _process_waiting_tasks( - self, tasks: list[CompTaskAtDB], run_id: PositiveInt + self, tasks: list[TaskStateTracker], run_id: PositiveInt ) -> None: comp_tasks_repo = CompTasksRepository(self.db_engine) await asyncio.gather( *( comp_tasks_repo.update_project_tasks_state( - t.project_id, + t.current.project_id, run_id, - [t.node_id], - t.state, + [t.current.node_id], + t.current.state, ) for t in tasks ) @@ -602,7 +603,7 @@ async def _stop_tasks( async def _process_completed_tasks( self, user_id: UserID, - tasks: list[CompTaskAtDB], + tasks: list[TaskStateTracker], iteration: Iteration, comp_run: CompRunsAtDB, ) -> None: diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 2638271d1a0f..193b44eb871d 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import logging from collections.abc import AsyncIterator, Callable @@ -25,7 +24,7 @@ from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.logging_utils import log_catch, log_context -from servicelib.utils import limited_as_completed +from servicelib.utils import limited_as_completed, limited_gather from ...core.errors import ( ComputationalBackendNotConnectedError, @@ -55,6 +54,7 @@ from ._constants import ( MAX_CONCURRENT_PIPELINE_SCHEDULING, ) +from ._models import TaskStateTracker from ._scheduler_base import BaseCompScheduler from ._utils import ( WAITING_FOR_START_STATES, @@ -131,7 +131,7 @@ async def _start_tasks( RunningState.PENDING, ) # each task is started independently - results: list[list[PublishedComputationTask]] = await asyncio.gather( + results: list[list[PublishedComputationTask]] = await limited_gather( *( client.send_computation_tasks( user_id=user_id, @@ -146,17 +146,21 @@ async def _start_tasks( ) for node_id, task in scheduled_tasks.items() ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) # update the database so we do have the correct job_ids there - await asyncio.gather( + await limited_gather( *( comp_tasks_repo.update_project_task_job_id( project_id, task.node_id, comp_run.run_id, task.job_id ) for task_sents in results for task in task_sents - ) + ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) async def _get_tasks_status( @@ -196,25 +200,32 @@ async def _process_executing_tasks( run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: - task_progresses = await client.get_tasks_progress( - [f"{t.job_id}" for t in tasks], - ) - for task_progress_event in task_progresses: - if task_progress_event: - await CompTasksRepository( - self.db_engine - ).update_project_task_progress( - task_progress_event.task_owner.project_id, - task_progress_event.task_owner.node_id, + task_progresses = [ + t + for t in await client.get_tasks_progress( + [f"{t.job_id}" for t in tasks], + ) + if t is not None + ] + await limited_gather( + *( + CompTasksRepository(self.db_engine).update_project_task_progress( + t.task_owner.project_id, + t.task_owner.node_id, comp_run.run_id, - task_progress_event.progress, + t.progress, ) + for t in task_progresses + ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + ) except ComputationalBackendOnDemandNotReadyError: _logger.info("The on demand computational backend is not ready yet...") comp_tasks_repo = CompTasksRepository(self.db_engine) - await asyncio.gather( + await limited_gather( *( comp_tasks_repo.update_project_task_progress( t.task_owner.project_id, @@ -224,9 +235,7 @@ async def _process_executing_tasks( ) for t in task_progresses if t - ) - ) - await asyncio.gather( + ), *( publish_service_progress( self.rabbitmq_client, @@ -237,7 +246,9 @@ async def _process_executing_tasks( ) for t in task_progresses if t - ) + ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) async def _release_resources(self, comp_run: CompRunsAtDB) -> None: @@ -271,29 +282,30 @@ async def _stop_tasks( run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: - await asyncio.gather( - *[ + await limited_gather( + *( client.abort_computation_task(t.job_id) for t in tasks if t.job_id - ] + ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) # tasks that have no-worker must be unpublished as these are blocking forever - tasks_with_no_worker = [ - t for t in tasks if t.state is RunningState.WAITING_FOR_RESOURCES - ] - await asyncio.gather( - *[ + await limited_gather( + *( client.release_task_result(t.job_id) - for t in tasks_with_no_worker - if t.job_id - ] + for t in tasks + if t.state is RunningState.WAITING_FOR_RESOURCES and t.job_id + ), + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) async def _process_completed_tasks( self, user_id: UserID, - tasks: list[CompTaskAtDB], + tasks: list[TaskStateTracker], iteration: Iteration, comp_run: CompRunsAtDB, ) -> None: @@ -305,14 +317,23 @@ async def _process_completed_tasks( run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: - tasks_results = await asyncio.gather( - *[client.get_task_result(t.job_id or "undefined") for t in tasks], - return_exceptions=True, + tasks_results = await limited_gather( + *( + client.get_task_result(t.current.job_id or "undefined") + for t in tasks + ), + reraise=False, + log=_logger, + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, ) async for future in limited_as_completed( ( self._process_task_result( - task, result, comp_run.metadata, iteration, comp_run.run_id + task, + result, + comp_run.metadata, + iteration, + comp_run.run_id, ) for task, result in zip(tasks, tasks_results, strict=True) ), @@ -467,7 +488,7 @@ async def _handle_task_error( async def _process_task_result( self, - task: CompTaskAtDB, + task: TaskStateTracker, result: BaseException | TaskOutputData, run_metadata: RunMetadataDict, iteration: Iteration, @@ -475,23 +496,22 @@ async def _process_task_result( ) -> tuple[bool, str]: """Returns True and the job ID if the task was successfully processed and can be released from the Dask cluster.""" _logger.debug("received %s result: %s", f"{task=}", f"{result=}") - - assert task.job_id # nosec + assert task.current.job_id # nosec ( _service_key, _service_version, user_id, project_id, node_id, - ) = parse_dask_job_id(task.job_id) + ) = parse_dask_job_id(task.current.job_id) - assert task.project_id == project_id # nosec - assert task.node_id == node_id # nosec + assert task.current.project_id == project_id # nosec + assert task.current.node_id == node_id # nosec log_error_context = { "user_id": user_id, "project_id": project_id, "node_id": node_id, - "job_id": task.job_id, + "job_id": task.current.job_id, } if isinstance(result, TaskOutputData): @@ -500,7 +520,9 @@ async def _process_task_result( simcore_platform_status, task_errors, task_completed, - ) = await self._handle_successful_run(task, result, log_error_context) + ) = await self._handle_successful_run( + task.current, result, log_error_context + ) elif isinstance(result, ComputationalBackendTaskResultsNotReadyError): ( @@ -509,7 +531,7 @@ async def _process_task_result( task_errors, task_completed, ) = await self._handle_computational_retrieval_error( - task, user_id, result, log_error_context + task.current, user_id, result, log_error_context ) elif isinstance(result, ComputationalBackendNotConnectedError): ( @@ -518,7 +540,7 @@ async def _process_task_result( task_errors, task_completed, ) = await self._handle_computational_backend_not_connected_error( - task, result, log_error_context + task.current, result, log_error_context ) else: ( @@ -526,7 +548,7 @@ async def _process_task_result( simcore_platform_status, task_errors, task_completed, - ) = await self._handle_task_error(task, result, log_error_context) + ) = await self._handle_task_error(task.current, result, log_error_context) # we need to remove any invalid files in the storage await clean_task_output_and_log_files_if_invalid( @@ -549,21 +571,21 @@ async def _process_task_result( simcore_user_agent=run_metadata.get( "simcore_user_agent", UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE ), - task=task, + task=task.current, task_final_state=task_final_state, ) await CompTasksRepository(self.db_engine).update_project_tasks_state( - task.project_id, + task.current.project_id, run_id, - [task.node_id], - task_final_state if task_completed else RunningState.STARTED, + [task.current.node_id], + task_final_state if task_completed else task.previous.state, errors=task_errors, optional_progress=1 if task_completed else None, optional_stopped=arrow.utcnow().datetime if task_completed else None, ) - return task_completed, task.job_id + return task_completed, task.current.job_id async def _task_progress_change_handler( self, event: tuple[UnixTimestamp, Any] diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index 17e1697495cf..acbfa239d30f 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -58,6 +58,7 @@ from pydantic.networks import AnyUrl from servicelib.logging_errors import create_troubleshootting_log_kwargs from servicelib.logging_utils import log_context +from servicelib.utils import limited_gather from settings_library.s3 import S3Settings from simcore_sdk.node_ports_common.exceptions import NodeportsException from simcore_sdk.node_ports_v2 import FileLinkType @@ -88,10 +89,11 @@ _logger = logging.getLogger(__name__) -_DASK_DEFAULT_TIMEOUT_S: Final[int] = 35 +_DASK_DEFAULT_TIMEOUT_S: Final[int] = 10 _UserCallbackInSepThread = Callable[[], None] +_MAX_CONCURRENT_CLIENT_CONNECTIONS: Final[int] = 10 @dataclass(frozen=True, kw_only=True, slots=True) @@ -423,7 +425,11 @@ async def _get_task_progress(job_id: str) -> TaskProgressEvent | None: # we are interested in the last event return TaskProgressEvent.model_validate_json(dask_events[-1][1]) - return await asyncio.gather(*(_get_task_progress(job_id) for job_id in job_ids)) + return await limited_gather( + *(_get_task_progress(job_id) for job_id in job_ids), + log=_logger, + limit=_MAX_CONCURRENT_CLIENT_CONNECTIONS, + ) async def get_tasks_status(self, job_ids: Iterable[str]) -> list[RunningState]: dask_utils.check_scheduler_is_still_the_same( @@ -501,7 +507,11 @@ async def _get_task_state(job_id: str) -> RunningState: return parsed_event.state - return await asyncio.gather(*(_get_task_state(job_id) for job_id in job_ids)) + return await limited_gather( + *(_get_task_state(job_id) for job_id in job_ids), + log=_logger, + limit=_MAX_CONCURRENT_CLIENT_CONNECTIONS, + ) async def abort_computation_task(self, job_id: str) -> None: # Dask future may be cancelled, but only a future that was not already taken by