diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py index 1491d2ef781c..97b1bc1cf22c 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -35,6 +35,7 @@ from servicelib.logging_utils import log_catch, log_context from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient from servicelib.redis import RedisClientSDK +from servicelib.utils import limited_gather from sqlalchemy.ext.asyncio import AsyncEngine from ...constants import UNDEFINED_STR_METADATA @@ -79,6 +80,7 @@ _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta( seconds=30 ) +_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10 def _auto_schedule_callback( @@ -336,7 +338,7 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool: project_id, dag ) if running_tasks := [t for t in tasks.values() if _need_heartbeat(t)]: - await asyncio.gather( + await limited_gather( *( publish_service_resource_tracking_heartbeat( self.rabbitmq_client, @@ -345,17 +347,15 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool: ), ) for t in running_tasks - ) + ), + log=_logger, + limit=_PUBLICATION_CONCURRENCY_LIMIT, ) - comp_tasks_repo = CompTasksRepository(self.db_engine) - await asyncio.gather( - *( - comp_tasks_repo.update_project_task_last_heartbeat( - t.project_id, t.node_id, run_id, utc_now - ) - for t in running_tasks + comp_tasks_repo = CompTasksRepository.instance(self.db_engine) + for task in running_tasks: + await comp_tasks_repo.update_project_task_last_heartbeat( + project_id, task.node_id, run_id, utc_now ) - ) async def _get_changed_tasks_from_backend( self, @@ -400,7 +400,7 @@ async def _process_started_tasks( utc_now = arrow.utcnow().datetime # resource tracking - await asyncio.gather( + await limited_gather( *( publish_service_resource_tracking_started( self.rabbitmq_client, @@ -462,10 +462,12 @@ async def _process_started_tasks( service_additional_metadata={}, ) for t in tasks - ) + ), + log=_logger, + limit=_PUBLICATION_CONCURRENCY_LIMIT, ) # instrumentation - await asyncio.gather( + await limited_gather( *( publish_service_started_metrics( self.rabbitmq_client, @@ -476,24 +478,22 @@ async def _process_started_tasks( task=t, ) for t in tasks - ) + ), + log=_logger, + limit=_PUBLICATION_CONCURRENCY_LIMIT, ) # update DB comp_tasks_repo = CompTasksRepository(self.db_engine) - await asyncio.gather( - *( - comp_tasks_repo.update_project_tasks_state( - t.project_id, - run_id, - [t.node_id], - t.state, - optional_started=utc_now, - optional_progress=t.progress, - ) - for t in tasks + for task in tasks: + await comp_tasks_repo.update_project_tasks_state( + project_id, + run_id, + [task.node_id], + task.state, + optional_started=utc_now, + optional_progress=task.progress, ) - ) await CompRunsRepository.instance(self.db_engine).mark_as_started( user_id=user_id, project_id=project_id, @@ -504,18 +504,14 @@ async def _process_started_tasks( async def _process_waiting_tasks( self, tasks: list[TaskStateTracker], run_id: PositiveInt ) -> None: - comp_tasks_repo = CompTasksRepository(self.db_engine) - await asyncio.gather( - *( - comp_tasks_repo.update_project_tasks_state( - t.current.project_id, - run_id, - [t.current.node_id], - t.current.state, - ) - for t in tasks + comp_tasks_repo = CompTasksRepository.instance(self.db_engine) + for task in tasks: + await comp_tasks_repo.update_project_tasks_state( + task.current.project_id, + run_id, + [task.current.node_id], + task.current.state, ) - ) async def _update_states_from_comp_backend( self, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 3cb3d993fd9e..867b59625cbf 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -48,15 +48,12 @@ publish_service_stopped_metrics, ) from ..clusters_keeper import get_or_create_on_demand_cluster -from ..dask_client import DaskClient, PublishedComputationTask +from ..dask_client import DaskClient from ..dask_clients_pool import DaskClientsPool from ..db.repositories.comp_runs import ( CompRunsRepository, ) from ..db.repositories.comp_tasks import CompTasksRepository -from ._constants import ( - MAX_CONCURRENT_PIPELINE_SCHEDULING, -) from ._models import TaskStateTracker from ._scheduler_base import BaseCompScheduler from ._utils import ( @@ -68,6 +65,7 @@ _DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{project_id}:{run_id}" _TASK_RETRIEVAL_ERROR_TYPE: Final[str] = "task-result-retrieval-timeout" _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time" +_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10 @asynccontextmanager @@ -149,37 +147,31 @@ async def _start_tasks( RunningState.PENDING, ) # each task is started independently - results: list[list[PublishedComputationTask]] = await limited_gather( - *( - client.send_computation_tasks( - user_id=user_id, - project_id=project_id, - tasks={node_id: task.image}, - hardware_info=task.hardware_info, - callback=wake_up_callback, - metadata=comp_run.metadata, - resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational( - user_id, project_id, node_id, comp_run.iteration - ), - ) - for node_id, task in scheduled_tasks.items() - ), - log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, - ) - # update the database so we do have the correct job_ids there - await limited_gather( - *( - comp_tasks_repo.update_project_task_job_id( - project_id, task.node_id, comp_run.run_id, task.job_id - ) - for task_sents in results - for task in task_sents - ), - log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, - ) + for node_id, task in scheduled_tasks.items(): + published_tasks = await client.send_computation_tasks( + user_id=user_id, + project_id=project_id, + tasks={node_id: task.image}, + hardware_info=task.hardware_info, + callback=wake_up_callback, + metadata=comp_run.metadata, + resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational( + user_id, project_id, node_id, comp_run.iteration + ), + ) + + # update the database so we do have the correct job_ids there + await limited_gather( + *( + comp_tasks_repo.update_project_task_job_id( + project_id, task.node_id, comp_run.run_id, task.job_id + ) + for task in published_tasks + ), + log=_logger, + limit=1, + ) async def _get_tasks_status( self, @@ -208,7 +200,7 @@ async def _process_executing_tasks( tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB, ) -> None: - task_progresses = [] + task_progress_events = [] try: async with _cluster_dask_client( user_id, @@ -218,42 +210,33 @@ async def _process_executing_tasks( run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: - task_progresses = [ + task_progress_events = [ t for t in await client.get_tasks_progress( [f"{t.job_id}" for t in tasks], ) if t is not None ] - await limited_gather( - *( - CompTasksRepository(self.db_engine).update_project_task_progress( - t.task_owner.project_id, - t.task_owner.node_id, - comp_run.run_id, - t.progress, - ) - for t in task_progresses - ), - log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, - ) + for progress_event in task_progress_events: + await CompTasksRepository(self.db_engine).update_project_task_progress( + progress_event.task_owner.project_id, + progress_event.task_owner.node_id, + comp_run.run_id, + progress_event.progress, + ) except ComputationalBackendOnDemandNotReadyError: _logger.info("The on demand computational backend is not ready yet...") comp_tasks_repo = CompTasksRepository(self.db_engine) + for task in task_progress_events: + await comp_tasks_repo.update_project_task_progress( + task.task_owner.project_id, + task.task_owner.node_id, + comp_run.run_id, + task.progress, + ) await limited_gather( - *( - comp_tasks_repo.update_project_task_progress( - t.task_owner.project_id, - t.task_owner.node_id, - comp_run.run_id, - t.progress, - ) - for t in task_progresses - if t - ), *( publish_service_progress( self.rabbitmq_client, @@ -262,11 +245,10 @@ async def _process_executing_tasks( node_id=t.task_owner.node_id, progress=t.progress, ) - for t in task_progresses - if t + for t in task_progress_events ), log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + limit=_PUBLICATION_CONCURRENCY_LIMIT, ) async def _release_resources(self, comp_run: CompRunsAtDB) -> None: @@ -300,25 +282,14 @@ async def _stop_tasks( run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: - await limited_gather( - *( - client.abort_computation_task(t.job_id) - for t in tasks - if t.job_id - ), - log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, - ) - # tasks that have no-worker must be unpublished as these are blocking forever - await limited_gather( - *( - client.release_task_result(t.job_id) - for t in tasks - if t.state is RunningState.WAITING_FOR_RESOURCES and t.job_id - ), - log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, - ) + for t in tasks: + if not t.job_id: + _logger.warning("%s has no job_id, cannot be stopped", t) + continue + await client.abort_computation_task(t.job_id) + # tasks that have no-worker must be unpublished as these are blocking forever + if t.state is RunningState.WAITING_FOR_RESOURCES: + await client.release_task_result(t.job_id) async def _process_completed_tasks( self, @@ -342,7 +313,7 @@ async def _process_completed_tasks( ), reraise=False, log=_logger, - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + limit=1, # to avoid overloading the dask scheduler ) async for future in limited_as_completed( ( @@ -354,7 +325,7 @@ async def _process_completed_tasks( ) for task, result in zip(tasks, tasks_results, strict=True) ), - limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + limit=10, # this is not accessing the dask-scheduelr (only db) ): with log_catch(_logger, reraise=False): task_can_be_cleaned, job_id = await future diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index 9b7b7c17b18b..84195c18d462 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -8,7 +8,6 @@ """ -import asyncio import logging from collections.abc import Callable, Iterable from dataclasses import dataclass @@ -241,9 +240,6 @@ def _comp_sidecar_fct( ) # NOTE: the callback is running in a secondary thread, and takes a future as arg task_future.add_done_callback(lambda _: callback()) - await distributed.Variable(job_id, client=self.backend.client).set( - task_future - ) await dask_utils.wrap_client_async_routine( self.backend.client.publish_dataset(task_future, name=job_id) @@ -560,12 +556,6 @@ async def get_task_result(self, job_id: str) -> TaskOutputData: async def release_task_result(self, job_id: str) -> None: _logger.debug("releasing results for %s", f"{job_id=}") try: - # NOTE: The distributed Variable holds the future of the tasks in the dask-scheduler - # Alas, deleting the variable is done asynchronously and there is no way to ensure - # the variable was effectively deleted. - # This is annoying as one can re-create the variable without error. - var = distributed.Variable(job_id, client=self.backend.client) - await asyncio.get_event_loop().run_in_executor(None, var.delete) # first check if the key exists await dask_utils.wrap_client_async_routine( self.backend.client.get_dataset(name=job_id)