Skip to content

Commit b8663e5

Browse files
committed
simplify
1 parent e7327c5 commit b8663e5

File tree

2 files changed

+32
-28
lines changed

2 files changed

+32
-28
lines changed

services/autoscaling/src/simcore_service_autoscaling/modules/dask.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -185,21 +185,16 @@ def _list_tasks(
185185
return [
186186
DaskTask(
187187
task_id=_dask_key_to_dask_task_id(task_id),
188-
required_resources=task_resources,
188+
required_resources=task_resources
189+
| {DASK_WORKER_THREAD_RESOURCE_NAME: 1},
189190
)
190191
for task_id, task_resources in list_of_tasks.items()
191192
]
192193

193194

194-
async def list_processing_tasks_per_worker(
195-
scheduler_url: AnyUrl,
196-
authentication: ClusterAuthentication,
197-
) -> dict[DaskWorkerUrl, list[DaskTask]]:
198-
"""
199-
Raises:
200-
DaskSchedulerNotFoundError
201-
"""
202-
195+
async def _list_cluster_processing_tasks(
196+
client: distributed.Client,
197+
) -> dict[DaskWorkerUrl, list[tuple[dask.typing.Key, DaskTaskResources]]]:
203198
def _list_processing_tasks(
204199
dask_scheduler: distributed.Scheduler,
205200
) -> dict[str, list[tuple[dask.typing.Key, DaskTaskResources]]]:
@@ -211,13 +206,26 @@ def _list_processing_tasks(
211206
)
212207
return worker_to_processing_tasks
213208

209+
list_of_tasks: dict[str, list[tuple[dask.typing.Key, DaskTaskResources]]] = (
210+
await client.run_on_scheduler(_list_processing_tasks)
211+
)
212+
_logger.debug("found processing tasks: %s", list_of_tasks)
213+
214+
return list_of_tasks
215+
216+
217+
async def list_processing_tasks_per_worker(
218+
scheduler_url: AnyUrl,
219+
authentication: ClusterAuthentication,
220+
) -> dict[DaskWorkerUrl, list[DaskTask]]:
221+
"""
222+
Raises:
223+
DaskSchedulerNotFoundError
224+
"""
225+
214226
async with _scheduler_client(scheduler_url, authentication) as client:
215-
worker_to_tasks: dict[str, list[tuple[dask.typing.Key, DaskTaskResources]]] = (
216-
await _wrap_client_async_routine(
217-
client.run_on_scheduler(_list_processing_tasks)
218-
)
219-
)
220-
_logger.debug("found processing tasks: %s", worker_to_tasks)
227+
worker_to_tasks = await _list_cluster_processing_tasks(client)
228+
221229
tasks_per_worker = defaultdict(list)
222230
for worker, tasks in worker_to_tasks.items():
223231
for task_id, required_resources in tasks:
@@ -277,17 +285,8 @@ def _list_processing_tasks_on_worker(
277285

278286
async with _scheduler_client(scheduler_url, authentication) as client:
279287
worker_url, _ = _dask_worker_from_ec2_instance(client, ec2_instance)
280-
281-
_logger.debug("looking for processing tasks for %s", f"{worker_url=}")
282-
283-
# now get the used resources
284-
worker_processing_tasks: list[tuple[dask.typing.Key, DaskTaskResources]] = (
285-
await _wrap_client_async_routine(
286-
client.run_on_scheduler(
287-
_list_processing_tasks_on_worker, worker_url=worker_url
288-
),
289-
)
290-
)
288+
worker_to_tasks = await _list_cluster_processing_tasks(client)
289+
worker_processing_tasks = worker_to_tasks.get(worker_url, [])
291290

292291
total_resources_used: collections.Counter[str] = collections.Counter()
293292
for _, task_resources in worker_processing_tasks:

services/autoscaling/tests/unit/test_modules_dask.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,12 @@ async def test_list_unrunnable_tasks(
126126
future = create_dask_task(dask_task_impossible_resources)
127127
assert future
128128
assert await list_unrunnable_tasks(scheduler_url, scheduler_authentication) == [
129-
DaskTask(task_id=future.key, required_resources=dask_task_impossible_resources)
129+
DaskTask(
130+
task_id=future.key,
131+
required_resources=(
132+
dask_task_impossible_resources | {DASK_WORKER_THREAD_RESOURCE_NAME: 1}
133+
),
134+
)
130135
]
131136
# remove that future, will remove the task
132137
del future

0 commit comments

Comments
 (0)