diff --git a/packages/service-library/src/servicelib/redis_utils.py b/packages/service-library/src/servicelib/redis_utils.py index 10f32ae59443..559349cbb0db 100644 --- a/packages/service-library/src/servicelib/redis_utils.py +++ b/packages/service-library/src/servicelib/redis_utils.py @@ -3,7 +3,7 @@ import logging from collections.abc import Awaitable, Callable from datetime import timedelta -from typing import Any +from typing import Any, ParamSpec, TypeVar import arrow @@ -12,10 +12,16 @@ _logger = logging.getLogger(__file__) +P = ParamSpec("P") +R = TypeVar("R") + def exclusive( - redis: RedisClientSDK, *, lock_key: str, lock_value: bytes | str | None = None -): + redis: RedisClientSDK | Callable[..., RedisClientSDK], + *, + lock_key: str | Callable[..., str], + lock_value: bytes | str | None = None, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: """ Define a method to run exclusively across processes by leveraging a Redis Lock. @@ -24,12 +30,30 @@ def exclusive( redis: the redis client SDK lock_key: a string as the name of the lock (good practice: app_name:lock_name) lock_value: some additional data that can be retrieved by another client + + Raises: + - ValueError if used incorrectly + - CouldNotAcquireLockError if the lock could not be acquired """ - def decorator(func): + if not lock_key: + msg = "lock_key cannot be empty string!" + raise ValueError(msg) + + def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: @functools.wraps(func) - async def wrapper(*args, **kwargs): - async with redis.lock_context(lock_key=lock_key, lock_value=lock_value): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + redis_lock_key = ( + lock_key(*args, **kwargs) if callable(lock_key) else lock_key + ) + assert isinstance(redis_lock_key, str) # nosec + + redis_client = redis(*args, **kwargs) if callable(redis) else redis + assert isinstance(redis_client, RedisClientSDK) # nosec + + async with redis_client.lock_context( + lock_key=redis_lock_key, lock_value=lock_value + ): return await func(*args, **kwargs) return wrapper diff --git a/packages/service-library/tests/test_redis_utils.py b/packages/service-library/tests/test_redis_utils.py index f897fc7c3998..26f749cd894c 100644 --- a/packages/service-library/tests/test_redis_utils.py +++ b/packages/service-library/tests/test_redis_utils.py @@ -5,6 +5,7 @@ from contextlib import AbstractAsyncContextManager from datetime import timedelta from itertools import chain +from typing import Awaitable from unittest.mock import Mock import arrow @@ -32,39 +33,117 @@ async def _is_locked(redis_client_sdk: RedisClientSDK, lock_name: str) -> bool: @pytest.fixture def lock_name(faker: Faker) -> str: - return faker.uuid4() # type: ignore + return faker.pystr() + + +def _exclusive_sleeping_task( + redis_client_sdk: RedisClientSDK | Callable[..., RedisClientSDK], + lock_name: str | Callable[..., str], + sleep_duration: float, +) -> Callable[..., Awaitable[float]]: + @exclusive(redis_client_sdk, lock_key=lock_name) + async def _() -> float: + resolved_client = ( + redis_client_sdk() if callable(redis_client_sdk) else redis_client_sdk + ) + resolved_lock_name = lock_name() if callable(lock_name) else lock_name + assert await _is_locked(resolved_client, resolved_lock_name) + await asyncio.sleep(sleep_duration) + assert await _is_locked(resolved_client, resolved_lock_name) + return sleep_duration + + return _ + + +@pytest.fixture +def sleep_duration(faker: Faker) -> float: + return faker.pyfloat(positive=True, min_value=0.2, max_value=0.8) -async def _contained_client( +async def test_exclusive_decorator( get_redis_client_sdk: Callable[ [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] ], lock_name: str, - task_duration: float, -) -> None: - async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk: - assert not await _is_locked(redis_client_sdk, lock_name) - - @exclusive(redis_client_sdk, lock_key=lock_name) - async def _some_task() -> None: - assert await _is_locked(redis_client_sdk, lock_name) - await asyncio.sleep(task_duration) - assert await _is_locked(redis_client_sdk, lock_name) - - await _some_task() + sleep_duration: float, +): - assert not await _is_locked(redis_client_sdk, lock_name) + async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client: + for _ in range(3): + assert ( + await _exclusive_sleeping_task( + redis_client, lock_name, sleep_duration + )() + == sleep_duration + ) -@pytest.mark.parametrize("task_duration", [0.1, 1, 2]) -async def test_exclusive_sequentially( +async def test_exclusive_decorator_with_key_builder( get_redis_client_sdk: Callable[ [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] ], lock_name: str, - task_duration: float, + sleep_duration: float, ): - await _contained_client(get_redis_client_sdk, lock_name, task_duration) + def _get_lock_name(*args, **kwargs) -> str: + assert args is not None + assert kwargs is not None + return lock_name + + async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client: + for _ in range(3): + assert ( + await _exclusive_sleeping_task( + redis_client, _get_lock_name, sleep_duration + )() + == sleep_duration + ) + + +async def test_exclusive_decorator_with_client_builder( + get_redis_client_sdk: Callable[ + [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] + ], + lock_name: str, + sleep_duration: float, +): + async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client: + + def _get_redis_client_builder(*args, **kwargs) -> RedisClientSDK: + assert args is not None + assert kwargs is not None + return redis_client + + for _ in range(3): + assert ( + await _exclusive_sleeping_task( + _get_redis_client_builder, lock_name, sleep_duration + )() + == sleep_duration + ) + + +async def _acquire_lock_and_exclusively_sleep( + get_redis_client_sdk: Callable[ + [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] + ], + lock_name: str | Callable[..., str], + sleep_duration: float, +) -> None: + async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk: + redis_lock_name = lock_name() if callable(lock_name) else lock_name + assert not await _is_locked(redis_client_sdk, redis_lock_name) + + @exclusive(redis_client_sdk, lock_key=lock_name) + async def _() -> float: + assert await _is_locked(redis_client_sdk, redis_lock_name) + await asyncio.sleep(sleep_duration) + assert await _is_locked(redis_client_sdk, redis_lock_name) + return sleep_duration + + assert await _() == sleep_duration + + assert not await _is_locked(redis_client_sdk, redis_lock_name) async def test_exclusive_parallel_lock_is_released_and_reacquired( @@ -76,17 +155,19 @@ async def test_exclusive_parallel_lock_is_released_and_reacquired( parallel_tasks = 10 results = await logged_gather( *[ - _contained_client(get_redis_client_sdk, lock_name, task_duration=0.1) + _acquire_lock_and_exclusively_sleep( + get_redis_client_sdk, lock_name, sleep_duration=0.1 + ) for _ in range(parallel_tasks) ], - reraise=False + reraise=False, ) assert results.count(None) == 1 assert [isinstance(x, CouldNotAcquireLockError) for x in results].count( True ) == parallel_tasks - 1 - # check lock is being released + # check lock is released async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk: assert not await _is_locked(redis_client_sdk, lock_name) @@ -168,7 +249,7 @@ async def test_start_exclusive_periodic_task_parallel_all_finish( _assert_task_completes_once(get_redis_client_sdk, stop_after=60) for _ in range(parallel_tasks) ], - reraise=False + reraise=False, ) # check no error occurred diff --git a/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py b/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py index a09036087895..aa01af1f34bc 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py @@ -1,3 +1,5 @@ +from typing import Annotated + from fastapi import Depends, FastAPI, Request from ...core.settings import ComputationalBackendSettings @@ -11,7 +13,7 @@ def get_scheduler(request: Request) -> BaseCompScheduler: def get_scheduler_settings( - app: FastAPI = Depends(get_app), + app: Annotated[FastAPI, Depends(get_app)] ) -> ComputationalBackendSettings: settings: ComputationalBackendSettings = ( app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py index 1eb6c3dab107..d06c37457b73 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py @@ -1,7 +1,38 @@ +import logging +from collections.abc import Callable, Coroutine +from typing import Any, cast + from fastapi import FastAPI +from servicelib.logging_utils import log_context +from . import _scheduler_factory from ._base_scheduler import BaseCompScheduler -from ._task import on_app_shutdown, on_app_startup + +_logger = logging.getLogger(__name__) + + +def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: + async def start_scheduler() -> None: + with log_context( + _logger, level=logging.INFO, msg="starting computational scheduler" + ): + app.state.scheduler = scheduler = await _scheduler_factory.create_from_db( + app + ) + scheduler.recover_scheduling() + + return start_scheduler + + +def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: + async def stop_scheduler() -> None: + await get_scheduler(app).shutdown() + + return stop_scheduler + + +def get_scheduler(app: FastAPI) -> BaseCompScheduler: + return cast(BaseCompScheduler, app.state.scheduler) def setup(app: FastAPI): @@ -12,4 +43,5 @@ def setup(app: FastAPI): __all__: tuple[str, ...] = ( "setup", "BaseCompScheduler", + "get_scheduler", ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index cae539596d41..097afd952883 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -12,7 +12,9 @@ """ import asyncio +import contextlib import datetime +import functools import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -29,9 +31,12 @@ from models_library.users import UserID from networkx.classes.reportviews import InDegreeView from pydantic import PositiveInt +from servicelib.background_task import start_periodic_task, stop_periodic_task from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from servicelib.logging_utils import log_context from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient -from servicelib.utils import limited_gather +from servicelib.redis import CouldNotAcquireLockError, RedisClientSDK +from servicelib.redis_utils import exclusive from ...constants import UNDEFINED_STR_METADATA from ...core.errors import ( @@ -76,6 +81,10 @@ _Previous = CompTaskAtDB _Current = CompTaskAtDB _MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN: Final[int] = 10 +_SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) +_TASK_NAME_TEMPLATE: Final[ + str +] = "computational-scheduler-{user_id}:{project_id}:{iteration}" @dataclass(frozen=True, slots=True) @@ -134,6 +143,12 @@ class ScheduledPipelineParams: mark_for_cancellation: datetime.datetime | None use_on_demand_clusters: bool + scheduler_task: asyncio.Task | None = None + scheduler_waker: asyncio.Event = field(default_factory=asyncio.Event) + + def wake_up(self) -> None: + self.scheduler_waker.set() + @dataclass class BaseCompScheduler(ABC): @@ -146,6 +161,7 @@ class BaseCompScheduler(ABC): rabbitmq_rpc_client: RabbitMQRPCClient settings: ComputationalBackendSettings service_runtime_heartbeat_interval: datetime.timedelta + redis_client: RedisClientSDK async def run_new_pipeline( self, @@ -178,7 +194,7 @@ async def run_new_pipeline( ) self.scheduled_pipelines[ (user_id, project_id, new_run.iteration) - ] = ScheduledPipelineParams( + ] = pipeline_params = ScheduledPipelineParams( cluster_id=cluster_id, run_metadata=new_run.metadata, use_on_demand_clusters=use_on_demand_clusters, @@ -191,8 +207,8 @@ async def run_new_pipeline( log=f"Project pipeline scheduled using {'on-demand clusters' if use_on_demand_clusters else 'pre-defined clusters'}, starting soon...", log_level=logging.INFO, ) - # ensure the scheduler starts right away - self._wake_up_scheduler_now() + + self._start_scheduling(pipeline_params, user_id, project_id, new_run.iteration) async def stop_pipeline( self, user_id: UserID, project_id: ProjectID, iteration: int | None = None @@ -224,29 +240,76 @@ async def stop_pipeline( (user_id, project_id, selected_iteration) ].mark_for_cancellation = updated_comp_run.cancelled # ensure the scheduler starts right away - self._wake_up_scheduler_now() + self.scheduled_pipelines[ + (user_id, project_id, selected_iteration) + ].wake_up() - async def schedule_all_pipelines(self) -> None: - self.wake_up_event.clear() - # if one of the task throws, the other are NOT cancelled which is what we want - await limited_gather( + def recover_scheduling(self) -> None: + for ( + user_id, + project_id, + iteration, + ), params in self.scheduled_pipelines.items(): + self._start_scheduling(params, user_id, project_id, iteration) + + async def shutdown(self) -> None: + # cancel all current scheduling processes + await asyncio.gather( *( - self._schedule_pipeline( + stop_periodic_task(p.scheduler_task, timeout=3) + for p in self.scheduled_pipelines.values() + if p.scheduler_task + ), + return_exceptions=True, + ) + + def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iteration: + # if no iteration given find the latest one in the list + possible_iterations = { + it + for u_id, p_id, it in self.scheduled_pipelines + if u_id == user_id and p_id == project_id + } + if not possible_iterations: + msg = f"There are no pipeline scheduled for {user_id}:{project_id}" + raise SchedulerError(msg) + return max(possible_iterations) + + def _start_scheduling( + self, + pipeline_params: ScheduledPipelineParams, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + ) -> None: + async def _exclusive_safe_schedule_pipeline( + *, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + pipeline_params: ScheduledPipelineParams, + ) -> None: + with contextlib.suppress(CouldNotAcquireLockError): + await self._schedule_pipeline( user_id=user_id, project_id=project_id, iteration=iteration, pipeline_params=pipeline_params, ) - for ( - user_id, - project_id, - iteration, - ), pipeline_params in self.scheduled_pipelines.items() + + pipeline_params.scheduler_task = start_periodic_task( + functools.partial( + _exclusive_safe_schedule_pipeline, + user_id=user_id, + project_id=project_id, + iteration=iteration, + pipeline_params=pipeline_params, + ), + interval=_SCHEDULER_INTERVAL, + task_name=_TASK_NAME_TEMPLATE.format( + user_id=user_id, project_id=project_id, iteration=iteration ), - reraise=False, - log=_logger, - limit=40, - tasks_group_prefix="computational-scheduled-pipeline", + early_wake_up_event=pipeline_params.scheduler_waker, ) async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: @@ -610,6 +673,22 @@ async def _process_completed_tasks( ) -> None: ... + @staticmethod + def _build_exclusive_lock_key(*args, **kwargs) -> str: + assert args # nosec + return f"{kwargs['user_id']}:{kwargs['project_id']}:{kwargs['iteration']}" + + @staticmethod + def _redis_client_getter(*args, **kwargs) -> RedisClientSDK: + assert kwargs # nosec + zelf = args[0] + assert isinstance(zelf, BaseCompScheduler) # nosec + return zelf.redis_client + + @exclusive( + redis=_redis_client_getter, + lock_key=_build_exclusive_lock_key, + ) async def _schedule_pipeline( self, *, @@ -618,98 +697,99 @@ async def _schedule_pipeline( iteration: PositiveInt, pipeline_params: ScheduledPipelineParams, ) -> None: - _logger.debug( - "checking run of project [%s:%s] for user [%s]", - f"{project_id=}", - f"{iteration=}", - f"{user_id=}", - ) - dag: nx.DiGraph = nx.DiGraph() - try: - dag = await self._get_pipeline_dag(project_id) - # 1. Update our list of tasks with data from backend (state, results) - await self._update_states_from_comp_backend( - user_id, project_id, iteration, dag, pipeline_params=pipeline_params - ) - # 2. Any task following a FAILED task shall be ABORTED - comp_tasks = await self._set_states_following_failed_to_aborted( - project_id, dag - ) - # 3. do we want to stop the pipeline now? - if pipeline_params.mark_for_cancellation: - await self._schedule_tasks_to_stop( - user_id, project_id, comp_tasks, pipeline_params + with log_context( + _logger, + level=logging.INFO, + msg=f"scheduling pipeline {user_id=}:{project_id=}:{iteration=}", + ): + dag: nx.DiGraph = nx.DiGraph() + try: + dag = await self._get_pipeline_dag(project_id) + # 1. Update our list of tasks with data from backend (state, results) + await self._update_states_from_comp_backend( + user_id, project_id, iteration, dag, pipeline_params=pipeline_params ) - else: - # let's get the tasks to schedule then - comp_tasks = await self._schedule_tasks_to_start( - user_id=user_id, - project_id=project_id, - comp_tasks=comp_tasks, - dag=dag, - pipeline_params=pipeline_params, + # 2. Any task following a FAILED task shall be ABORTED + comp_tasks = await self._set_states_following_failed_to_aborted( + project_id, dag + ) + # 3. do we want to stop the pipeline now? + if pipeline_params.mark_for_cancellation: + await self._schedule_tasks_to_stop( + user_id, project_id, comp_tasks, pipeline_params + ) + else: + # let's get the tasks to schedule then + comp_tasks = await self._schedule_tasks_to_start( + user_id=user_id, + project_id=project_id, + comp_tasks=comp_tasks, + dag=dag, + pipeline_params=pipeline_params, + ) + # 4. timeout if waiting for cluster has been there for more than X minutes + comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( + user_id, project_id, comp_tasks + ) + # 5. send a heartbeat + await self._send_running_tasks_heartbeat( + user_id, project_id, iteration, dag ) - # 4. timeout if waiting for cluster has been there for more than X minutes - comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( - user_id, project_id, comp_tasks - ) - # 5. send a heartbeat - await self._send_running_tasks_heartbeat( - user_id, project_id, iteration, dag - ) - # 6. Update the run result - pipeline_result = await self._update_run_result_from_tasks( - user_id, project_id, iteration, comp_tasks - ) + # 6. Update the run result + pipeline_result = await self._update_run_result_from_tasks( + user_id, project_id, iteration, comp_tasks + ) - # 7. Are we done scheduling that pipeline? - if not dag.nodes() or pipeline_result in COMPLETED_STATES: - # there is nothing left, the run is completed, we're done here + # 7. Are we done scheduling that pipeline? + if not dag.nodes() or pipeline_result in COMPLETED_STATES: + # there is nothing left, the run is completed, we're done here + self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + _logger.info( + "pipeline %s scheduling completed with result %s", + f"{project_id=}", + f"{pipeline_result=}", + ) + assert pipeline_params.scheduler_task is not None # nosec + pipeline_params.scheduler_task.cancel() + except PipelineNotFoundError: + _logger.warning( + "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", + f"{project_id=}", + ) + await self._set_run_result( + user_id, project_id, iteration, RunningState.ABORTED + ) self.scheduled_pipelines.pop((user_id, project_id, iteration), None) - _logger.info( - "pipeline %s scheduling completed with result %s", + except InvalidPipelineError as exc: + _logger.warning( + "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", f"{project_id=}", - f"{pipeline_result=}", + exc, ) - except PipelineNotFoundError: - _logger.warning( - "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", - f"{project_id=}", - ) - await self._set_run_result( - user_id, project_id, iteration, RunningState.ABORTED - ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) - except InvalidPipelineError as exc: - _logger.warning( - "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", - f"{project_id=}", - exc, - ) - await self._set_run_result( - user_id, project_id, iteration, RunningState.ABORTED - ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) - except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): - _logger.exception( - "Unexpected error while connecting with computational backend, aborting pipeline" - ) - tasks: dict[NodeIDStr, CompTaskAtDB] = await self._get_pipeline_tasks( - project_id, dag - ) - comp_tasks_repo = CompTasksRepository(self.db_engine) - await comp_tasks_repo.update_project_tasks_state( - project_id, - [t.node_id for t in tasks.values()], - RunningState.FAILED, - ) - await self._set_run_result( - user_id, project_id, iteration, RunningState.FAILED - ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) - except ComputationalBackendNotConnectedError: - _logger.exception("Computational backend is not connected!") + await self._set_run_result( + user_id, project_id, iteration, RunningState.ABORTED + ) + self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): + _logger.exception( + "Unexpected error while connecting with computational backend, aborting pipeline" + ) + tasks: dict[NodeIDStr, CompTaskAtDB] = await self._get_pipeline_tasks( + project_id, dag + ) + comp_tasks_repo = CompTasksRepository(self.db_engine) + await comp_tasks_repo.update_project_tasks_state( + project_id, + [t.node_id for t in tasks.values()], + RunningState.FAILED, + ) + await self._set_run_result( + user_id, project_id, iteration, RunningState.FAILED + ) + self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + except ComputationalBackendNotConnectedError: + _logger.exception("Computational backend is not connected!") async def _schedule_tasks_to_stop( self, @@ -910,6 +990,3 @@ async def _timeout_if_waiting_for_cluster_too_long( log_level=logging.ERROR, ) return comp_tasks - - def _wake_up_scheduler_now(self) -> None: - self.wake_up_event.set() diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py index 51fb3b1a3fb0..512df1b17124 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py @@ -124,7 +124,7 @@ async def _start_tasks( cluster_id=pipeline_params.cluster_id, tasks={node_id: task.image}, hardware_info=task.hardware_info, - callback=self._wake_up_scheduler_now, + callback=pipeline_params.wake_up, metadata=pipeline_params.run_metadata, ) for node_id, task in scheduled_tasks.items() diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index f8b648eaf489..4f7812816ccd 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -2,6 +2,8 @@ from fastapi import FastAPI from models_library.clusters import DEFAULT_CLUSTER_ID +from servicelib.logging_utils import log_context +from settings_library.redis import RedisDatabase from ...core.errors import ConfigurationError from ...core.settings import AppSettings @@ -10,10 +12,11 @@ from ..dask_clients_pool import DaskClientsPool from ..db.repositories.comp_runs import CompRunsRepository from ..rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client +from ..redis import get_redis_client_manager from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams from ._dask_scheduler import DaskScheduler -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) async def create_from_db(app: FastAPI) -> BaseCompScheduler: @@ -28,29 +31,32 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: filter_by_state=SCHEDULED_STATES ) - logger.debug( + _logger.debug( "Following scheduled comp_runs found still to be scheduled: %s", runs if runs else "NONE", ) - logger.info("Creating Dask-based scheduler...") - app_settings: AppSettings = app.state.settings - return DaskScheduler( - settings=app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, - dask_clients_pool=DaskClientsPool.instance(app), - rabbitmq_client=get_rabbitmq_client(app), - rabbitmq_rpc_client=get_rabbitmq_rpc_client(app), - db_engine=db_engine, - scheduled_pipelines={ - (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams( - cluster_id=( - r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID - ), - run_metadata=r.metadata, - mark_for_cancellation=r.cancelled, - use_on_demand_clusters=r.use_on_demand_clusters, - ) - for r in runs - }, - service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, - ) + with log_context( + _logger, logging.INFO, msg="Creating Dask-based computational scheduler" + ): + app_settings: AppSettings = app.state.settings + return DaskScheduler( + settings=app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, + dask_clients_pool=DaskClientsPool.instance(app), + rabbitmq_client=get_rabbitmq_client(app), + rabbitmq_rpc_client=get_rabbitmq_rpc_client(app), + redis_client=get_redis_client_manager(app).client(RedisDatabase.LOCKS), + db_engine=db_engine, + scheduled_pipelines={ + (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams( + cluster_id=( + r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID + ), + run_metadata=r.metadata, + mark_for_cancellation=r.cancelled, + use_on_demand_clusters=r.use_on_demand_clusters, + ) + for r in runs + }, + service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, + ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_task.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_task.py deleted file mode 100644 index 989b310687cd..000000000000 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_task.py +++ /dev/null @@ -1,51 +0,0 @@ -import datetime -import logging -from collections.abc import Callable, Coroutine -from typing import Any, Final - -from fastapi import FastAPI -from servicelib.background_task import start_periodic_task, stop_periodic_task -from servicelib.logging_utils import log_context -from servicelib.redis import RedisClientsManager -from servicelib.redis_utils import exclusive -from settings_library.redis import RedisDatabase - -from ..._meta import APP_NAME -from . import _scheduler_factory - -_logger = logging.getLogger(__name__) - -_COMPUTATIONAL_SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta( - seconds=5 -) -_TASK_NAME: Final[str] = "computational services scheduler" - - -def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: - async def start_scheduler() -> None: - with log_context( - _logger, level=logging.INFO, msg="starting computational scheduler" - ): - redis_clients_manager: RedisClientsManager = app.state.redis_clients_manager - lock_key = f"{APP_NAME}:computational_scheduler" - app.state.scheduler = scheduler = await _scheduler_factory.create_from_db( - app - ) - app.state.computational_scheduler_task = start_periodic_task( - exclusive( - redis_clients_manager.client(RedisDatabase.LOCKS), - lock_key=lock_key, - )(scheduler.schedule_all_pipelines), - interval=_COMPUTATIONAL_SCHEDULER_INTERVAL, - task_name=_TASK_NAME, - early_wake_up_event=scheduler.wake_up_event, - ) - - return start_scheduler - - -def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: - async def stop_scheduler() -> None: - await stop_periodic_task(app.state.computational_scheduler_task) - - return stop_scheduler diff --git a/services/director-v2/src/simcore_service_director_v2/modules/redis.py b/services/director-v2/src/simcore_service_director_v2/modules/redis.py index e7da01afef75..273061cb188a 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/redis.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/redis.py @@ -1,3 +1,5 @@ +from typing import cast + from fastapi import FastAPI from servicelib.redis import RedisClientsManager, RedisManagerDBConfig from settings_library.redis import RedisDatabase @@ -29,3 +31,7 @@ async def on_shutdown() -> None: app.add_event_handler("startup", on_startup) app.add_event_handler("shutdown", on_shutdown) + + +def get_redis_client_manager(app: FastAPI) -> RedisClientsManager: + return cast(RedisClientsManager, app.state.redis_clients_manager) diff --git a/services/director-v2/tests/unit/_helpers.py b/services/director-v2/tests/unit/_helpers.py index 2654c63a3e1f..779d6cdd1179 100644 --- a/services/director-v2/tests/unit/_helpers.py +++ b/services/director-v2/tests/unit/_helpers.py @@ -1,4 +1,3 @@ -import asyncio from dataclasses import dataclass from typing import Any @@ -11,9 +10,6 @@ from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.comp_runs import CompRunsAtDB from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB -from simcore_service_director_v2.modules.comp_scheduler._base_scheduler import ( - BaseCompScheduler, -) @dataclass @@ -28,13 +24,6 @@ class RunningProject(PublishedProject): runs: CompRunsAtDB -async def trigger_comp_scheduler(scheduler: BaseCompScheduler) -> None: - # trigger the scheduler - scheduler._wake_up_scheduler_now() # pylint: disable=protected-access # noqa: SLF001 - # let the scheduler be actually triggered - await asyncio.sleep(1) - - async def set_comp_task_state( aiopg_engine: aiopg.sa.engine.Engine, node_id: str, state: StateType ) -> None: diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index f9e5ff33c4b3..1df1ae09d39e 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -45,6 +45,7 @@ from pytest_mock.plugin import MockerFixture from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.rabbitmq import RabbitMQClient +from servicelib.redis import CouldNotAcquireLockError from settings_library.rabbit import RabbitSettings from settings_library.redis import RedisSettings from simcore_postgres_database.models.comp_runs import comp_runs @@ -66,8 +67,12 @@ from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB, Image from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState -from simcore_service_director_v2.modules.comp_scheduler._base_scheduler import ( +from simcore_service_director_v2.modules.comp_scheduler import ( BaseCompScheduler, + get_scheduler, +) +from simcore_service_director_v2.modules.comp_scheduler._base_scheduler import ( + ScheduledPipelineParams, ) from simcore_service_director_v2.modules.comp_scheduler._dask_scheduler import ( DaskScheduler, @@ -155,8 +160,38 @@ async def _assert_comp_tasks_db( ), f"{expected_progress=}, found: {[t.progress for t in tasks]}" -async def run_comp_scheduler(scheduler: BaseCompScheduler) -> None: - await scheduler.schedule_all_pipelines() +async def schedule_all_pipelines(scheduler: BaseCompScheduler) -> None: + # NOTE: we take a copy of the pipelines, as this could change quickly if there are + # misconfigured pipelines that would be removed from the scheduler + # NOTE: we simulate multiple dv-2 replicas by running several times + # the same pipeline scheduling + local_pipelines = deepcopy(scheduler.scheduled_pipelines) + results = await asyncio.gather( + *( + scheduler._schedule_pipeline( # noqa: SLF001 + user_id=user_id, + project_id=project_id, + iteration=iteration, + pipeline_params=params, + ) + for _ in range(3) + for ( + user_id, + project_id, + iteration, + ), params in local_pipelines.items() + ), + return_exceptions=True, + ) + # we should have exceptions 2/3 of the time + could_not_acquire_lock_count = sum( + isinstance(r, CouldNotAcquireLockError) for r in results + ) + total_results_count = len(results) + + # Check if 2/3 of the results are CouldNotAcquireLockError + # checks that scheduling is done exclusively + assert could_not_acquire_lock_count == (2 / 3) * total_results_count @pytest.fixture @@ -185,11 +220,11 @@ def minimal_dask_scheduler_config( def scheduler( minimal_dask_scheduler_config: None, aiopg_engine: aiopg.sa.engine.Engine, - # dask_spec_local_cluster: SpecCluster, minimal_app: FastAPI, ) -> BaseCompScheduler: - assert minimal_app.state.scheduler is not None - return minimal_app.state.scheduler + scheduler = get_scheduler(minimal_app) + assert scheduler is not None + return scheduler @pytest.fixture @@ -220,16 +255,21 @@ def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture -def with_disabled_scheduler_task(mocker: MockerFixture) -> None: +def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.MagicMock: """disables the scheduler task, note that it needs to be triggered manually then""" - mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._task.start_periodic_task", - autospec=True, - ) - mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._task.stop_periodic_task", + def _fake_starter( + self: BaseCompScheduler, + pipeline_params: ScheduledPipelineParams, + *args, + **kwargs, + ) -> None: + pipeline_params.scheduler_task = mocker.MagicMock() + + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._base_scheduler.BaseCompScheduler._start_scheduling", autospec=True, + side_effect=_fake_starter, ) @@ -258,7 +298,7 @@ async def test_scheduler_gracefully_starts_and_stops( minimal_app: FastAPI, ): # check it started correctly - assert minimal_app.state.computational_scheduler_task is not None + assert get_scheduler(minimal_app) is not None @pytest.mark.parametrize( @@ -287,7 +327,7 @@ def test_scheduler_raises_exception_for_missing_dependencies( async def test_empty_pipeline_is_not_scheduled( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, scheduler: BaseCompScheduler, registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], @@ -319,9 +359,6 @@ async def test_empty_pipeline_is_not_scheduled( use_on_demand_clusters=False, ) assert len(scheduler.scheduled_pipelines) == 0 - assert ( - scheduler.wake_up_event.is_set() is False - ), "the scheduler was woken up on an empty pipeline!" # check the database is empty async with aiopg_engine.acquire() as conn: result = await conn.scalar( @@ -334,7 +371,7 @@ async def test_empty_pipeline_is_not_scheduled( async def test_misconfigured_pipeline_is_not_scheduled( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, scheduler: BaseCompScheduler, registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], @@ -361,9 +398,6 @@ async def test_misconfigured_pipeline_is_not_scheduled( use_on_demand_clusters=False, ) assert len(scheduler.scheduled_pipelines) == 1 - assert ( - scheduler.wake_up_event.is_set() is True - ), "the scheduler was NOT woken up on the scheduled pipeline!" for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): assert u_id == user["id"] assert p_id == sleepers_project.uuid @@ -380,7 +414,7 @@ async def test_misconfigured_pipeline_is_not_scheduled( run_entry = CompRunsAtDB.parse_obj(await result.first()) assert run_entry.result == RunningState.PUBLISHED # let the scheduler kick in - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # check the scheduled pipelines is again empty since it's misconfigured assert len(scheduler.scheduled_pipelines) == 0 # check the database entry is correctly updated @@ -412,9 +446,6 @@ async def _assert_start_pipeline( use_on_demand_clusters=False, ) assert len(scheduler.scheduled_pipelines) == 1, "the pipeline is not scheduled!" - assert ( - scheduler.wake_up_event.is_set() is True - ), "the scheduler was NOT woken up on the scheduled pipeline!" for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): assert u_id == published_project.project.prj_owner assert p_id == published_project.project.uuid @@ -434,7 +465,7 @@ async def _assert_start_pipeline( return exp_published_tasks -async def _assert_schedule_pipeline_PENDING( +async def _assert_schedule_pipeline_PENDING( # noqa: N802 aiopg_engine, published_project: PublishedProject, published_tasks: list[CompTaskAtDB], @@ -452,7 +483,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] return [DaskClientTaskState.PENDING for job_id in job_ids] mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) _assert_dask_client_correctly_initialized(mocked_dask_client, scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) await _assert_comp_tasks_db( @@ -471,6 +502,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] expected_progress=None, # since we bypass the API entrypoint this is correct ) # tasks were send to the backend + assert published_project.project.prj_owner is not None mocked_dask_client.send_computation_tasks.assert_has_calls( calls=[ mock.call( @@ -478,7 +510,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, tasks={f"{p.node_id}": p.image}, - callback=scheduler._wake_up_scheduler_now, # noqa: SLF001 + callback=mock.ANY, metadata=mock.ANY, hardware_info=mock.ANY, ) @@ -490,7 +522,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] mocked_dask_client.get_tasks_status.assert_not_called() mocked_dask_client.get_task_result.assert_not_called() # there is a second run of the scheduler to move comp_runs to pending, the rest does not change - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) await _assert_comp_tasks_db( aiopg_engine, @@ -616,7 +648,7 @@ async def _trigger_progress_event( @pytest.mark.acceptance_test() async def test_proper_pipeline_is_scheduled( # noqa: PLR0915 - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler, aiopg_engine: aiopg.sa.engine.Engine, @@ -661,7 +693,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) await _assert_comp_tasks_db( @@ -707,7 +739,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta node_id=exp_started_task.node_id, ) - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # comp_run, the comp_task switch to STARTED await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) await _assert_comp_tasks_db( @@ -771,7 +803,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData: return TaskOutputData.parse_obj({"out_1": None, "out_2": 45}) mocked_dask_client.get_task_result.side_effect = _return_random_task_result - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) await _assert_comp_tasks_db( aiopg_engine, @@ -819,7 +851,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData: tasks={ f"{next_pending_task.node_id}": next_pending_task.image, }, - callback=scheduler._wake_up_scheduler_now, # noqa: SLF001 + callback=mock.ANY, metadata=mock.ANY, hardware_info=mock.ANY, ) @@ -866,7 +898,7 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta project_id=exp_started_task.project_id, node_id=exp_started_task.node_id, ) - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) await _assert_comp_tasks_db( aiopg_engine, @@ -908,7 +940,7 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskStat mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed mocked_dask_client.get_task_result.side_effect = None - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) await _assert_comp_tasks_db( aiopg_engine, @@ -955,7 +987,7 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_task_result.side_effect = _return_random_task_result # trigger the scheduler, it should switch to FAILED, as we are done - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) await _assert_comp_tasks_db( @@ -991,7 +1023,7 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta async def test_task_progress_triggers( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler, aiopg_engine: aiopg.sa.engine.Engine, @@ -1054,7 +1086,7 @@ async def test_task_progress_triggers( ], ) async def test_handling_of_disconnected_dask_scheduler( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler, aiopg_engine: aiopg.sa.engine.Engine, @@ -1098,7 +1130,7 @@ async def test_handling_of_disconnected_dask_scheduler( project_id=published_project.project.uuid, ) # we ensure the scheduler was run - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # after this step the tasks are marked as ABORTED await _assert_comp_tasks_db( aiopg_engine, @@ -1112,7 +1144,7 @@ async def test_handling_of_disconnected_dask_scheduler( expected_progress=1, ) # then we have another scheduler run - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # now the run should be ABORTED await _assert_comp_run_db(aiopg_engine, published_project, RunningState.ABORTED) @@ -1197,7 +1229,7 @@ class RebootState: ], ) async def test_handling_scheduling_after_reboot( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, mocked_dask_client: mock.MagicMock, aiopg_engine: aiopg.sa.engine.Engine, running_project: RunningProject, @@ -1222,7 +1254,7 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: mocked_dask_client.get_task_result.side_effect = mocked_get_task_result - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # the status will be called once for all RUNNING tasks mocked_dask_client.get_tasks_status.assert_called_once() if reboot_state.expected_run_state in COMPLETED_STATES: @@ -1279,7 +1311,7 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: async def test_handling_cancellation_of_jobs_after_reboot( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, mocked_dask_client: mock.MagicMock, aiopg_engine: aiopg.sa.engine.Engine, running_project_mark_for_cancellation: RunningProject, @@ -1309,7 +1341,7 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status # Running the scheduler, should actually cancel the run now - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) mocked_dask_client.abort_computation_task.assert_called() assert mocked_dask_client.abort_computation_task.call_count == len( [ @@ -1346,7 +1378,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData: raise TaskCancelledError mocked_dask_client.get_task_result.side_effect = _return_random_task_result - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) # now should be stopped await _assert_comp_tasks_db( aiopg_engine, @@ -1373,7 +1405,7 @@ def with_fast_service_heartbeat_s(monkeypatch: pytest.MonkeyPatch) -> int: async def test_running_pipeline_triggers_heartbeat( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, with_fast_service_heartbeat_s: int, mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler, @@ -1420,7 +1452,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta project_id=exp_started_task.project_id, node_id=exp_started_task.node_id, ) - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, @@ -1432,8 +1464,8 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 3. wait a bit and run again we should get another heartbeat, but only one! await asyncio.sleep(with_fast_service_heartbeat_s + 1) - await run_comp_scheduler(scheduler) - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) + await schedule_all_pipelines(scheduler) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -1444,8 +1476,8 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 4. wait a bit and run again we should get another heartbeat, but only one! await asyncio.sleep(with_fast_service_heartbeat_s + 1) - await run_comp_scheduler(scheduler) - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) + await schedule_all_pipelines(scheduler) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -1463,7 +1495,7 @@ async def mocked_get_or_create_cluster(mocker: MockerFixture) -> mock.Mock: async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, scheduler: BaseCompScheduler, aiopg_engine: aiopg.sa.engine.Engine, published_project: PublishedProject, @@ -1501,7 +1533,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( published_project.tasks[1], published_project.tasks[3], ] - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() @@ -1516,7 +1548,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( expected_progress=None, ) # again will trigger the same response - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() @@ -1537,7 +1569,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( [ClustersKeeperNotAvailableError], ) async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( - with_disabled_scheduler_task: None, + with_disabled_auto_scheduling: None, scheduler: BaseCompScheduler, aiopg_engine: aiopg.sa.engine.Engine, published_project: PublishedProject, @@ -1570,7 +1602,7 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( published_project.tasks[1], published_project.tasks[3], ] - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() @@ -1583,7 +1615,7 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( expected_progress=1.0, ) # again will not re-trigger the call to clusters-keeper - await run_comp_scheduler(scheduler) + await schedule_all_pipelines(scheduler) mocked_get_or_create_cluster.assert_not_called() await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) await _assert_comp_tasks_db(