Skip to content

Commit cc71ed1

Browse files
committed
manage to poll task progress now
1 parent 1437d82 commit cc71ed1

File tree

3 files changed

+129
-19
lines changed

3 files changed

+129
-19
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ class SortedTasks:
116116

117117

118118
async def _triage_changed_tasks(
119-
changed_tasks: list[tuple[_Previous, _Current]],
119+
changed_tasks_or_executing: list[tuple[_Previous, _Current]],
120120
) -> SortedTasks:
121121
started_tasks = [
122122
current
123-
for previous, current in changed_tasks
123+
for previous, current in changed_tasks_or_executing
124124
if current.state in RUNNING_STATES
125125
or (
126126
previous.state in WAITING_FOR_START_STATES
@@ -130,17 +130,21 @@ async def _triage_changed_tasks(
130130

131131
# NOTE: some tasks can be both started and completed since we might have the time they were running
132132
completed_tasks = [
133-
current for _, current in changed_tasks if current.state in COMPLETED_STATES
133+
current
134+
for _, current in changed_tasks_or_executing
135+
if current.state in COMPLETED_STATES
134136
]
135137

136138
waiting_for_resources_tasks = [
137139
current
138-
for previous, current in changed_tasks
140+
for previous, current in changed_tasks_or_executing
139141
if current.state in WAITING_FOR_START_STATES
140142
]
141143

142144
lost_or_momentarily_lost_tasks = [
143-
current for _, current in changed_tasks if current.state is RunningState.UNKNOWN
145+
current
146+
for _, current in changed_tasks_or_executing
147+
if current.state is RunningState.UNKNOWN
144148
]
145149
if lost_or_momentarily_lost_tasks:
146150
_logger.warning(
@@ -321,21 +325,30 @@ async def _get_changed_tasks_from_backend(
321325
user_id: UserID,
322326
processing_tasks: list[CompTaskAtDB],
323327
comp_run: CompRunsAtDB,
324-
) -> list[tuple[_Previous, _Current]]:
328+
) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]:
325329
tasks_backend_status = await self._get_tasks_status(
326330
user_id, processing_tasks, comp_run
327331
)
328332

329-
return [
330-
(
331-
task,
332-
task.model_copy(update={"state": backend_state}),
333-
)
334-
for task, backend_state in zip(
335-
processing_tasks, tasks_backend_status, strict=True
336-
)
337-
if task.state is not backend_state
338-
]
333+
return (
334+
[
335+
(
336+
task,
337+
task.model_copy(update={"state": backend_state}),
338+
)
339+
for task, backend_state in zip(
340+
processing_tasks, tasks_backend_status, strict=True
341+
)
342+
if task.state is not backend_state
343+
],
344+
[
345+
task
346+
for task, backend_state in zip(
347+
processing_tasks, tasks_backend_status, strict=True
348+
)
349+
if task.state is backend_state is RunningState.STARTED
350+
],
351+
)
339352

340353
async def _process_started_tasks(
341354
self,
@@ -476,7 +489,10 @@ async def _update_states_from_comp_backend(
476489
return
477490

478491
# get the tasks which state actually changed since last check
479-
tasks_with_changed_states = await self._get_changed_tasks_from_backend(
492+
(
493+
tasks_with_changed_states,
494+
executing_tasks,
495+
) = await self._get_changed_tasks_from_backend(
480496
user_id, tasks_inprocess, comp_run
481497
)
482498
# NOTE: typical states a task goes through
@@ -511,6 +527,9 @@ async def _update_states_from_comp_backend(
511527
if sorted_tasks.waiting:
512528
await self._process_waiting_tasks(sorted_tasks.waiting)
513529

530+
if executing_tasks:
531+
await self._process_executing_tasks(user_id, executing_tasks, comp_run)
532+
514533
@abstractmethod
515534
async def _start_tasks(
516535
self,
@@ -545,6 +564,15 @@ async def _process_completed_tasks(
545564
) -> None:
546565
"""process tasks from the 3rd party backend"""
547566

567+
@abstractmethod
568+
async def _process_executing_tasks(
569+
self,
570+
user_id: UserID,
571+
tasks: list[CompTaskAtDB],
572+
comp_run: CompRunsAtDB,
573+
) -> None:
574+
"""process executing tasks from the 3rd party backend"""
575+
548576
async def apply(
549577
self,
550578
*,

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from ...utils.dask_client_utils import TaskHandlers, UnixTimestamp
3939
from ...utils.rabbitmq import (
40+
publish_service_progress,
4041
publish_service_resource_tracking_stopped,
4142
publish_service_stopped_metrics,
4243
)
@@ -150,6 +151,62 @@ async def _get_tasks_status(
150151
_logger.info("The on demand computational backend is not ready yet...")
151152
return [RunningState.WAITING_FOR_CLUSTER] * len(tasks)
152153

154+
async def _process_executing_tasks(
155+
self,
156+
user_id: UserID,
157+
tasks: list[CompTaskAtDB],
158+
comp_run: CompRunsAtDB,
159+
) -> None:
160+
task_progresses = []
161+
try:
162+
async with _cluster_dask_client(
163+
user_id,
164+
self,
165+
use_on_demand_clusters=comp_run.use_on_demand_clusters,
166+
run_metadata=comp_run.metadata,
167+
) as client:
168+
task_progresses = await client.get_tasks_progress(
169+
[f"{t.job_id}" for t in tasks],
170+
)
171+
for task_progress_event in task_progresses:
172+
if task_progress_event:
173+
await CompTasksRepository(
174+
self.db_engine
175+
).update_project_task_progress(
176+
task_progress_event.task_owner.project_id,
177+
task_progress_event.task_owner.node_id,
178+
task_progress_event.progress,
179+
)
180+
181+
except ComputationalBackendOnDemandNotReadyError:
182+
_logger.info("The on demand computational backend is not ready yet...")
183+
184+
comp_tasks_repo = CompTasksRepository(self.db_engine)
185+
await asyncio.gather(
186+
*(
187+
comp_tasks_repo.update_project_task_progress(
188+
t.task_owner.project_id,
189+
t.task_owner.node_id,
190+
t.progress,
191+
)
192+
for t in task_progresses
193+
if t
194+
)
195+
)
196+
await asyncio.gather(
197+
*(
198+
publish_service_progress(
199+
self.rabbitmq_client,
200+
user_id=t.task_owner.user_id,
201+
project_id=t.task_owner.project_id,
202+
node_id=t.task_owner.node_id,
203+
progress=t.progress,
204+
)
205+
for t in task_progresses
206+
if t
207+
)
208+
)
209+
153210
async def _stop_tasks(
154211
self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB
155212
) -> None:

services/director-v2/src/simcore_service_director_v2/modules/dask_client.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,33 @@ async def send_computation_tasks(
405405

406406
return list_of_node_id_to_job_id
407407

408+
async def get_tasks_progress(
409+
self, job_ids: list[str]
410+
) -> tuple[TaskProgressEvent | None, ...]:
411+
dask_utils.check_scheduler_is_still_the_same(
412+
self.backend.scheduler_id, self.backend.client
413+
)
414+
dask_utils.check_communication_with_scheduler_is_open(self.backend.client)
415+
dask_utils.check_scheduler_status(self.backend.client)
416+
417+
dask_events = await self.backend.client.get_events(
418+
TaskProgressEvent.topic_name()
419+
)
420+
421+
if not dask_events:
422+
return tuple([None] * len(job_ids))
423+
last_task_progress = []
424+
for job_id in job_ids:
425+
progress_event = None
426+
for dask_event in reversed(dask_events):
427+
parsed_event = TaskProgressEvent.model_validate_json(dask_event[1])
428+
if parsed_event.job_id == job_id:
429+
progress_event = parsed_event
430+
break
431+
last_task_progress.append(progress_event)
432+
433+
return tuple(last_task_progress)
434+
408435
async def get_tasks_status(self, job_ids: Iterable[str]) -> list[RunningState]:
409436
dask_utils.check_scheduler_is_still_the_same(
410437
self.backend.scheduler_id, self.backend.client
@@ -413,8 +440,6 @@ async def get_tasks_status(self, job_ids: Iterable[str]) -> list[RunningState]:
413440
dask_utils.check_scheduler_status(self.backend.client)
414441

415442
async def _get_job_id_status(job_id: str) -> RunningState:
416-
# TODO: maybe we should define an event just for that, instead of multiple calls
417-
# but the max length by default is 1000. We should test it
418443
dask_events: tuple[tuple[UnixTimestamp, str], ...] = (
419444
await self.backend.client.get_events(
420445
TASK_LIFE_CYCLE_EVENT.format(key=job_id)

0 commit comments

Comments
 (0)