Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from dataclasses import dataclass
from typing import Literal

from models_library.projects import ProjectID
from models_library.rabbitmq_messages import RabbitMessageBase
from models_library.users import UserID

from ...models.comp_runs import Iteration
from ...models.comp_tasks import CompTaskAtDB


class SchedulePipelineRabbitMessage(RabbitMessageBase):
channel_name: Literal[
channel_name: Literal["simcore.services.director-v2.scheduling"] = (
"simcore.services.director-v2.scheduling"
] = "simcore.services.director-v2.scheduling"
)
user_id: UserID
project_id: ProjectID
iteration: Iteration

def routing_key(self) -> str | None: # pylint: disable=no-self-use # abstract
return None


@dataclass(frozen=True, slots=True)
class TaskStateTracker:
previous: CompTaskAtDB
current: CompTaskAtDB
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ..db.repositories.comp_pipelines import CompPipelinesRepository
from ..db.repositories.comp_runs import CompRunsRepository
from ..db.repositories.comp_tasks import CompTasksRepository
from ._models import TaskStateTracker
from ._publisher import request_pipeline_scheduling
from ._utils import (
COMPLETED_STATES,
Expand All @@ -76,9 +77,6 @@
_logger = logging.getLogger(__name__)


_Previous = CompTaskAtDB
_Current = CompTaskAtDB

_MAX_WAITING_TIME_FOR_UNKNOWN_TASKS: Final[datetime.timedelta] = datetime.timedelta(
seconds=30
)
Expand Down Expand Up @@ -117,47 +115,49 @@ async def _async_cb() -> None:
@dataclass(frozen=True, slots=True)
class SortedTasks:
started: list[CompTaskAtDB]
completed: list[CompTaskAtDB]
waiting: list[CompTaskAtDB]
potentially_lost: list[CompTaskAtDB]
completed: list[TaskStateTracker]
waiting: list[TaskStateTracker]
potentially_lost: list[TaskStateTracker]


async def _triage_changed_tasks(
changed_tasks: list[tuple[_Previous, _Current]],
changed_tasks: list[TaskStateTracker],
) -> SortedTasks:
started_tasks = [
current
for previous, current in changed_tasks
if current.state in RUNNING_STATES
tracker.current
for tracker in changed_tasks
if tracker.current.state in RUNNING_STATES
or (
previous.state in WAITING_FOR_START_STATES
and current.state in COMPLETED_STATES
tracker.previous.state in WAITING_FOR_START_STATES
and tracker.current.state in COMPLETED_STATES
)
]

completed_tasks = [
current for _, current in changed_tasks if current.state in COMPLETED_STATES
tracker
for tracker in changed_tasks
if tracker.current.state in COMPLETED_STATES
]

waiting_for_resources_tasks = [
current
for previous, current in changed_tasks
if current.state in WAITING_FOR_START_STATES
tracker
for tracker in changed_tasks
if tracker.current.state in WAITING_FOR_START_STATES
]

lost_tasks = [
current
for previous, current in changed_tasks
if (current.state is RunningState.UNKNOWN)
tracker
for tracker in changed_tasks
if (tracker.current.state is RunningState.UNKNOWN)
and (
(arrow.utcnow().datetime - previous.modified)
(arrow.utcnow().datetime - tracker.previous.modified)
> _MAX_WAITING_TIME_FOR_UNKNOWN_TASKS
)
]
if lost_tasks:
_logger.warning(
"%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!",
[t.node_id for t in lost_tasks],
[t.current.node_id for t in lost_tasks],
)

return SortedTasks(
Expand Down Expand Up @@ -321,6 +321,7 @@ async def _send_running_tasks_heartbeat(
def _need_heartbeat(task: CompTaskAtDB) -> bool:
if task.state not in RUNNING_STATES:
return False

if task.last_heartbeat is None:
assert task.start # nosec
return bool(
Expand Down Expand Up @@ -362,14 +363,14 @@ async def _get_changed_tasks_from_backend(
user_id: UserID,
processing_tasks: list[CompTaskAtDB],
comp_run: CompRunsAtDB,
) -> tuple[list[tuple[_Previous, _Current]], list[CompTaskAtDB]]:
) -> tuple[list[TaskStateTracker], list[CompTaskAtDB]]:
tasks_backend_status = await self._get_tasks_status(
user_id, processing_tasks, comp_run
)

return (
[
(
TaskStateTracker(
task,
task.model_copy(update={"state": backend_state}),
)
Expand Down Expand Up @@ -502,16 +503,16 @@ async def _process_started_tasks(
)

async def _process_waiting_tasks(
self, tasks: list[CompTaskAtDB], run_id: PositiveInt
self, tasks: list[TaskStateTracker], run_id: PositiveInt
) -> None:
comp_tasks_repo = CompTasksRepository(self.db_engine)
await asyncio.gather(
*(
comp_tasks_repo.update_project_tasks_state(
t.project_id,
t.current.project_id,
run_id,
[t.node_id],
t.state,
[t.current.node_id],
t.current.state,
)
for t in tasks
)
Expand Down Expand Up @@ -602,7 +603,7 @@ async def _stop_tasks(
async def _process_completed_tasks(
self,
user_id: UserID,
tasks: list[CompTaskAtDB],
tasks: list[TaskStateTracker],
iteration: Iteration,
comp_run: CompRunsAtDB,
) -> None:
Expand Down
Loading
Loading