Skip to content

Commit 5a1948d

Browse files
committed
task now revert to old state to prevent issue in heartbeat
1 parent b58912d commit 5a1948d

File tree

3 files changed

+51
-39
lines changed

3 files changed

+51
-39
lines changed
Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1+
from dataclasses import dataclass
12
from typing import Literal
23

34
from models_library.projects import ProjectID
45
from models_library.rabbitmq_messages import RabbitMessageBase
56
from models_library.users import UserID
67

78
from ...models.comp_runs import Iteration
9+
from ...models.comp_tasks import CompTaskAtDB
810

911

1012
class SchedulePipelineRabbitMessage(RabbitMessageBase):
11-
channel_name: Literal[
13+
channel_name: Literal["simcore.services.director-v2.scheduling"] = (
1214
"simcore.services.director-v2.scheduling"
13-
] = "simcore.services.director-v2.scheduling"
15+
)
1416
user_id: UserID
1517
project_id: ProjectID
1618
iteration: Iteration
1719

1820
def routing_key(self) -> str | None: # pylint: disable=no-self-use # abstract
1921
return None
22+
23+
24+
@dataclass(frozen=True, slots=True)
25+
class TaskStateTracker:
26+
previous: CompTaskAtDB
27+
current: CompTaskAtDB

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from ..db.repositories.comp_pipelines import CompPipelinesRepository
6464
from ..db.repositories.comp_runs import CompRunsRepository
6565
from ..db.repositories.comp_tasks import CompTasksRepository
66+
from ._models import TaskStateTracker
6667
from ._publisher import request_pipeline_scheduling
6768
from ._utils import (
6869
COMPLETED_STATES,
@@ -76,12 +77,6 @@
7677
_logger = logging.getLogger(__name__)
7778

7879

79-
@dataclass(frozen=True, slots=True)
80-
class TaskStateTracker:
81-
previous: CompTaskAtDB
82-
current: CompTaskAtDB
83-
84-
8580
_MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta(
8681
seconds=30
8782
)
@@ -120,9 +115,9 @@ async def _async_cb() -> None:
120115
@dataclass(frozen=True, slots=True)
121116
class SortedTasks:
122117
started: list[CompTaskAtDB]
123-
completed: list[CompTaskAtDB]
124-
waiting: list[CompTaskAtDB]
125-
potentially_lost: list[CompTaskAtDB]
118+
completed: list[TaskStateTracker]
119+
waiting: list[TaskStateTracker]
120+
potentially_lost: list[TaskStateTracker]
126121

127122

128123
async def _triage_changed_tasks(
@@ -139,19 +134,19 @@ async def _triage_changed_tasks(
139134
]
140135

141136
completed_tasks = [
142-
tracker.current
137+
tracker
143138
for tracker in changed_tasks
144139
if tracker.current.state in COMPLETED_STATES
145140
]
146141

147142
waiting_for_resources_tasks = [
148-
tracker.current
143+
tracker
149144
for tracker in changed_tasks
150145
if tracker.current.state in WAITING_FOR_START_STATES
151146
]
152147

153148
lost_tasks = [
154-
tracker.current
149+
tracker
155150
for tracker in changed_tasks
156151
if (tracker.current.state is RunningState.UNKNOWN)
157152
and (
@@ -162,7 +157,7 @@ async def _triage_changed_tasks(
162157
if lost_tasks:
163158
_logger.warning(
164159
"%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!",
165-
[t.node_id for t in lost_tasks],
160+
[t.current.node_id for t in lost_tasks],
166161
)
167162

168163
return SortedTasks(
@@ -508,16 +503,16 @@ async def _process_started_tasks(
508503
)
509504

510505
async def _process_waiting_tasks(
511-
self, tasks: list[CompTaskAtDB], run_id: PositiveInt
506+
self, tasks: list[TaskStateTracker], run_id: PositiveInt
512507
) -> None:
513508
comp_tasks_repo = CompTasksRepository(self.db_engine)
514509
await asyncio.gather(
515510
*(
516511
comp_tasks_repo.update_project_tasks_state(
517-
t.project_id,
512+
t.current.project_id,
518513
run_id,
519-
[t.node_id],
520-
t.state,
514+
[t.current.node_id],
515+
t.current.state,
521516
)
522517
for t in tasks
523518
)
@@ -608,7 +603,7 @@ async def _stop_tasks(
608603
async def _process_completed_tasks(
609604
self,
610605
user_id: UserID,
611-
tasks: list[CompTaskAtDB],
606+
tasks: list[TaskStateTracker],
612607
iteration: Iteration,
613608
comp_run: CompRunsAtDB,
614609
) -> None:

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from ._constants import (
5656
MAX_CONCURRENT_PIPELINE_SCHEDULING,
5757
)
58+
from ._models import TaskStateTracker
5859
from ._scheduler_base import BaseCompScheduler
5960
from ._utils import (
6061
WAITING_FOR_START_STATES,
@@ -293,7 +294,7 @@ async def _stop_tasks(
293294
async def _process_completed_tasks(
294295
self,
295296
user_id: UserID,
296-
tasks: list[CompTaskAtDB],
297+
tasks: list[TaskStateTracker],
297298
iteration: Iteration,
298299
comp_run: CompRunsAtDB,
299300
) -> None:
@@ -306,13 +307,20 @@ async def _process_completed_tasks(
306307
run_metadata=comp_run.metadata,
307308
) as client:
308309
tasks_results = await asyncio.gather(
309-
*[client.get_task_result(t.job_id or "undefined") for t in tasks],
310+
*[
311+
client.get_task_result(t.current.job_id or "undefined")
312+
for t in tasks
313+
],
310314
return_exceptions=True,
311315
)
312316
async for future in limited_as_completed(
313317
(
314318
self._process_task_result(
315-
task, result, comp_run.metadata, iteration, comp_run.run_id
319+
task,
320+
result,
321+
comp_run.metadata,
322+
iteration,
323+
comp_run.run_id,
316324
)
317325
for task, result in zip(tasks, tasks_results, strict=True)
318326
),
@@ -467,31 +475,30 @@ async def _handle_task_error(
467475

468476
async def _process_task_result(
469477
self,
470-
task: CompTaskAtDB,
478+
task: TaskStateTracker,
471479
result: BaseException | TaskOutputData,
472480
run_metadata: RunMetadataDict,
473481
iteration: Iteration,
474482
run_id: PositiveInt,
475483
) -> tuple[bool, str]:
476484
"""Returns True and the job ID if the task was successfully processed and can be released from the Dask cluster."""
477485
_logger.debug("received %s result: %s", f"{task=}", f"{result=}")
478-
479-
assert task.job_id # nosec
486+
assert task.current.job_id # nosec
480487
(
481488
_service_key,
482489
_service_version,
483490
user_id,
484491
project_id,
485492
node_id,
486-
) = parse_dask_job_id(task.job_id)
493+
) = parse_dask_job_id(task.current.job_id)
487494

488-
assert task.project_id == project_id # nosec
489-
assert task.node_id == node_id # nosec
495+
assert task.current.project_id == project_id # nosec
496+
assert task.current.node_id == node_id # nosec
490497
log_error_context = {
491498
"user_id": user_id,
492499
"project_id": project_id,
493500
"node_id": node_id,
494-
"job_id": task.job_id,
501+
"job_id": task.current.job_id,
495502
}
496503

497504
if isinstance(result, TaskOutputData):
@@ -500,7 +507,9 @@ async def _process_task_result(
500507
simcore_platform_status,
501508
task_errors,
502509
task_completed,
503-
) = await self._handle_successful_run(task, result, log_error_context)
510+
) = await self._handle_successful_run(
511+
task.current, result, log_error_context
512+
)
504513

505514
elif isinstance(result, ComputationalBackendTaskResultsNotReadyError):
506515
(
@@ -509,7 +518,7 @@ async def _process_task_result(
509518
task_errors,
510519
task_completed,
511520
) = await self._handle_computational_retrieval_error(
512-
task, user_id, result, log_error_context
521+
task.current, user_id, result, log_error_context
513522
)
514523
elif isinstance(result, ComputationalBackendNotConnectedError):
515524
(
@@ -518,15 +527,15 @@ async def _process_task_result(
518527
task_errors,
519528
task_completed,
520529
) = await self._handle_computational_backend_not_connected_error(
521-
task, result, log_error_context
530+
task.current, result, log_error_context
522531
)
523532
else:
524533
(
525534
task_final_state,
526535
simcore_platform_status,
527536
task_errors,
528537
task_completed,
529-
) = await self._handle_task_error(task, result, log_error_context)
538+
) = await self._handle_task_error(task.current, result, log_error_context)
530539

531540
# we need to remove any invalid files in the storage
532541
await clean_task_output_and_log_files_if_invalid(
@@ -549,21 +558,21 @@ async def _process_task_result(
549558
simcore_user_agent=run_metadata.get(
550559
"simcore_user_agent", UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE
551560
),
552-
task=task,
561+
task=task.current,
553562
task_final_state=task_final_state,
554563
)
555564

556565
await CompTasksRepository(self.db_engine).update_project_tasks_state(
557-
task.project_id,
566+
task.current.project_id,
558567
run_id,
559-
[task.node_id],
560-
task_final_state if task_completed else RunningState.STARTED,
568+
[task.current.node_id],
569+
task_final_state if task_completed else task.previous.state,
561570
errors=task_errors,
562571
optional_progress=1 if task_completed else None,
563572
optional_stopped=arrow.utcnow().datetime if task_completed else None,
564573
)
565574

566-
return task_completed, task.job_id
575+
return task_completed, task.current.job_id
567576

568577
async def _task_progress_change_handler(
569578
self, event: tuple[UnixTimestamp, Any]

0 commit comments

Comments
 (0)