diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/clusters_keeper/clusters.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/clusters_keeper/clusters.py index ada0c66d26d9..ca409aa7e651 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/clusters_keeper/clusters.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/clusters_keeper/clusters.py @@ -6,6 +6,7 @@ from models_library.rabbitmq_basic_types import RPCMethodName from models_library.users import UserID from models_library.wallets import WalletID +from pydantic import TypeAdapter from ....async_utils import run_sequentially_in_context from ..._client_rpc import RabbitMQRPCClient @@ -13,6 +14,10 @@ _TTL_CACHE_ON_CLUSTERS_S: Final[int] = 5 +_GET_OR_CREATE_CLUSTER_METHOD_NAME: Final[RPCMethodName] = TypeAdapter( + RPCMethodName +).validate_python("get_or_create_cluster") + @run_sequentially_in_context(target_args=["user_id", "wallet_id"]) @cached( @@ -32,7 +37,7 @@ async def get_or_create_cluster( # the 2nd decorator ensure that many calls in a short time will return quickly the same value on_demand_cluster: OnDemandCluster = await client.request( CLUSTERS_KEEPER_RPC_NAMESPACE, - RPCMethodName("get_or_create_cluster"), + _GET_OR_CREATE_CLUSTER_METHOD_NAME, timeout_s=RPC_REMOTE_METHOD_TIMEOUT_S, user_id=user_id, wallet_id=wallet_id, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py index e1a37378dd1f..b14e7e4f0cbe 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -589,6 +589,12 @@ async def _process_executing_tasks( ) -> None: """process executing tasks from the 3rd party backend""" + @abstractmethod + async def _release_resources( + self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB + ) -> None: + """release resources used by the scheduler for a given user and project""" + async def apply( self, *, @@ -654,6 +660,7 @@ async def apply( # 7. Are we done scheduling that pipeline? if not dag.nodes() or pipeline_result in COMPLETED_STATES: + await self._release_resources(user_id, project_id, comp_run) # there is nothing left, the run is completed, we're done here _logger.info( "pipeline %s scheduling completed with result %s", diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 39a19d7cc5f3..213167693431 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any +from typing import Any, Final import arrow from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -23,7 +23,7 @@ from models_library.users import UserID from pydantic import PositiveInt from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE -from servicelib.logging_utils import log_catch +from servicelib.logging_utils import log_catch, log_context from ...core.errors import ( ComputationalBackendNotConnectedError, @@ -56,6 +56,8 @@ _logger = logging.getLogger(__name__) +_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{run_id}" + @asynccontextmanager async def _cluster_dask_client( @@ -63,6 +65,7 @@ async def _cluster_dask_client( scheduler: "DaskScheduler", *, use_on_demand_clusters: bool, + run_id: PositiveInt, run_metadata: RunMetadataDict, ) -> AsyncIterator[DaskClient]: cluster: BaseCluster = scheduler.settings.default_cluster @@ -72,7 +75,9 @@ async def _cluster_dask_client( user_id=user_id, wallet_id=run_metadata.get("wallet_id"), ) - async with scheduler.dask_clients_pool.acquire(cluster) as client: + async with scheduler.dask_clients_pool.acquire( + cluster, ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=run_id) + ) as client: yield client @@ -101,6 +106,7 @@ async def _start_tasks( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: # Change the tasks state to PENDING @@ -151,6 +157,7 @@ async def _get_tasks_status( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: return await client.get_tasks_status([f"{t.job_id}" for t in tasks]) @@ -171,6 +178,7 @@ async def _process_executing_tasks( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: task_progresses = await client.get_tasks_progress( @@ -217,6 +225,22 @@ async def _process_executing_tasks( ) ) + async def _release_resources( + self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB + ) -> None: + """release resources used by the scheduler for a given user and project""" + with ( + log_catch(_logger, reraise=False), + log_context( + _logger, + logging.INFO, + msg=f"releasing resources for {user_id=}, {project_id=}, {comp_run.run_id=}", + ), + ): + await self.dask_clients_pool.release_client_ref( + ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=comp_run.run_id) + ) + async def _stop_tasks( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: @@ -226,6 +250,7 @@ async def _stop_tasks( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: await asyncio.gather( @@ -259,6 +284,7 @@ async def _process_completed_tasks( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: tasks_results = await asyncio.gather( @@ -278,6 +304,7 @@ async def _process_completed_tasks( user_id, self, use_on_demand_clusters=comp_run.use_on_demand_clusters, + run_id=comp_run.run_id, run_metadata=comp_run.metadata, ) as client: await asyncio.gather( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py index 6ff8714811f6..fd09e3db015d 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_client.py @@ -120,43 +120,42 @@ async def create( tasks_file_link_type: FileLinkType, cluster_type: ClusterTypeInModel, ) -> "DaskClient": - _logger.info( - "Initiating connection to %s with auth: %s, type: %s", - f"dask-scheduler at {endpoint}", - authentication, - cluster_type, - ) - async for attempt in AsyncRetrying( - reraise=True, - before_sleep=before_sleep_log(_logger, logging.INFO), - wait=wait_fixed(0.3), - stop=stop_after_attempt(3), + with log_context( + _logger, + logging.INFO, + msg=f"create dask client to dask-scheduler at {endpoint=} with {authentication=}, {cluster_type=}", ): - with attempt: - _logger.debug( - "Connecting to %s, attempt %s...", - endpoint, - attempt.retry_state.attempt_number, - ) - backend = await connect_to_dask_scheduler(endpoint, authentication) - dask_utils.check_scheduler_status(backend.client) - instance = cls( - app=app, - backend=backend, - settings=settings, - tasks_file_link_type=tasks_file_link_type, - cluster_type=cluster_type, - ) - _logger.info( - "Connection to %s succeeded [%s]", - f"dask-scheduler at {endpoint}", - json_dumps(attempt.retry_state.retry_object.statistics), - ) - _logger.info( - "Scheduler info:\n%s", - json_dumps(backend.client.scheduler_info(), indent=2), - ) - return instance + async for attempt in AsyncRetrying( + reraise=True, + before_sleep=before_sleep_log(_logger, logging.INFO), + wait=wait_fixed(0.3), + stop=stop_after_attempt(3), + ): + with attempt: + _logger.debug( + "Connecting to %s, attempt %s...", + endpoint, + attempt.retry_state.attempt_number, + ) + backend = await connect_to_dask_scheduler(endpoint, authentication) + dask_utils.check_scheduler_status(backend.client) + instance = cls( + app=app, + backend=backend, + settings=settings, + tasks_file_link_type=tasks_file_link_type, + cluster_type=cluster_type, + ) + _logger.info( + "Connection to %s succeeded [%s]", + f"dask-scheduler at {endpoint}", + json_dumps(attempt.retry_state.retry_object.statistics), + ) + _logger.info( + "Scheduler info:\n%s", + json_dumps(backend.client.scheduler_info(), indent=2), + ) + return instance # this is to satisfy pylance err_msg = "Could not create client" raise ValueError(err_msg) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py b/services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py index 31177b5a6162..b4f75c68d725 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py @@ -1,5 +1,6 @@ import asyncio import logging +from collections import defaultdict from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -8,6 +9,7 @@ from fastapi import FastAPI from models_library.clusters import BaseCluster, ClusterTypeInModel from pydantic import AnyUrl +from servicelib.logging_utils import log_context from ..core.errors import ( ComputationalBackendNotConnectedError, @@ -19,10 +21,11 @@ from ..utils.dask_client_utils import TaskHandlers from .dask_client import DaskClient -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) _ClusterUrl: TypeAlias = AnyUrl +ClientRef: TypeAlias = str @dataclass @@ -32,6 +35,11 @@ class DaskClientsPool: _client_acquisition_lock: asyncio.Lock = field(init=False) _cluster_to_client_map: dict[_ClusterUrl, DaskClient] = field(default_factory=dict) _task_handlers: TaskHandlers | None = None + # Track references to each client by endpoint + _client_to_refs: defaultdict[_ClusterUrl, set[ClientRef]] = field( + default_factory=lambda: defaultdict(set) + ) + _ref_to_clients: dict[ClientRef, _ClusterUrl] = field(default_factory=dict) def __post_init__(self): # NOTE: to ensure the correct loop is used @@ -59,54 +67,102 @@ async def delete(self) -> None: *[client.delete() for client in self._cluster_to_client_map.values()], return_exceptions=True, ) + self._cluster_to_client_map.clear() + self._client_to_refs.clear() + self._ref_to_clients.clear() + + async def release_client_ref(self, ref: ClientRef) -> None: + """Release a dask client reference by its ref. + + If all the references to the client are released, + the client will be deleted from the pool. + This method is thread-safe and can be called concurrently. + """ + async with self._client_acquisition_lock: + # Find which endpoint this ref belongs to + if cluster_endpoint := self._ref_to_clients.pop(ref, None): + # we have a client, remove our reference and check if there are any more references + assert ref in self._client_to_refs[cluster_endpoint] # nosec + self._client_to_refs[cluster_endpoint].discard(ref) + + # If we found an endpoint with no more refs, clean it up + if not self._client_to_refs[cluster_endpoint] and ( + dask_client := self._cluster_to_client_map.pop( + cluster_endpoint, None + ) + ): + _logger.info( + "Last reference to client %s released, deleting client", + cluster_endpoint, + ) + await dask_client.delete() + _logger.debug( + "Remaining clients: %s", + [f"{k}" for k in self._cluster_to_client_map], + ) @asynccontextmanager - async def acquire(self, cluster: BaseCluster) -> AsyncIterator[DaskClient]: + async def acquire( + self, cluster: BaseCluster, *, ref: ClientRef + ) -> AsyncIterator[DaskClient]: + """Returns a dask client for the given cluster. + + This method is thread-safe and can be called concurrently. + If the cluster is not found in the pool, it will create a new dask client for it. + + The passed reference is used to track the client usage, user should call + `release_client_ref` to release the client reference when done. + """ + async def _concurently_safe_acquire_client() -> DaskClient: async with self._client_acquisition_lock: - dask_client = self._cluster_to_client_map.get(cluster.endpoint) - - # we create a new client if that cluster was never used before - logger.debug( - "acquiring connection to cluster %s:%s", - cluster.endpoint, - cluster.name, - ) - if not dask_client: - tasks_file_link_type = ( - self.settings.COMPUTATIONAL_BACKEND_DEFAULT_FILE_LINK_TYPE - ) - if cluster == self.settings.default_cluster: + with log_context( + _logger, + logging.DEBUG, + f"acquire dask client for {cluster.name=}:{cluster.endpoint}", + ): + dask_client = self._cluster_to_client_map.get(cluster.endpoint) + if not dask_client: tasks_file_link_type = ( - self.settings.COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_FILE_LINK_TYPE + self.settings.COMPUTATIONAL_BACKEND_DEFAULT_FILE_LINK_TYPE ) - if cluster.type == ClusterTypeInModel.ON_DEMAND.value: - tasks_file_link_type = ( - self.settings.COMPUTATIONAL_BACKEND_ON_DEMAND_CLUSTERS_FILE_LINK_TYPE + if cluster == self.settings.default_cluster: + tasks_file_link_type = ( + self.settings.COMPUTATIONAL_BACKEND_DEFAULT_CLUSTER_FILE_LINK_TYPE + ) + if cluster.type == ClusterTypeInModel.ON_DEMAND.value: + tasks_file_link_type = ( + self.settings.COMPUTATIONAL_BACKEND_ON_DEMAND_CLUSTERS_FILE_LINK_TYPE + ) + self._cluster_to_client_map[cluster.endpoint] = dask_client = ( + await DaskClient.create( + app=self.app, + settings=self.settings, + endpoint=cluster.endpoint, + authentication=cluster.authentication, + tasks_file_link_type=tasks_file_link_type, + cluster_type=cluster.type, + ) ) - self._cluster_to_client_map[ - cluster.endpoint - ] = dask_client = await DaskClient.create( - app=self.app, - settings=self.settings, - endpoint=cluster.endpoint, - authentication=cluster.authentication, - tasks_file_link_type=tasks_file_link_type, - cluster_type=cluster.type, - ) - if self._task_handlers: - dask_client.register_handlers(self._task_handlers) + if self._task_handlers: + dask_client.register_handlers(self._task_handlers) - logger.debug("created new client to cluster %s", f"{cluster=}") - logger.debug( - "list of clients: %s", f"{self._cluster_to_client_map=}" + # Track the reference + self._client_to_refs[cluster.endpoint].add(ref) + self._ref_to_clients[ref] = cluster.endpoint + + _logger.debug( + "Client %s now has %d references", + cluster.endpoint, + len(self._client_to_refs[cluster.endpoint]), ) - assert dask_client # nosec - return dask_client + assert dask_client # nosec + return dask_client try: dask_client = await _concurently_safe_acquire_client() + except Exception as exc: raise DaskClientAcquisisitonError(cluster=cluster, error=exc) from exc @@ -129,7 +185,7 @@ async def on_startup() -> None: app=app, settings=settings ) - logger.info( + _logger.info( "Default cluster is set to %s", f"{settings.default_cluster!r}", ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index 19989577215b..f623663b4aa3 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -282,7 +282,7 @@ async def list_for_user__only_latest_iterations( ComputationRunRpcGet.model_validate( { **row, - "state": DB_TO_RUNNING_STATE[row["state"]], + "state": DB_TO_RUNNING_STATE[row.state], } ) async for row in await conn.stream(list_query) diff --git a/services/director-v2/tests/unit/test_modules_dask_clients_pool.py b/services/director-v2/tests/unit/test_modules_dask_clients_pool.py index d3c6274fa7c4..b0c982647dd6 100644 --- a/services/director-v2/tests/unit/test_modules_dask_clients_pool.py +++ b/services/director-v2/tests/unit/test_modules_dask_clients_pool.py @@ -164,11 +164,11 @@ async def test_dask_clients_pool_acquisition_creates_client_on_demand( cluster_type=ClusterTypeInModel.ON_PREMISE, ) ) - async with clients_pool.acquire(cluster): + async with clients_pool.acquire(cluster, ref=f"test-ref-{cluster.name}"): # on start it is created mocked_dask_client.create.assert_has_calls(mocked_creation_calls) - async with clients_pool.acquire(cluster): + async with clients_pool.acquire(cluster, ref=f"test-ref-{cluster.name}-2"): # the connection already exists, so there is no new call to create mocked_dask_client.create.assert_has_calls(mocked_creation_calls) @@ -196,7 +196,9 @@ async def test_acquiring_wrong_cluster_raises_exception( non_existing_cluster = fake_clusters(1)[0] with pytest.raises(DaskClientAcquisisitonError): - async with clients_pool.acquire(non_existing_cluster): + async with clients_pool.acquire( + non_existing_cluster, ref="test-non-existing-ref" + ): ... @@ -239,7 +241,9 @@ async def test_acquire_default_cluster( dask_scheduler_settings = the_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND default_cluster = dask_scheduler_settings.default_cluster assert default_cluster - async with dask_clients_pool.acquire(default_cluster) as dask_client: + async with dask_clients_pool.acquire( + default_cluster, ref="test-default-cluster-ref" + ) as dask_client: def just_a_quick_fct(x, y): return x + y @@ -252,3 +256,63 @@ def just_a_quick_fct(x, y): assert future result = await future.result(timeout=10) assert result == 35 + + +async def test_dask_clients_pool_reference_counting( + minimal_dask_config: None, + mocker: MockerFixture, + client: TestClient, + fake_clusters: Callable[[int], list[BaseCluster]], +): + """Test that the reference counting mechanism works correctly.""" + assert client.app + the_app = cast(FastAPI, client.app) + mocked_dask_client = mocker.patch( + "simcore_service_director_v2.modules.dask_clients_pool.DaskClient", + autospec=True, + ) + mocked_dask_client.create.return_value = mocked_dask_client + clients_pool = DaskClientsPool.instance(the_app) + + # Create a cluster + cluster = fake_clusters(1)[0] + + # Acquire the client with first reference + ref1 = "test-ref-1" + async with clients_pool.acquire(cluster, ref=ref1): + # Client should be created + mocked_dask_client.create.assert_called_once() + # Reset the mock to check the next call + mocked_dask_client.create.reset_mock() + mocked_dask_client.delete.assert_not_called() + + # calling again with the same reference should not create a new client + async with clients_pool.acquire(cluster, ref=ref1): + # Client should NOT be re-created + mocked_dask_client.create.assert_not_called() + + mocked_dask_client.delete.assert_not_called() + + # Acquire the same client with second reference + ref2 = "test-ref-2" + async with clients_pool.acquire(cluster, ref=ref2): + # No new client should be created + mocked_dask_client.create.assert_not_called() + mocked_dask_client.delete.assert_not_called() + + # Release first reference, client should still exist + await clients_pool.release_client_ref(ref1) + mocked_dask_client.delete.assert_not_called() + + # Release second reference, which should delete the client + await clients_pool.release_client_ref(ref2) + mocked_dask_client.delete.assert_called_once() + + # calling again should not raise and not delete more + await clients_pool.release_client_ref(ref2) + mocked_dask_client.delete.assert_called_once() + + # Acquire again should create a new client + mocked_dask_client.create.reset_mock() + async with clients_pool.acquire(cluster, ref="test-ref-3"): + mocked_dask_client.create.assert_called_once() diff --git a/services/director-v2/tests/unit/with_dbs/test_utils_dask.py b/services/director-v2/tests/unit/with_dbs/test_utils_dask.py index 682e24825fcc..66035ccfcda1 100644 --- a/services/director-v2/tests/unit/with_dbs/test_utils_dask.py +++ b/services/director-v2/tests/unit/with_dbs/test_utils_dask.py @@ -443,7 +443,7 @@ async def test_clean_task_output_and_log_files_if_invalid( ] def _add_is_directory(entry: mock._Call) -> mock._Call: - new_kwargs: dict[str, Any] = deepcopy(entry.kwargs) + new_kwargs = dict(deepcopy(entry.kwargs)) new_kwargs["is_directory"] = False return mock.call(**new_kwargs) @@ -520,7 +520,9 @@ async def test_check_if_cluster_is_able_to_run_pipeline( ) default_cluster = dask_scheduler_settings.default_cluster dask_clients_pool = DaskClientsPool.instance(initialized_app) - async with dask_clients_pool.acquire(default_cluster) as dask_client: + async with dask_clients_pool.acquire( + default_cluster, ref="test-utils-dask-ref" + ) as dask_client: check_if_cluster_is_able_to_run_pipeline( project_id=project_id, node_id=node_id,