@@ -185,21 +185,16 @@ def _list_tasks(
185185 return [
186186 DaskTask (
187187 task_id = _dask_key_to_dask_task_id (task_id ),
188- required_resources = task_resources ,
188+ required_resources = task_resources
189+ | {DASK_WORKER_THREAD_RESOURCE_NAME : 1 },
189190 )
190191 for task_id , task_resources in list_of_tasks .items ()
191192 ]
192193
193194
194- async def list_processing_tasks_per_worker (
195- scheduler_url : AnyUrl ,
196- authentication : ClusterAuthentication ,
197- ) -> dict [DaskWorkerUrl , list [DaskTask ]]:
198- """
199- Raises:
200- DaskSchedulerNotFoundError
201- """
202-
195+ async def _list_cluster_processing_tasks (
196+ client : distributed .Client ,
197+ ) -> dict [DaskWorkerUrl , list [tuple [dask .typing .Key , DaskTaskResources ]]]:
203198 def _list_processing_tasks (
204199 dask_scheduler : distributed .Scheduler ,
205200 ) -> dict [str , list [tuple [dask .typing .Key , DaskTaskResources ]]]:
@@ -211,13 +206,26 @@ def _list_processing_tasks(
211206 )
212207 return worker_to_processing_tasks
213208
209+ list_of_tasks : dict [str , list [tuple [dask .typing .Key , DaskTaskResources ]]] = (
210+ await client .run_on_scheduler (_list_processing_tasks )
211+ )
212+ _logger .debug ("found processing tasks: %s" , list_of_tasks )
213+
214+ return list_of_tasks
215+
216+
217+ async def list_processing_tasks_per_worker (
218+ scheduler_url : AnyUrl ,
219+ authentication : ClusterAuthentication ,
220+ ) -> dict [DaskWorkerUrl , list [DaskTask ]]:
221+ """
222+ Raises:
223+ DaskSchedulerNotFoundError
224+ """
225+
214226 async with _scheduler_client (scheduler_url , authentication ) as client :
215- worker_to_tasks : dict [str , list [tuple [dask .typing .Key , DaskTaskResources ]]] = (
216- await _wrap_client_async_routine (
217- client .run_on_scheduler (_list_processing_tasks )
218- )
219- )
220- _logger .debug ("found processing tasks: %s" , worker_to_tasks )
227+ worker_to_tasks = await _list_cluster_processing_tasks (client )
228+
221229 tasks_per_worker = defaultdict (list )
222230 for worker , tasks in worker_to_tasks .items ():
223231 for task_id , required_resources in tasks :
@@ -277,17 +285,8 @@ def _list_processing_tasks_on_worker(
277285
278286 async with _scheduler_client (scheduler_url , authentication ) as client :
279287 worker_url , _ = _dask_worker_from_ec2_instance (client , ec2_instance )
280-
281- _logger .debug ("looking for processing tasks for %s" , f"{ worker_url = } " )
282-
283- # now get the used resources
284- worker_processing_tasks : list [tuple [dask .typing .Key , DaskTaskResources ]] = (
285- await _wrap_client_async_routine (
286- client .run_on_scheduler (
287- _list_processing_tasks_on_worker , worker_url = worker_url
288- ),
289- )
290- )
288+ worker_to_tasks = await _list_cluster_processing_tasks (client )
289+ worker_processing_tasks = worker_to_tasks .get (worker_url , [])
291290
292291 total_resources_used : collections .Counter [str ] = collections .Counter ()
293292 for _ , task_resources in worker_processing_tasks :
0 commit comments