44import re
55from collections import defaultdict
66from collections .abc import AsyncIterator , Coroutine
7- from typing import Any , Final , TypeAlias
7+ from typing import Any , Final , TypeAlias , TypedDict
88
99import dask .typing
1010import distributed
@@ -119,6 +119,44 @@ def _find_by_worker_host(
119119 return next (iter (filtered_workers .items ()))
120120
121121
122+ class DaskClusterTasks (TypedDict ):
123+ processing : dict [DaskWorkerUrl , list [tuple [dask .typing .Key , DaskTaskResources ]]]
124+ unrunnable : dict [dask .typing .Key , DaskTaskResources ]
125+
126+
127+ async def _list_cluster_known_tasks (
128+ client : distributed .Client ,
129+ ) -> DaskClusterTasks :
130+ def _list_on_scheduler (
131+ dask_scheduler : distributed .Scheduler ,
132+ ) -> DaskClusterTasks :
133+ worker_to_processing_tasks = defaultdict (list )
134+ unrunnable_tasks = {}
135+ for task_key , task_state in dask_scheduler .tasks .items ():
136+ if task_state .processing_on :
137+ worker_to_processing_tasks [task_state .processing_on .address ].append (
138+ (
139+ task_key ,
140+ (task_state .resource_restrictions or {})
141+ | {DASK_WORKER_THREAD_RESOURCE_NAME : 1 },
142+ )
143+ )
144+ elif task_state in dask_scheduler .unrunnable :
145+ unrunnable_tasks [task_key ] = (
146+ task_state .resource_restrictions or {}
147+ ) | {DASK_WORKER_THREAD_RESOURCE_NAME : 1 }
148+
149+ return DaskClusterTasks (
150+ processing = dict (worker_to_processing_tasks ),
151+ unrunnable = unrunnable_tasks ,
152+ )
153+
154+ list_of_tasks : DaskClusterTasks = await client .run_on_scheduler (_list_on_scheduler )
155+ _logger .debug ("found tasks: %s" , list_of_tasks )
156+
157+ return list_of_tasks
158+
159+
122160async def is_worker_connected (
123161 scheduler_url : AnyUrl ,
124162 authentication : ClusterAuthentication ,
@@ -178,10 +216,9 @@ def _list_tasks(
178216 }
179217
180218 async with _scheduler_client (scheduler_url , authentication ) as client :
181- list_of_tasks : dict [dask .typing .Key , DaskTaskResources ] = (
182- await _wrap_client_async_routine (client .run_on_scheduler (_list_tasks ))
183- )
184- _logger .debug ("found unrunnable tasks: %s" , list_of_tasks )
219+ known_tasks = await _list_cluster_known_tasks (client )
220+ list_of_tasks = known_tasks ["unrunnable" ]
221+
185222 return [
186223 DaskTask (
187224 task_id = _dask_key_to_dask_task_id (task_id ),
@@ -192,32 +229,6 @@ def _list_tasks(
192229 ]
193230
194231
195- async def _list_cluster_processing_tasks (
196- client : distributed .Client ,
197- ) -> dict [DaskWorkerUrl , list [tuple [dask .typing .Key , DaskTaskResources ]]]:
198- def _list_processing_tasks (
199- dask_scheduler : distributed .Scheduler ,
200- ) -> dict [str , list [tuple [dask .typing .Key , DaskTaskResources ]]]:
201- worker_to_processing_tasks = defaultdict (list )
202- for task_key , task_state in dask_scheduler .tasks .items ():
203- if task_state .processing_on :
204- worker_to_processing_tasks [task_state .processing_on .address ].append (
205- (
206- task_key ,
207- (task_state .resource_restrictions or {})
208- | {DASK_WORKER_THREAD_RESOURCE_NAME : 1 },
209- )
210- )
211- return worker_to_processing_tasks
212-
213- list_of_tasks : dict [str , list [tuple [dask .typing .Key , DaskTaskResources ]]] = (
214- await client .run_on_scheduler (_list_processing_tasks )
215- )
216- _logger .debug ("found processing tasks: %s" , list_of_tasks )
217-
218- return list_of_tasks
219-
220-
221232async def list_processing_tasks_per_worker (
222233 scheduler_url : AnyUrl ,
223234 authentication : ClusterAuthentication ,
@@ -228,10 +239,10 @@ async def list_processing_tasks_per_worker(
228239 """
229240
230241 async with _scheduler_client (scheduler_url , authentication ) as client :
231- worker_to_tasks = await _list_cluster_processing_tasks (client )
242+ worker_to_tasks = await _list_cluster_known_tasks (client )
232243
233244 tasks_per_worker = defaultdict (list )
234- for worker , tasks in worker_to_tasks .items ():
245+ for worker , tasks in worker_to_tasks [ "processing" ] .items ():
235246 for task_id , required_resources in tasks :
236247 tasks_per_worker [worker ].append (
237248 DaskTask (
@@ -276,8 +287,8 @@ async def get_worker_used_resources(
276287
277288 async with _scheduler_client (scheduler_url , authentication ) as client :
278289 worker_url , _ = _dask_worker_from_ec2_instance (client , ec2_instance )
279- worker_to_tasks = await _list_cluster_processing_tasks (client )
280- worker_processing_tasks = worker_to_tasks .get (worker_url , [])
290+ known_tasks = await _list_cluster_known_tasks (client )
291+ worker_processing_tasks = known_tasks [ "processing" ] .get (worker_url , [])
281292
282293 total_resources_used : collections .Counter [str ] = collections .Counter ()
283294 for _ , task_resources in worker_processing_tasks :
0 commit comments