Skip to content

Commit e5b3e39

Browse files
committed
correct the AI
1 parent 5558197 commit e5b3e39

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

services/director-v2/src/simcore_service_director_v2/modules/dask_clients_pool.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class DaskClientsPool:
3636
_cluster_to_client_map: dict[_ClusterUrl, DaskClient] = field(default_factory=dict)
3737
_task_handlers: TaskHandlers | None = None
3838
# Track references to each client by endpoint
39-
_client_refs: defaultdict[_ClusterUrl, set[str]] = field(
39+
_client_to_refs: defaultdict[_ClusterUrl, set[ClientRef]] = field(
4040
default_factory=lambda: defaultdict(set)
4141
)
42+
_ref_to_clients: dict[ClientRef, _ClusterUrl] = field(default_factory=dict)
4243

4344
def __post_init__(self):
4445
# NOTE: to ensure the correct loop is used
@@ -67,7 +68,8 @@ async def delete(self) -> None:
6768
return_exceptions=True,
6869
)
6970
self._cluster_to_client_map.clear()
70-
self._client_refs.clear()
71+
self._client_to_refs.clear()
72+
self._ref_to_clients.clear()
7173

7274
async def release_client_ref(self, ref: ClientRef) -> None:
7375
"""Release a dask client reference by its ref.
@@ -78,30 +80,26 @@ async def release_client_ref(self, ref: ClientRef) -> None:
7880
"""
7981
async with self._client_acquisition_lock:
8082
# Find which endpoint this ref belongs to
81-
endpoint_to_remove = None
82-
for endpoint, refs in self._client_refs.items():
83-
if ref in refs:
84-
refs.remove(ref)
85-
_logger.debug("Released reference %s for client %s", ref, endpoint)
86-
if not refs: # No more references to this client
87-
endpoint_to_remove = endpoint
88-
break
89-
90-
# If we found an endpoint with no more refs, clean it up
91-
if endpoint_to_remove and (
92-
dask_client := self._cluster_to_client_map.pop(endpoint_to_remove, None)
93-
):
94-
_logger.info(
95-
"Last reference to client %s released, deleting client",
96-
endpoint_to_remove,
97-
)
98-
await dask_client.delete()
99-
# Clean up the empty refs set
100-
del self._client_refs[endpoint_to_remove]
101-
_logger.debug(
102-
"Remaining clients: %s",
103-
[f"{k}" for k in self._cluster_to_client_map],
104-
)
83+
if cluster_endpoint := self._ref_to_clients.pop(ref, None):
84+
# we have a client, remove our reference and check if there are any more references
85+
assert ref in self._client_to_refs[cluster_endpoint] # nosec
86+
self._client_to_refs[cluster_endpoint].discard(ref)
87+
88+
# If we found an endpoint with no more refs, clean it up
89+
if not self._client_to_refs[cluster_endpoint] and (
90+
dask_client := self._cluster_to_client_map.pop(
91+
cluster_endpoint, None
92+
)
93+
):
94+
_logger.info(
95+
"Last reference to client %s released, deleting client",
96+
cluster_endpoint,
97+
)
98+
await dask_client.delete()
99+
_logger.debug(
100+
"Remaining clients: %s",
101+
[f"{k}" for k in self._cluster_to_client_map],
102+
)
105103

106104
@asynccontextmanager
107105
async def acquire(
@@ -150,12 +148,13 @@ async def _concurently_safe_acquire_client() -> DaskClient:
150148
dask_client.register_handlers(self._task_handlers)
151149

152150
# Track the reference
153-
self._client_refs[cluster.endpoint].add(ref)
151+
self._client_to_refs[cluster.endpoint].add(ref)
152+
self._ref_to_clients[ref] = cluster.endpoint
154153

155154
_logger.debug(
156155
"Client %s now has %d references",
157156
cluster.endpoint,
158-
len(self._client_refs[cluster.endpoint]),
157+
len(self._client_to_refs[cluster.endpoint]),
159158
)
160159

161160
assert dask_client # nosec

services/director-v2/tests/unit/test_modules_dask_clients_pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,15 @@ async def test_dask_clients_pool_reference_counting(
284284
mocked_dask_client.create.assert_called_once()
285285
# Reset the mock to check the next call
286286
mocked_dask_client.create.reset_mock()
287+
mocked_dask_client.delete.assert_not_called()
288+
289+
# calling again with the same reference should not create a new client
290+
async with clients_pool.acquire(cluster, ref=ref1):
291+
# Client should NOT be re-created
292+
mocked_dask_client.create.assert_not_called()
287293

288294
mocked_dask_client.delete.assert_not_called()
295+
289296
# Acquire the same client with second reference
290297
ref2 = "test-ref-2"
291298
async with clients_pool.acquire(cluster, ref=ref2):

0 commit comments

Comments
 (0)