Skip to content

Commit 8fafa40

Browse files
authored
🎨Computational backend: improvements step 6 (#8397)
1 parent a26142b commit 8fafa40

File tree

3 files changed

+87
-130
lines changed

3 files changed

+87
-130
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from servicelib.logging_utils import log_catch, log_context
3636
from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient
3737
from servicelib.redis import RedisClientSDK
38+
from servicelib.utils import limited_gather
3839
from sqlalchemy.ext.asyncio import AsyncEngine
3940

4041
from ...constants import UNDEFINED_STR_METADATA
@@ -79,6 +80,7 @@
7980
_MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta(
8081
seconds=30
8182
)
83+
_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10
8284

8385

8486
def _auto_schedule_callback(
@@ -336,7 +338,7 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool:
336338
project_id, dag
337339
)
338340
if running_tasks := [t for t in tasks.values() if _need_heartbeat(t)]:
339-
await asyncio.gather(
341+
await limited_gather(
340342
*(
341343
publish_service_resource_tracking_heartbeat(
342344
self.rabbitmq_client,
@@ -345,17 +347,15 @@ def _need_heartbeat(task: CompTaskAtDB) -> bool:
345347
),
346348
)
347349
for t in running_tasks
348-
)
350+
),
351+
log=_logger,
352+
limit=_PUBLICATION_CONCURRENCY_LIMIT,
349353
)
350-
comp_tasks_repo = CompTasksRepository(self.db_engine)
351-
await asyncio.gather(
352-
*(
353-
comp_tasks_repo.update_project_task_last_heartbeat(
354-
t.project_id, t.node_id, run_id, utc_now
355-
)
356-
for t in running_tasks
354+
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
355+
for task in running_tasks:
356+
await comp_tasks_repo.update_project_task_last_heartbeat(
357+
project_id, task.node_id, run_id, utc_now
357358
)
358-
)
359359

360360
async def _get_changed_tasks_from_backend(
361361
self,
@@ -400,7 +400,7 @@ async def _process_started_tasks(
400400
utc_now = arrow.utcnow().datetime
401401

402402
# resource tracking
403-
await asyncio.gather(
403+
await limited_gather(
404404
*(
405405
publish_service_resource_tracking_started(
406406
self.rabbitmq_client,
@@ -462,10 +462,12 @@ async def _process_started_tasks(
462462
service_additional_metadata={},
463463
)
464464
for t in tasks
465-
)
465+
),
466+
log=_logger,
467+
limit=_PUBLICATION_CONCURRENCY_LIMIT,
466468
)
467469
# instrumentation
468-
await asyncio.gather(
470+
await limited_gather(
469471
*(
470472
publish_service_started_metrics(
471473
self.rabbitmq_client,
@@ -476,24 +478,22 @@ async def _process_started_tasks(
476478
task=t,
477479
)
478480
for t in tasks
479-
)
481+
),
482+
log=_logger,
483+
limit=_PUBLICATION_CONCURRENCY_LIMIT,
480484
)
481485

482486
# update DB
483487
comp_tasks_repo = CompTasksRepository(self.db_engine)
484-
await asyncio.gather(
485-
*(
486-
comp_tasks_repo.update_project_tasks_state(
487-
t.project_id,
488-
run_id,
489-
[t.node_id],
490-
t.state,
491-
optional_started=utc_now,
492-
optional_progress=t.progress,
493-
)
494-
for t in tasks
488+
for task in tasks:
489+
await comp_tasks_repo.update_project_tasks_state(
490+
project_id,
491+
run_id,
492+
[task.node_id],
493+
task.state,
494+
optional_started=utc_now,
495+
optional_progress=task.progress,
495496
)
496-
)
497497
await CompRunsRepository.instance(self.db_engine).mark_as_started(
498498
user_id=user_id,
499499
project_id=project_id,
@@ -504,18 +504,14 @@ async def _process_started_tasks(
504504
async def _process_waiting_tasks(
505505
self, tasks: list[TaskStateTracker], run_id: PositiveInt
506506
) -> None:
507-
comp_tasks_repo = CompTasksRepository(self.db_engine)
508-
await asyncio.gather(
509-
*(
510-
comp_tasks_repo.update_project_tasks_state(
511-
t.current.project_id,
512-
run_id,
513-
[t.current.node_id],
514-
t.current.state,
515-
)
516-
for t in tasks
507+
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
508+
for task in tasks:
509+
await comp_tasks_repo.update_project_tasks_state(
510+
task.current.project_id,
511+
run_id,
512+
[task.current.node_id],
513+
task.current.state,
517514
)
518-
)
519515

520516
async def _update_states_from_comp_backend(
521517
self,

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 54 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,12 @@
4848
publish_service_stopped_metrics,
4949
)
5050
from ..clusters_keeper import get_or_create_on_demand_cluster
51-
from ..dask_client import DaskClient, PublishedComputationTask
51+
from ..dask_client import DaskClient
5252
from ..dask_clients_pool import DaskClientsPool
5353
from ..db.repositories.comp_runs import (
5454
CompRunsRepository,
5555
)
5656
from ..db.repositories.comp_tasks import CompTasksRepository
57-
from ._constants import (
58-
MAX_CONCURRENT_PIPELINE_SCHEDULING,
59-
)
6057
from ._models import TaskStateTracker
6158
from ._scheduler_base import BaseCompScheduler
6259
from ._utils import (
@@ -68,6 +65,7 @@
6865
_DASK_CLIENT_RUN_REF: Final[str] = "{user_id}:{project_id}:{run_id}"
6966
_TASK_RETRIEVAL_ERROR_TYPE: Final[str] = "task-result-retrieval-timeout"
7067
_TASK_RETRIEVAL_ERROR_CONTEXT_TIME_KEY: Final[str] = "check_time"
68+
_PUBLICATION_CONCURRENCY_LIMIT: Final[int] = 10
7169

7270

7371
@asynccontextmanager
@@ -149,37 +147,31 @@ async def _start_tasks(
149147
RunningState.PENDING,
150148
)
151149
# each task is started independently
152-
results: list[list[PublishedComputationTask]] = await limited_gather(
153-
*(
154-
client.send_computation_tasks(
155-
user_id=user_id,
156-
project_id=project_id,
157-
tasks={node_id: task.image},
158-
hardware_info=task.hardware_info,
159-
callback=wake_up_callback,
160-
metadata=comp_run.metadata,
161-
resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational(
162-
user_id, project_id, node_id, comp_run.iteration
163-
),
164-
)
165-
for node_id, task in scheduled_tasks.items()
166-
),
167-
log=_logger,
168-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
169-
)
170150

171-
# update the database so we do have the correct job_ids there
172-
await limited_gather(
173-
*(
174-
comp_tasks_repo.update_project_task_job_id(
175-
project_id, task.node_id, comp_run.run_id, task.job_id
176-
)
177-
for task_sents in results
178-
for task in task_sents
179-
),
180-
log=_logger,
181-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
182-
)
151+
for node_id, task in scheduled_tasks.items():
152+
published_tasks = await client.send_computation_tasks(
153+
user_id=user_id,
154+
project_id=project_id,
155+
tasks={node_id: task.image},
156+
hardware_info=task.hardware_info,
157+
callback=wake_up_callback,
158+
metadata=comp_run.metadata,
159+
resource_tracking_run_id=ServiceRunID.get_resource_tracking_run_id_for_computational(
160+
user_id, project_id, node_id, comp_run.iteration
161+
),
162+
)
163+
164+
# update the database so we do have the correct job_ids there
165+
await limited_gather(
166+
*(
167+
comp_tasks_repo.update_project_task_job_id(
168+
project_id, task.node_id, comp_run.run_id, task.job_id
169+
)
170+
for task in published_tasks
171+
),
172+
log=_logger,
173+
limit=1,
174+
)
183175

184176
async def _get_tasks_status(
185177
self,
@@ -208,7 +200,7 @@ async def _process_executing_tasks(
208200
tasks: list[CompTaskAtDB],
209201
comp_run: CompRunsAtDB,
210202
) -> None:
211-
task_progresses = []
203+
task_progress_events = []
212204
try:
213205
async with _cluster_dask_client(
214206
user_id,
@@ -218,42 +210,33 @@ async def _process_executing_tasks(
218210
run_id=comp_run.run_id,
219211
run_metadata=comp_run.metadata,
220212
) as client:
221-
task_progresses = [
213+
task_progress_events = [
222214
t
223215
for t in await client.get_tasks_progress(
224216
[f"{t.job_id}" for t in tasks],
225217
)
226218
if t is not None
227219
]
228-
await limited_gather(
229-
*(
230-
CompTasksRepository(self.db_engine).update_project_task_progress(
231-
t.task_owner.project_id,
232-
t.task_owner.node_id,
233-
comp_run.run_id,
234-
t.progress,
235-
)
236-
for t in task_progresses
237-
),
238-
log=_logger,
239-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
240-
)
220+
for progress_event in task_progress_events:
221+
await CompTasksRepository(self.db_engine).update_project_task_progress(
222+
progress_event.task_owner.project_id,
223+
progress_event.task_owner.node_id,
224+
comp_run.run_id,
225+
progress_event.progress,
226+
)
241227

242228
except ComputationalBackendOnDemandNotReadyError:
243229
_logger.info("The on demand computational backend is not ready yet...")
244230

245231
comp_tasks_repo = CompTasksRepository(self.db_engine)
232+
for task in task_progress_events:
233+
await comp_tasks_repo.update_project_task_progress(
234+
task.task_owner.project_id,
235+
task.task_owner.node_id,
236+
comp_run.run_id,
237+
task.progress,
238+
)
246239
await limited_gather(
247-
*(
248-
comp_tasks_repo.update_project_task_progress(
249-
t.task_owner.project_id,
250-
t.task_owner.node_id,
251-
comp_run.run_id,
252-
t.progress,
253-
)
254-
for t in task_progresses
255-
if t
256-
),
257240
*(
258241
publish_service_progress(
259242
self.rabbitmq_client,
@@ -262,11 +245,10 @@ async def _process_executing_tasks(
262245
node_id=t.task_owner.node_id,
263246
progress=t.progress,
264247
)
265-
for t in task_progresses
266-
if t
248+
for t in task_progress_events
267249
),
268250
log=_logger,
269-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
251+
limit=_PUBLICATION_CONCURRENCY_LIMIT,
270252
)
271253

272254
async def _release_resources(self, comp_run: CompRunsAtDB) -> None:
@@ -300,25 +282,14 @@ async def _stop_tasks(
300282
run_id=comp_run.run_id,
301283
run_metadata=comp_run.metadata,
302284
) as client:
303-
await limited_gather(
304-
*(
305-
client.abort_computation_task(t.job_id)
306-
for t in tasks
307-
if t.job_id
308-
),
309-
log=_logger,
310-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
311-
)
312-
# tasks that have no-worker must be unpublished as these are blocking forever
313-
await limited_gather(
314-
*(
315-
client.release_task_result(t.job_id)
316-
for t in tasks
317-
if t.state is RunningState.WAITING_FOR_RESOURCES and t.job_id
318-
),
319-
log=_logger,
320-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
321-
)
285+
for t in tasks:
286+
if not t.job_id:
287+
_logger.warning("%s has no job_id, cannot be stopped", t)
288+
continue
289+
await client.abort_computation_task(t.job_id)
290+
# tasks that have no-worker must be unpublished as these are blocking forever
291+
if t.state is RunningState.WAITING_FOR_RESOURCES:
292+
await client.release_task_result(t.job_id)
322293

323294
async def _process_completed_tasks(
324295
self,
@@ -342,7 +313,7 @@ async def _process_completed_tasks(
342313
),
343314
reraise=False,
344315
log=_logger,
345-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
316+
limit=1, # to avoid overloading the dask scheduler
346317
)
347318
async for future in limited_as_completed(
348319
(
@@ -354,7 +325,7 @@ async def _process_completed_tasks(
354325
)
355326
for task, result in zip(tasks, tasks_results, strict=True)
356327
),
357-
limit=MAX_CONCURRENT_PIPELINE_SCHEDULING,
328+
limit=10, # this is not accessing the dask-scheduelr (only db)
358329
):
359330
with log_catch(_logger, reraise=False):
360331
task_can_be_cleaned, job_id = await future

services/director-v2/src/simcore_service_director_v2/modules/dask_client.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
99
"""
1010

11-
import asyncio
1211
import logging
1312
from collections.abc import Callable, Iterable
1413
from dataclasses import dataclass
@@ -241,9 +240,6 @@ def _comp_sidecar_fct(
241240
)
242241
# NOTE: the callback is running in a secondary thread, and takes a future as arg
243242
task_future.add_done_callback(lambda _: callback())
244-
await distributed.Variable(job_id, client=self.backend.client).set(
245-
task_future
246-
)
247243

248244
await dask_utils.wrap_client_async_routine(
249245
self.backend.client.publish_dataset(task_future, name=job_id)
@@ -560,12 +556,6 @@ async def get_task_result(self, job_id: str) -> TaskOutputData:
560556
async def release_task_result(self, job_id: str) -> None:
561557
_logger.debug("releasing results for %s", f"{job_id=}")
562558
try:
563-
# NOTE: The distributed Variable holds the future of the tasks in the dask-scheduler
564-
# Alas, deleting the variable is done asynchronously and there is no way to ensure
565-
# the variable was effectively deleted.
566-
# This is annoying as one can re-create the variable without error.
567-
var = distributed.Variable(job_id, client=self.backend.client)
568-
await asyncio.get_event_loop().run_in_executor(None, var.delete)
569559
# first check if the key exists
570560
await dask_utils.wrap_client_async_routine(
571561
self.backend.client.get_dataset(name=job_id)

0 commit comments

Comments
 (0)