| 
63 | 63 | from ..db.repositories.comp_pipelines import CompPipelinesRepository  | 
64 | 64 | from ..db.repositories.comp_runs import CompRunsRepository  | 
65 | 65 | from ..db.repositories.comp_tasks import CompTasksRepository  | 
 | 66 | +from ._models import TaskStateTracker  | 
66 | 67 | from ._publisher import request_pipeline_scheduling  | 
67 | 68 | from ._utils import (  | 
68 | 69 |     COMPLETED_STATES,  | 
 | 
76 | 77 | _logger = logging.getLogger(__name__)  | 
77 | 78 | 
 
  | 
78 | 79 | 
 
  | 
79 |  | -_Previous = CompTaskAtDB  | 
80 |  | -_Current = CompTaskAtDB  | 
81 |  | - | 
82 | 80 | _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta(  | 
83 | 81 |     seconds=30  | 
84 | 82 | )  | 
@@ -117,47 +115,49 @@ async def _async_cb() -> None:  | 
117 | 115 | @dataclass(frozen=True, slots=True)  | 
118 | 116 | class SortedTasks:  | 
119 | 117 |     started: list[CompTaskAtDB]  | 
120 |  | -    completed: list[CompTaskAtDB]  | 
121 |  | -    waiting: list[CompTaskAtDB]  | 
122 |  | -    potentially_lost: list[CompTaskAtDB]  | 
 | 118 | +    completed: list[TaskStateTracker]  | 
 | 119 | +    waiting: list[TaskStateTracker]  | 
 | 120 | +    potentially_lost: list[TaskStateTracker]  | 
123 | 121 | 
 
  | 
124 | 122 | 
 
  | 
125 | 123 | async def _triage_changed_tasks(  | 
126 |  | -    changed_tasks: list[tuple[_Previous, _Current]],  | 
 | 124 | +    changed_tasks: list[TaskStateTracker],  | 
127 | 125 | ) -> SortedTasks:  | 
128 | 126 |     started_tasks = [  | 
129 |  | -        current  | 
130 |  | -        for previous, current in changed_tasks  | 
131 |  | -        if current.state in RUNNING_STATES  | 
 | 127 | +        tracker.current  | 
 | 128 | +        for tracker in changed_tasks  | 
 | 129 | +        if tracker.current.state in RUNNING_STATES  | 
132 | 130 |         or (  | 
133 |  | -            previous.state in WAITING_FOR_START_STATES  | 
134 |  | -            and current.state in COMPLETED_STATES  | 
 | 131 | +            tracker.previous.state in WAITING_FOR_START_STATES  | 
 | 132 | +            and tracker.current.state in COMPLETED_STATES  | 
135 | 133 |         )  | 
136 | 134 |     ]  | 
137 | 135 | 
 
  | 
138 | 136 |     completed_tasks = [  | 
139 |  | -        current for _, current in changed_tasks if current.state in COMPLETED_STATES  | 
 | 137 | +        tracker  | 
 | 138 | +        for tracker in changed_tasks  | 
 | 139 | +        if tracker.current.state in COMPLETED_STATES  | 
140 | 140 |     ]  | 
141 | 141 | 
 
  | 
142 | 142 |     waiting_for_resources_tasks = [  | 
143 |  | -        current  | 
144 |  | -        for previous, current in changed_tasks  | 
145 |  | -        if current.state in WAITING_FOR_START_STATES  | 
 | 143 | +        tracker  | 
 | 144 | +        for tracker in changed_tasks  | 
 | 145 | +        if tracker.current.state in WAITING_FOR_START_STATES  | 
146 | 146 |     ]  | 
147 | 147 | 
 
  | 
148 | 148 |     lost_tasks = [  | 
149 |  | -        current  | 
150 |  | -        for previous, current in changed_tasks  | 
151 |  | -        if (current.state is RunningState.UNKNOWN)  | 
 | 149 | +        tracker  | 
 | 150 | +        for tracker in changed_tasks  | 
 | 151 | +        if (tracker.current.state is RunningState.UNKNOWN)  | 
152 | 152 |         and (  | 
153 |  | -            (arrow.utcnow().datetime - previous.modified)  | 
 | 153 | +            (arrow.utcnow().datetime - tracker.previous.modified)  | 
154 | 154 |             > _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS  | 
155 | 155 |         )  | 
156 | 156 |     ]  | 
157 | 157 |     if lost_tasks:  | 
158 | 158 |         _logger.warning(  | 
159 | 159 |             "%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!",  | 
160 |  | -            [t.node_id for t in lost_tasks],  | 
 | 160 | +            [t.current.node_id for t in lost_tasks],  | 
161 | 161 |         )  | 
162 | 162 | 
 
  | 
163 | 163 |     return SortedTasks(  | 
@@ -321,6 +321,7 @@ async def _send_running_tasks_heartbeat(  | 
321 | 321 |         def _need_heartbeat(task: CompTaskAtDB) -> bool:  | 
322 | 322 |             if task.state not in RUNNING_STATES:  | 
323 | 323 |                 return False  | 
 | 324 | + | 
324 | 325 |             if task.last_heartbeat is None:  | 
325 | 326 |                 assert task.start  # nosec  | 
326 | 327 |                 return bool(  | 
@@ -362,14 +363,14 @@ async def _get_changed_tasks_from_backend(  | 
362 | 363 |         user_id: UserID,  | 
363 | 364 |         processing_tasks: list[CompTaskAtDB],  | 
364 | 365 |         comp_run: CompRunsAtDB,  | 
365 |  | -    ) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]:  | 
 | 366 | +    ) -> tuple[list[TaskStateTracker], list[CompTaskAtDB]]:  | 
366 | 367 |         tasks_backend_status = await self._get_tasks_status(  | 
367 | 368 |             user_id, processing_tasks, comp_run  | 
368 | 369 |         )  | 
369 | 370 | 
 
  | 
370 | 371 |         return (  | 
371 | 372 |             [  | 
372 |  | -                (  | 
 | 373 | +                TaskStateTracker(  | 
373 | 374 |                     task,  | 
374 | 375 |                     task.model_copy(update={"state": backend_state}),  | 
375 | 376 |                 )  | 
@@ -502,16 +503,16 @@ async def _process_started_tasks(  | 
502 | 503 |         )  | 
503 | 504 | 
 
  | 
504 | 505 |     async def _process_waiting_tasks(  | 
505 |  | -        self, tasks: list[CompTaskAtDB], run_id: PositiveInt  | 
 | 506 | +        self, tasks: list[TaskStateTracker], run_id: PositiveInt  | 
506 | 507 |     ) -> None:  | 
507 | 508 |         comp_tasks_repo = CompTasksRepository(self.db_engine)  | 
508 | 509 |         await asyncio.gather(  | 
509 | 510 |             *(  | 
510 | 511 |                 comp_tasks_repo.update_project_tasks_state(  | 
511 |  | -                    t.project_id,  | 
 | 512 | +                    t.current.project_id,  | 
512 | 513 |                     run_id,  | 
513 |  | -                    [t.node_id],  | 
514 |  | -                    t.state,  | 
 | 514 | +                    [t.current.node_id],  | 
 | 515 | +                    t.current.state,  | 
515 | 516 |                 )  | 
516 | 517 |                 for t in tasks  | 
517 | 518 |             )  | 
@@ -602,7 +603,7 @@ async def _stop_tasks(  | 
602 | 603 |     async def _process_completed_tasks(  | 
603 | 604 |         self,  | 
604 | 605 |         user_id: UserID,  | 
605 |  | -        tasks: list[CompTaskAtDB],  | 
 | 606 | +        tasks: list[TaskStateTracker],  | 
606 | 607 |         iteration: Iteration,  | 
607 | 608 |         comp_run: CompRunsAtDB,  | 
608 | 609 |     ) -> None:  | 
 | 
0 commit comments