5555from ._constants import (
5656 MAX_CONCURRENT_PIPELINE_SCHEDULING ,
5757)
58+ from ._models import TaskStateTracker
5859from ._scheduler_base import BaseCompScheduler
5960from ._utils import (
6061 WAITING_FOR_START_STATES ,
@@ -293,7 +294,7 @@ async def _stop_tasks(
293294 async def _process_completed_tasks (
294295 self ,
295296 user_id : UserID ,
296- tasks : list [CompTaskAtDB ],
297+ tasks : list [TaskStateTracker ],
297298 iteration : Iteration ,
298299 comp_run : CompRunsAtDB ,
299300 ) -> None :
@@ -306,13 +307,20 @@ async def _process_completed_tasks(
306307 run_metadata = comp_run .metadata ,
307308 ) as client :
308309 tasks_results = await asyncio .gather (
309- * [client .get_task_result (t .job_id or "undefined" ) for t in tasks ],
310+ * [
311+ client .get_task_result (t .current .job_id or "undefined" )
312+ for t in tasks
313+ ],
310314 return_exceptions = True ,
311315 )
312316 async for future in limited_as_completed (
313317 (
314318 self ._process_task_result (
315- task , result , comp_run .metadata , iteration , comp_run .run_id
319+ task ,
320+ result ,
321+ comp_run .metadata ,
322+ iteration ,
323+ comp_run .run_id ,
316324 )
317325 for task , result in zip (tasks , tasks_results , strict = True )
318326 ),
@@ -467,31 +475,30 @@ async def _handle_task_error(
467475
468476 async def _process_task_result (
469477 self ,
470- task : CompTaskAtDB ,
478+ task : TaskStateTracker ,
471479 result : BaseException | TaskOutputData ,
472480 run_metadata : RunMetadataDict ,
473481 iteration : Iteration ,
474482 run_id : PositiveInt ,
475483 ) -> tuple [bool , str ]:
476484 """Returns True and the job ID if the task was successfully processed and can be released from the Dask cluster."""
477485 _logger .debug ("received %s result: %s" , f"{ task = } " , f"{ result = } " )
478-
479- assert task .job_id # nosec
486+ assert task .current .job_id # nosec
480487 (
481488 _service_key ,
482489 _service_version ,
483490 user_id ,
484491 project_id ,
485492 node_id ,
486- ) = parse_dask_job_id (task .job_id )
493+ ) = parse_dask_job_id (task .current . job_id )
487494
488- assert task .project_id == project_id # nosec
489- assert task .node_id == node_id # nosec
495+ assert task .current . project_id == project_id # nosec
496+ assert task .current . node_id == node_id # nosec
490497 log_error_context = {
491498 "user_id" : user_id ,
492499 "project_id" : project_id ,
493500 "node_id" : node_id ,
494- "job_id" : task .job_id ,
501+ "job_id" : task .current . job_id ,
495502 }
496503
497504 if isinstance (result , TaskOutputData ):
@@ -500,7 +507,9 @@ async def _process_task_result(
500507 simcore_platform_status ,
501508 task_errors ,
502509 task_completed ,
503- ) = await self ._handle_successful_run (task , result , log_error_context )
510+ ) = await self ._handle_successful_run (
511+ task .current , result , log_error_context
512+ )
504513
505514 elif isinstance (result , ComputationalBackendTaskResultsNotReadyError ):
506515 (
@@ -509,7 +518,7 @@ async def _process_task_result(
509518 task_errors ,
510519 task_completed ,
511520 ) = await self ._handle_computational_retrieval_error (
512- task , user_id , result , log_error_context
521+ task . current , user_id , result , log_error_context
513522 )
514523 elif isinstance (result , ComputationalBackendNotConnectedError ):
515524 (
@@ -518,15 +527,15 @@ async def _process_task_result(
518527 task_errors ,
519528 task_completed ,
520529 ) = await self ._handle_computational_backend_not_connected_error (
521- task , result , log_error_context
530+ task . current , result , log_error_context
522531 )
523532 else :
524533 (
525534 task_final_state ,
526535 simcore_platform_status ,
527536 task_errors ,
528537 task_completed ,
529- ) = await self ._handle_task_error (task , result , log_error_context )
538+ ) = await self ._handle_task_error (task . current , result , log_error_context )
530539
531540 # we need to remove any invalid files in the storage
532541 await clean_task_output_and_log_files_if_invalid (
@@ -549,21 +558,21 @@ async def _process_task_result(
549558 simcore_user_agent = run_metadata .get (
550559 "simcore_user_agent" , UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
551560 ),
552- task = task ,
561+ task = task . current ,
553562 task_final_state = task_final_state ,
554563 )
555564
556565 await CompTasksRepository (self .db_engine ).update_project_tasks_state (
557- task .project_id ,
566+ task .current . project_id ,
558567 run_id ,
559- [task .node_id ],
560- task_final_state if task_completed else RunningState . STARTED ,
568+ [task .current . node_id ],
569+ task_final_state if task_completed else task . previous . state ,
561570 errors = task_errors ,
562571 optional_progress = 1 if task_completed else None ,
563572 optional_stopped = arrow .utcnow ().datetime if task_completed else None ,
564573 )
565574
566- return task_completed , task .job_id
575+ return task_completed , task .current . job_id
567576
568577 async def _task_progress_change_handler (
569578 self , event : tuple [UnixTimestamp , Any ]
0 commit comments