|
| 1 | +import collections |
1 | 2 | import contextlib |
2 | 3 | import logging |
3 | 4 | import re |
@@ -217,30 +218,39 @@ async def get_worker_used_resources( |
217 | 218 | DaskNoWorkersError |
218 | 219 | """ |
219 | 220 |
|
220 | | - def _get_worker_used_resources( |
| 221 | + def _list_processing_tasks_on_worker( |
221 | 222 | dask_scheduler: distributed.Scheduler, *, worker_url: str |
222 | | - ) -> dict[str, float] | None: |
223 | | - for worker_name, worker_state in dask_scheduler.workers.items(): |
224 | | - if worker_url != worker_name: |
225 | | - continue |
226 | | - if worker_state.status is distributed.Status.closing_gracefully: |
227 | | - # NOTE: when a worker was retired it is in this state |
228 | | - return {} |
229 | | - return dict(worker_state.used_resources) |
230 | | - return None |
| 223 | + ) -> list[tuple[DaskTaskId, DaskTaskResources]]: |
| 224 | + processing_tasks = [] |
| 225 | + for task_key, task_state in dask_scheduler.tasks.items(): |
| 226 | + if task_state.processing_on and ( |
| 227 | + task_state.processing_on.address == worker_url |
| 228 | + ): |
| 229 | + processing_tasks.append((task_key, task_state.resource_restrictions)) |
| 230 | + return processing_tasks |
231 | 231 |
|
232 | 232 | async with _scheduler_client(scheduler_url, authentication) as client: |
233 | 233 | worker_url, _ = _dask_worker_from_ec2_instance(client, ec2_instance) |
234 | 234 |
|
| 235 | + _logger.debug("looking for processing tasksfor %s", f"{worker_url=}") |
| 236 | + |
235 | 237 | # now get the used resources |
236 | | - worker_used_resources: dict[str, Any] | None = await _wrap_client_async_routine( |
237 | | - client.run_on_scheduler(_get_worker_used_resources, worker_url=worker_url), |
| 238 | + worker_processing_tasks: list[ |
| 239 | + tuple[DaskTaskId, DaskTaskResources] |
| 240 | + ] = await _wrap_client_async_routine( |
| 241 | + client.run_on_scheduler( |
| 242 | + _list_processing_tasks_on_worker, worker_url=worker_url |
| 243 | + ), |
238 | 244 | ) |
239 | | - if worker_used_resources is None: |
240 | | - raise DaskWorkerNotFoundError(worker_host=worker_url, url=scheduler_url) |
| 245 | + |
| 246 | + total_resources_used: collections.Counter[str] = collections.Counter() |
| 247 | + for _, task_resources in worker_processing_tasks: |
| 248 | + total_resources_used.update(task_resources) |
| 249 | + |
| 250 | + _logger.debug("found %s for %s", f"{total_resources_used=}", f"{worker_url=}") |
241 | 251 | return Resources( |
242 | | - cpus=worker_used_resources.get("CPU", 0), |
243 | | - ram=parse_obj_as(ByteSize, worker_used_resources.get("RAM", 0)), |
| 252 | + cpus=total_resources_used.get("CPU", 0), |
| 253 | + ram=parse_obj_as(ByteSize, total_resources_used.get("RAM", 0)), |
244 | 254 | ) |
245 | 255 |
|
246 | 256 |
|
|
0 commit comments