4040from dask_task_models_library .container_tasks .utils import generate_dask_job_id
4141from dask_task_models_library .models import (
4242 TASK_LIFE_CYCLE_EVENT ,
43+ TASK_RUNNING_PROGRESS_EVENT ,
4344 DaskJobID ,
4445 DaskResources ,
4546 TaskLifeCycleState ,
@@ -407,30 +408,25 @@ async def send_computation_tasks(
407408
408409 async def get_tasks_progress (
409410 self , job_ids : list [str ]
410- ) -> tuple [TaskProgressEvent | None , ... ]:
411+ ) -> list [TaskProgressEvent | None ]:
411412 dask_utils .check_scheduler_is_still_the_same (
412413 self .backend .scheduler_id , self .backend .client
413414 )
414415 dask_utils .check_communication_with_scheduler_is_open (self .backend .client )
415416 dask_utils .check_scheduler_status (self .backend .client )
416417
417- dask_events = await self .backend .client .get_events (
418- TaskProgressEvent .topic_name ()
419- )
420-
421- if not dask_events :
422- return tuple ([None ] * len (job_ids ))
423- last_task_progress = []
424- for job_id in job_ids :
425- progress_event = None
426- for dask_event in reversed (dask_events ):
427- parsed_event = TaskProgressEvent .model_validate_json (dask_event [1 ])
428- if parsed_event .job_id == job_id :
429- progress_event = parsed_event
430- break
431- last_task_progress .append (progress_event )
418+ async def _get_task_progress (job_id : str ) -> TaskProgressEvent | None :
419+ dask_events : tuple [tuple [UnixTimestamp , str ], ...] = (
420+ await self .backend .client .get_events (
421+ TASK_RUNNING_PROGRESS_EVENT .format (key = job_id )
422+ )
423+ )
424+ if not dask_events :
425+ return None
426+ # we are interested in the last event
427+ return TaskProgressEvent .model_validate_json (dask_events [- 1 ][1 ])
432428
433- return tuple ( last_task_progress )
429+ return await asyncio . gather ( * ( _get_task_progress ( job_id ) for job_id in job_ids ) )
434430
435431 async def get_tasks_status (self , job_ids : Iterable [str ]) -> list [RunningState ]:
436432 dask_utils .check_scheduler_is_still_the_same (
@@ -439,7 +435,7 @@ async def get_tasks_status(self, job_ids: Iterable[str]) -> list[RunningState]:
439435 dask_utils .check_communication_with_scheduler_is_open (self .backend .client )
440436 dask_utils .check_scheduler_status (self .backend .client )
441437
442- async def _get_job_id_status (job_id : str ) -> RunningState :
438+ async def _get_task_state (job_id : str ) -> RunningState :
443439 dask_events : tuple [tuple [UnixTimestamp , str ], ...] = (
444440 await self .backend .client .get_events (
445441 TASK_LIFE_CYCLE_EVENT .format (key = job_id )
@@ -480,7 +476,7 @@ async def _get_job_id_status(job_id: str) -> RunningState:
480476
481477 return parsed_event .state
482478
483- return await asyncio .gather (* (_get_job_id_status (job_id ) for job_id in job_ids ))
479+ return await asyncio .gather (* (_get_task_state (job_id ) for job_id in job_ids ))
484480
485481 async def abort_computation_task (self , job_id : str ) -> None :
486482 # Dask future may be cancelled, but only a future that was not already taken by
0 commit comments