diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index dc904a5b435f..3cb3d993fd9e 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -23,7 +23,6 @@ from pydantic import PositiveInt from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE from servicelib.logging_utils import log_catch, log_context -from servicelib.redis._client import RedisClientSDK from servicelib.redis._semaphore_decorator import ( with_limited_concurrency_cm, ) @@ -71,40 +70,6 @@ _TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time" -def _get_redis_client_from_scheduler( - _user_id: UserID, - scheduler: "DaskScheduler", - **kwargs, # pylint: disable=unused-argument # noqa: ARG001 -) -> RedisClientSDK: - return scheduler.redis_client - - -def _get_semaphore_cluster_redis_key( - user_id: UserID, - *args, # pylint: disable=unused-argument # noqa: ARG001 - run_metadata: RunMetadataDict, - **kwargs, # pylint: disable=unused-argument # noqa: ARG001 -) -> str: - return f"{APP_NAME}-cluster-user_id_{user_id}-wallet_id_{run_metadata.get('wallet_id')}" - - -def _get_semaphore_capacity_from_scheduler( - _user_id: UserID, - scheduler: "DaskScheduler", - **kwargs, # pylint: disable=unused-argument # noqa: ARG001 -) -> int: - return ( - scheduler.settings.COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS - ) - - -@with_limited_concurrency_cm( - _get_redis_client_from_scheduler, - key=_get_semaphore_cluster_redis_key, - capacity=_get_semaphore_capacity_from_scheduler, - blocking=True, - blocking_timeout=None, -) @asynccontextmanager async def _cluster_dask_client( user_id: UserID, @@ -122,12 +87,27 @@ async def _cluster_dask_client( user_id=user_id, wallet_id=run_metadata.get("wallet_id"), ) - async with scheduler.dask_clients_pool.acquire( - cluster, - ref=_DASK_CLIENT_RUN_REF.format( - user_id=user_id, project_id=project_id, run_id=run_id - ), - ) as client: + + @with_limited_concurrency_cm( + scheduler.redis_client, + key=f"{APP_NAME}-cluster-user_id_{user_id}-wallet_id_{run_metadata.get('wallet_id')}", + capacity=scheduler.settings.COMPUTATIONAL_BACKEND_PER_CLUSTER_MAX_DISTRIBUTED_CONCURRENT_CONNECTIONS, + blocking=True, + blocking_timeout=None, + ) + @asynccontextmanager + async def _acquire_client( + user_id: UserID, scheduler: "DaskScheduler" + ) -> AsyncIterator[DaskClient]: + async with scheduler.dask_clients_pool.acquire( + cluster, + ref=_DASK_CLIENT_RUN_REF.format( + user_id=user_id, project_id=project_id, run_id=run_id + ), + ) as client: + yield client + + async with _acquire_client(user_id, scheduler) as client: yield client