|
63 | 63 | from ..db.repositories.comp_pipelines import CompPipelinesRepository |
64 | 64 | from ..db.repositories.comp_runs import CompRunsRepository |
65 | 65 | from ..db.repositories.comp_tasks import CompTasksRepository |
| 66 | +from ._models import TaskStateTracker |
66 | 67 | from ._publisher import request_pipeline_scheduling |
67 | 68 | from ._utils import ( |
68 | 69 | COMPLETED_STATES, |
|
76 | 77 | _logger = logging.getLogger(__name__) |
77 | 78 |
|
78 | 79 |
|
79 | | -_Previous = CompTaskAtDB |
80 | | -_Current = CompTaskAtDB |
81 | | - |
82 | 80 | _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta( |
83 | 81 | seconds=30 |
84 | 82 | ) |
@@ -117,47 +115,49 @@ async def _async_cb() -> None: |
117 | 115 | @dataclass(frozen=True, slots=True) |
118 | 116 | class SortedTasks: |
119 | 117 | started: list[CompTaskAtDB] |
120 | | - completed: list[CompTaskAtDB] |
121 | | - waiting: list[CompTaskAtDB] |
122 | | - potentially_lost: list[CompTaskAtDB] |
| 118 | + completed: list[TaskStateTracker] |
| 119 | + waiting: list[TaskStateTracker] |
| 120 | + potentially_lost: list[TaskStateTracker] |
123 | 121 |
|
124 | 122 |
|
125 | 123 | async def _triage_changed_tasks( |
126 | | - changed_tasks: list[tuple[_Previous, _Current]], |
| 124 | + changed_tasks: list[TaskStateTracker], |
127 | 125 | ) -> SortedTasks: |
128 | 126 | started_tasks = [ |
129 | | - current |
130 | | - for previous, current in changed_tasks |
131 | | - if current.state in RUNNING_STATES |
| 127 | + tracker.current |
| 128 | + for tracker in changed_tasks |
| 129 | + if tracker.current.state in RUNNING_STATES |
132 | 130 | or ( |
133 | | - previous.state in WAITING_FOR_START_STATES |
134 | | - and current.state in COMPLETED_STATES |
| 131 | + tracker.previous.state in WAITING_FOR_START_STATES |
| 132 | + and tracker.current.state in COMPLETED_STATES |
135 | 133 | ) |
136 | 134 | ] |
137 | 135 |
|
138 | 136 | completed_tasks = [ |
139 | | - current for _, current in changed_tasks if current.state in COMPLETED_STATES |
| 137 | + tracker |
| 138 | + for tracker in changed_tasks |
| 139 | + if tracker.current.state in COMPLETED_STATES |
140 | 140 | ] |
141 | 141 |
|
142 | 142 | waiting_for_resources_tasks = [ |
143 | | - current |
144 | | - for previous, current in changed_tasks |
145 | | - if current.state in WAITING_FOR_START_STATES |
| 143 | + tracker |
| 144 | + for tracker in changed_tasks |
| 145 | + if tracker.current.state in WAITING_FOR_START_STATES |
146 | 146 | ] |
147 | 147 |
|
148 | 148 | lost_tasks = [ |
149 | | - current |
150 | | - for previous, current in changed_tasks |
151 | | - if (current.state is RunningState.UNKNOWN) |
| 149 | + tracker |
| 150 | + for tracker in changed_tasks |
| 151 | + if (tracker.current.state is RunningState.UNKNOWN) |
152 | 152 | and ( |
153 | | - (arrow.utcnow().datetime - previous.modified) |
| 153 | + (arrow.utcnow().datetime - tracker.previous.modified) |
154 | 154 | > _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS |
155 | 155 | ) |
156 | 156 | ] |
157 | 157 | if lost_tasks: |
158 | 158 | _logger.warning( |
159 | 159 | "%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!", |
160 | | - [t.node_id for t in lost_tasks], |
| 160 | + [t.current.node_id for t in lost_tasks], |
161 | 161 | ) |
162 | 162 |
|
163 | 163 | return SortedTasks( |
@@ -321,6 +321,7 @@ async def _send_running_tasks_heartbeat( |
321 | 321 | def _need_heartbeat(task: CompTaskAtDB) -> bool: |
322 | 322 | if task.state not in RUNNING_STATES: |
323 | 323 | return False |
| 324 | + |
324 | 325 | if task.last_heartbeat is None: |
325 | 326 | assert task.start # nosec |
326 | 327 | return bool( |
@@ -362,14 +363,14 @@ async def _get_changed_tasks_from_backend( |
362 | 363 | user_id: UserID, |
363 | 364 | processing_tasks: list[CompTaskAtDB], |
364 | 365 | comp_run: CompRunsAtDB, |
365 | | - ) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]: |
| 366 | + ) -> tuple[list[TaskStateTracker], list[CompTaskAtDB]]: |
366 | 367 | tasks_backend_status = await self._get_tasks_status( |
367 | 368 | user_id, processing_tasks, comp_run |
368 | 369 | ) |
369 | 370 |
|
370 | 371 | return ( |
371 | 372 | [ |
372 | | - ( |
| 373 | + TaskStateTracker( |
373 | 374 | task, |
374 | 375 | task.model_copy(update={"state": backend_state}), |
375 | 376 | ) |
@@ -502,16 +503,16 @@ async def _process_started_tasks( |
502 | 503 | ) |
503 | 504 |
|
504 | 505 | async def _process_waiting_tasks( |
505 | | - self, tasks: list[CompTaskAtDB], run_id: PositiveInt |
| 506 | + self, tasks: list[TaskStateTracker], run_id: PositiveInt |
506 | 507 | ) -> None: |
507 | 508 | comp_tasks_repo = CompTasksRepository(self.db_engine) |
508 | 509 | await asyncio.gather( |
509 | 510 | *( |
510 | 511 | comp_tasks_repo.update_project_tasks_state( |
511 | | - t.project_id, |
| 512 | + t.current.project_id, |
512 | 513 | run_id, |
513 | | - [t.node_id], |
514 | | - t.state, |
| 514 | + [t.current.node_id], |
| 515 | + t.current.state, |
515 | 516 | ) |
516 | 517 | for t in tasks |
517 | 518 | ) |
@@ -602,7 +603,7 @@ async def _stop_tasks( |
602 | 603 | async def _process_completed_tasks( |
603 | 604 | self, |
604 | 605 | user_id: UserID, |
605 | | - tasks: list[CompTaskAtDB], |
| 606 | + tasks: list[TaskStateTracker], |
606 | 607 | iteration: Iteration, |
607 | 608 | comp_run: CompRunsAtDB, |
608 | 609 | ) -> None: |
|
0 commit comments