@@ -36,9 +36,10 @@ class DaskClientsPool:
3636 _cluster_to_client_map : dict [_ClusterUrl , DaskClient ] = field (default_factory = dict )
3737 _task_handlers : TaskHandlers | None = None
3838 # Track references to each client by endpoint
39- _client_refs : defaultdict [_ClusterUrl , set [str ]] = field (
39+ _client_to_refs : defaultdict [_ClusterUrl , set [ClientRef ]] = field (
4040 default_factory = lambda : defaultdict (set )
4141 )
42+ _ref_to_clients : dict [ClientRef , _ClusterUrl ] = field (default_factory = dict )
4243
4344 def __post_init__ (self ):
4445 # NOTE: to ensure the correct loop is used
@@ -67,7 +68,8 @@ async def delete(self) -> None:
6768 return_exceptions = True ,
6869 )
6970 self ._cluster_to_client_map .clear ()
70- self ._client_refs .clear ()
71+ self ._client_to_refs .clear ()
72+ self ._ref_to_clients .clear ()
7173
7274 async def release_client_ref (self , ref : ClientRef ) -> None :
7375 """Release a dask client reference by its ref.
@@ -78,30 +80,26 @@ async def release_client_ref(self, ref: ClientRef) -> None:
7880 """
7981 async with self ._client_acquisition_lock :
8082 # Find which endpoint this ref belongs to
81- endpoint_to_remove = None
82- for endpoint , refs in self ._client_refs .items ():
83- if ref in refs :
84- refs .remove (ref )
85- _logger .debug ("Released reference %s for client %s" , ref , endpoint )
86- if not refs : # No more references to this client
87- endpoint_to_remove = endpoint
88- break
89-
90- # If we found an endpoint with no more refs, clean it up
91- if endpoint_to_remove and (
92- dask_client := self ._cluster_to_client_map .pop (endpoint_to_remove , None )
93- ):
94- _logger .info (
95- "Last reference to client %s released, deleting client" ,
96- endpoint_to_remove ,
97- )
98- await dask_client .delete ()
99- # Clean up the empty refs set
100- del self ._client_refs [endpoint_to_remove ]
101- _logger .debug (
102- "Remaining clients: %s" ,
103- [f"{ k } " for k in self ._cluster_to_client_map ],
104- )
83+ if cluster_endpoint := self ._ref_to_clients .pop (ref , None ):
84+ # we have a client, remove our reference and check if there are any more references
85+ assert ref in self ._client_to_refs [cluster_endpoint ] # nosec
86+ self ._client_to_refs [cluster_endpoint ].discard (ref )
87+
88+ # If we found an endpoint with no more refs, clean it up
89+ if not self ._client_to_refs [cluster_endpoint ] and (
90+ dask_client := self ._cluster_to_client_map .pop (
91+ cluster_endpoint , None
92+ )
93+ ):
94+ _logger .info (
95+ "Last reference to client %s released, deleting client" ,
96+ cluster_endpoint ,
97+ )
98+ await dask_client .delete ()
99+ _logger .debug (
100+ "Remaining clients: %s" ,
101+ [f"{ k } " for k in self ._cluster_to_client_map ],
102+ )
105103
106104 @asynccontextmanager
107105 async def acquire (
@@ -150,12 +148,13 @@ async def _concurently_safe_acquire_client() -> DaskClient:
150148 dask_client .register_handlers (self ._task_handlers )
151149
152150 # Track the reference
153- self ._client_refs [cluster .endpoint ].add (ref )
151+ self ._client_to_refs [cluster .endpoint ].add (ref )
152+ self ._ref_to_clients [ref ] = cluster .endpoint
154153
155154 _logger .debug (
156155 "Client %s now has %d references" ,
157156 cluster .endpoint ,
158- len (self ._client_refs [cluster .endpoint ]),
157+ len (self ._client_to_refs [cluster .endpoint ]),
159158 )
160159
161160 assert dask_client # nosec
0 commit comments