Skip to content

Commit e2700ab

Browse files
committed
added reference counting on the dask client in the pool
1 parent da342d2 commit e2700ab

File tree

3 files changed

+95
-8
lines changed

3 files changed

+95
-8
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,12 @@ async def _process_executing_tasks(
589589
) -> None:
590590
"""process executing tasks from the 3rd party backend"""
591591

592+
@abstractmethod
593+
async def _release_resources(
594+
self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB
595+
) -> None:
596+
"""release resources used by the scheduler for a given user and project"""
597+
592598
async def apply(
593599
self,
594600
*,
@@ -654,6 +660,7 @@ async def apply(
654660

655661
# 7. Are we done scheduling that pipeline?
656662
if not dag.nodes() or pipeline_result in COMPLETED_STATES:
663+
await self._release_resources(user_id, project_id, comp_run)
657664
# there is nothing left, the run is completed, we're done here
658665
_logger.info(
659666
"pipeline %s scheduling completed with result %s",

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import AsyncIterator, Callable
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass
7-
from typing import Any
7+
from typing import Any, Final
88

99
import arrow
1010
from dask_task_models_library.container_tasks.errors import TaskCancelledError
@@ -23,7 +23,7 @@
2323
from models_library.users import UserID
2424
from pydantic import PositiveInt
2525
from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
26-
from servicelib.logging_utils import log_catch
26+
from servicelib.logging_utils import log_catch, log_context
2727

2828
from ...core.errors import (
2929
ComputationalBackendNotConnectedError,
@@ -56,13 +56,16 @@
5656

5757
_logger = logging.getLogger(__name__)
5858

59+
_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{comp_run.run_id}"
60+
5961

6062
@asynccontextmanager
6163
async def _cluster_dask_client(
6264
user_id: UserID,
6365
scheduler: "DaskScheduler",
6466
*,
6567
use_on_demand_clusters: bool,
68+
run_id: PositiveInt,
6669
run_metadata: RunMetadataDict,
6770
) -> AsyncIterator[DaskClient]:
6871
cluster: BaseCluster = scheduler.settings.default_cluster
@@ -72,7 +75,9 @@ async def _cluster_dask_client(
7275
user_id=user_id,
7376
wallet_id=run_metadata.get("wallet_id"),
7477
)
75-
async with scheduler.dask_clients_pool.acquire(cluster) as client:
78+
async with scheduler.dask_clients_pool.acquire(
79+
cluster, ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=run_id)
80+
) as client:
7681
yield client
7782

7883

@@ -101,6 +106,7 @@ async def _start_tasks(
101106
user_id,
102107
self,
103108
use_on_demand_clusters=comp_run.use_on_demand_clusters,
109+
run_id=comp_run.run_id,
104110
run_metadata=comp_run.metadata,
105111
) as client:
106112
# Change the tasks state to PENDING
@@ -151,6 +157,7 @@ async def _get_tasks_status(
151157
user_id,
152158
self,
153159
use_on_demand_clusters=comp_run.use_on_demand_clusters,
160+
run_id=comp_run.run_id,
154161
run_metadata=comp_run.metadata,
155162
) as client:
156163
return await client.get_tasks_status([f"{t.job_id}" for t in tasks])
@@ -171,6 +178,7 @@ async def _process_executing_tasks(
171178
user_id,
172179
self,
173180
use_on_demand_clusters=comp_run.use_on_demand_clusters,
181+
run_id=comp_run.run_id,
174182
run_metadata=comp_run.metadata,
175183
) as client:
176184
task_progresses = await client.get_tasks_progress(
@@ -217,6 +225,22 @@ async def _process_executing_tasks(
217225
)
218226
)
219227

228+
async def _release_resources(
229+
self, user_id: UserID, project_id: ProjectID, comp_run: CompRunsAtDB
230+
) -> None:
231+
"""release resources used by the scheduler for a given user and project"""
232+
with (
233+
log_catch(_logger, reraise=False),
234+
log_context(
235+
_logger,
236+
logging.INFO,
237+
msg=f"releasing resources for {user_id=}, {project_id=}, {comp_run.run_id=}",
238+
),
239+
):
240+
await self.dask_clients_pool.release_client_ref(
241+
ref=_DASK_CLIENT_RUN_REF.format(user_id=user_id, run_id=comp_run.run_id)
242+
)
243+
220244
async def _stop_tasks(
221245
self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB
222246
) -> None:
@@ -226,6 +250,7 @@ async def _stop_tasks(
226250
user_id,
227251
self,
228252
use_on_demand_clusters=comp_run.use_on_demand_clusters,
253+
run_id=comp_run.run_id,
229254
run_metadata=comp_run.metadata,
230255
) as client:
231256
await asyncio.gather(
@@ -259,6 +284,7 @@ async def _process_completed_tasks(
259284
user_id,
260285
self,
261286
use_on_demand_clusters=comp_run.use_on_demand_clusters,
287+
run_id=comp_run.run_id,
262288
run_metadata=comp_run.metadata,
263289
) as client:
264290
tasks_results = await asyncio.gather(
@@ -278,6 +304,7 @@ async def _process_completed_tasks(
278304
user_id,
279305
self,
280306
use_on_demand_clusters=comp_run.use_on_demand_clusters,
307+
run_id=comp_run.run_id,
281308
run_metadata=comp_run.metadata,
282309
) as client:
283310
await asyncio.gather(

services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from collections import defaultdict
34
from collections.abc import AsyncIterator
45
from contextlib import asynccontextmanager
56
from dataclasses import dataclass, field
@@ -24,6 +25,7 @@
2425

2526

2627
_ClusterUrl: TypeAlias = AnyUrl
28+
ClientRef: TypeAlias = str
2729

2830

2931
@dataclass
@@ -33,6 +35,10 @@ class DaskClientsPool:
3335
_client_acquisition_lock: asyncio.Lock = field(init=False)
3436
_cluster_to_client_map: dict[_ClusterUrl, DaskClient] = field(default_factory=dict)
3537
_task_handlers: TaskHandlers | None = None
38+
# Track references to each client by endpoint
39+
_client_refs: defaultdict[_ClusterUrl, set[str]] = field(
40+
default_factory=lambda: defaultdict(set)
41+
)
3642

3743
def __post_init__(self):
3844
# NOTE: to ensure the correct loop is used
@@ -60,13 +66,54 @@ async def delete(self) -> None:
6066
*[client.delete() for client in self._cluster_to_client_map.values()],
6167
return_exceptions=True,
6268
)
69+
self._cluster_to_client_map.clear()
70+
self._client_refs.clear()
71+
72+
async def release_client_ref(self, ref: ClientRef) -> None:
73+
"""Release a dask client reference by its ref.
74+
75+
If all the references to the client are released,
76+
the client will be deleted from the pool.
77+
This method is thread-safe and can be called concurrently.
78+
"""
79+
async with self._client_acquisition_lock:
80+
# Find which endpoint this ref belongs to
81+
endpoint_to_remove = None
82+
for endpoint, refs in self._client_refs.items():
83+
if ref in refs:
84+
refs.remove(ref)
85+
_logger.debug("Released reference %s for client %s", ref, endpoint)
86+
if not refs: # No more references to this client
87+
endpoint_to_remove = endpoint
88+
break
89+
90+
# If we found an endpoint with no more refs, clean it up
91+
if endpoint_to_remove and (
92+
dask_client := self._cluster_to_client_map.pop(endpoint_to_remove, None)
93+
):
94+
_logger.info(
95+
"Last reference to client %s released, deleting client",
96+
endpoint_to_remove,
97+
)
98+
await dask_client.delete()
99+
# Clean up the empty refs set
100+
del self._client_refs[endpoint_to_remove]
101+
_logger.debug(
102+
"Remaining clients: %s",
103+
[f"{k}" for k in self._cluster_to_client_map],
104+
)
63105

64106
@asynccontextmanager
65-
async def acquire(self, cluster: BaseCluster) -> AsyncIterator[DaskClient]:
66-
"""returns a dask client for the given cluster
107+
async def acquire(
108+
self, cluster: BaseCluster, *, ref: ClientRef
109+
) -> AsyncIterator[DaskClient]:
110+
"""Returns a dask client for the given cluster.
111+
67112
This method is thread-safe and can be called concurrently.
68113
If the cluster is not found in the pool, it will create a new dask client for it.
69114
115+
The passed reference is used to track the client usage, user should call
116+
`release_client_ref` to release the client reference when done.
70117
"""
71118

72119
async def _concurently_safe_acquire_client() -> DaskClient:
@@ -102,15 +149,21 @@ async def _concurently_safe_acquire_client() -> DaskClient:
102149
if self._task_handlers:
103150
dask_client.register_handlers(self._task_handlers)
104151

105-
_logger.debug(
106-
"list of clients: %s", f"{self._cluster_to_client_map=}"
107-
)
152+
# Track the reference
153+
self._client_refs[cluster.endpoint].add(ref)
154+
155+
_logger.debug(
156+
"Client %s now has %d references",
157+
cluster.endpoint,
158+
len(self._client_refs[cluster.endpoint]),
159+
)
108160

109161
assert dask_client # nosec
110162
return dask_client
111163

112164
try:
113165
dask_client = await _concurently_safe_acquire_client()
166+
114167
except Exception as exc:
115168
raise DaskClientAcquisisitonError(cluster=cluster, error=exc) from exc
116169

0 commit comments

Comments
 (0)