Skip to content

Commit ed0099a

Browse files
committed
simplify
1 parent 2049fc7 commit ed0099a

File tree

1 file changed

+46
-35
lines changed
  • services/autoscaling/src/simcore_service_autoscaling/modules

1 file changed

+46
-35
lines changed

services/autoscaling/src/simcore_service_autoscaling/modules/dask.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from collections import defaultdict
66
from collections.abc import AsyncIterator, Coroutine
7-
from typing import Any, Final, TypeAlias
7+
from typing import Any, Final, TypeAlias, TypedDict
88

99
import dask.typing
1010
import distributed
@@ -119,6 +119,44 @@ def _find_by_worker_host(
119119
return next(iter(filtered_workers.items()))
120120

121121

122+
class DaskClusterTasks(TypedDict):
123+
processing: dict[DaskWorkerUrl, list[tuple[dask.typing.Key, DaskTaskResources]]]
124+
unrunnable: dict[dask.typing.Key, DaskTaskResources]
125+
126+
127+
async def _list_cluster_known_tasks(
128+
client: distributed.Client,
129+
) -> DaskClusterTasks:
130+
def _list_on_scheduler(
131+
dask_scheduler: distributed.Scheduler,
132+
) -> DaskClusterTasks:
133+
worker_to_processing_tasks = defaultdict(list)
134+
unrunnable_tasks = {}
135+
for task_key, task_state in dask_scheduler.tasks.items():
136+
if task_state.processing_on:
137+
worker_to_processing_tasks[task_state.processing_on.address].append(
138+
(
139+
task_key,
140+
(task_state.resource_restrictions or {})
141+
| {DASK_WORKER_THREAD_RESOURCE_NAME: 1},
142+
)
143+
)
144+
elif task_state in dask_scheduler.unrunnable:
145+
unrunnable_tasks[task_key] = (
146+
task_state.resource_restrictions or {}
147+
) | {DASK_WORKER_THREAD_RESOURCE_NAME: 1}
148+
149+
return DaskClusterTasks(
150+
processing=dict(worker_to_processing_tasks),
151+
unrunnable=unrunnable_tasks,
152+
)
153+
154+
list_of_tasks: DaskClusterTasks = await client.run_on_scheduler(_list_on_scheduler)
155+
_logger.debug("found tasks: %s", list_of_tasks)
156+
157+
return list_of_tasks
158+
159+
122160
async def is_worker_connected(
123161
scheduler_url: AnyUrl,
124162
authentication: ClusterAuthentication,
@@ -178,10 +216,9 @@ def _list_tasks(
178216
}
179217

180218
async with _scheduler_client(scheduler_url, authentication) as client:
181-
list_of_tasks: dict[dask.typing.Key, DaskTaskResources] = (
182-
await _wrap_client_async_routine(client.run_on_scheduler(_list_tasks))
183-
)
184-
_logger.debug("found unrunnable tasks: %s", list_of_tasks)
219+
known_tasks = await _list_cluster_known_tasks(client)
220+
list_of_tasks = known_tasks["unrunnable"]
221+
185222
return [
186223
DaskTask(
187224
task_id=_dask_key_to_dask_task_id(task_id),
@@ -192,32 +229,6 @@ def _list_tasks(
192229
]
193230

194231

195-
async def _list_cluster_processing_tasks(
196-
client: distributed.Client,
197-
) -> dict[DaskWorkerUrl, list[tuple[dask.typing.Key, DaskTaskResources]]]:
198-
def _list_processing_tasks(
199-
dask_scheduler: distributed.Scheduler,
200-
) -> dict[str, list[tuple[dask.typing.Key, DaskTaskResources]]]:
201-
worker_to_processing_tasks = defaultdict(list)
202-
for task_key, task_state in dask_scheduler.tasks.items():
203-
if task_state.processing_on:
204-
worker_to_processing_tasks[task_state.processing_on.address].append(
205-
(
206-
task_key,
207-
(task_state.resource_restrictions or {})
208-
| {DASK_WORKER_THREAD_RESOURCE_NAME: 1},
209-
)
210-
)
211-
return worker_to_processing_tasks
212-
213-
list_of_tasks: dict[str, list[tuple[dask.typing.Key, DaskTaskResources]]] = (
214-
await client.run_on_scheduler(_list_processing_tasks)
215-
)
216-
_logger.debug("found processing tasks: %s", list_of_tasks)
217-
218-
return list_of_tasks
219-
220-
221232
async def list_processing_tasks_per_worker(
222233
scheduler_url: AnyUrl,
223234
authentication: ClusterAuthentication,
@@ -228,10 +239,10 @@ async def list_processing_tasks_per_worker(
228239
"""
229240

230241
async with _scheduler_client(scheduler_url, authentication) as client:
231-
worker_to_tasks = await _list_cluster_processing_tasks(client)
242+
worker_to_tasks = await _list_cluster_known_tasks(client)
232243

233244
tasks_per_worker = defaultdict(list)
234-
for worker, tasks in worker_to_tasks.items():
245+
for worker, tasks in worker_to_tasks["processing"].items():
235246
for task_id, required_resources in tasks:
236247
tasks_per_worker[worker].append(
237248
DaskTask(
@@ -276,8 +287,8 @@ async def get_worker_used_resources(
276287

277288
async with _scheduler_client(scheduler_url, authentication) as client:
278289
worker_url, _ = _dask_worker_from_ec2_instance(client, ec2_instance)
279-
worker_to_tasks = await _list_cluster_processing_tasks(client)
280-
worker_processing_tasks = worker_to_tasks.get(worker_url, [])
290+
known_tasks = await _list_cluster_known_tasks(client)
291+
worker_processing_tasks = known_tasks["processing"].get(worker_url, [])
281292

282293
total_resources_used: collections.Counter[str] = collections.Counter()
283294
for _, task_resources in worker_processing_tasks:

0 commit comments

Comments
 (0)