Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/models-library/src/models_library/rabbitmq_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,13 @@ class WalletCreditsLimitReachedMessage(RabbitMessageBase):

def routing_key(self) -> str | None:
return f"{self.wallet_id}.{self.credits_limit}"


class ComputationalPipelineStatusMessage(RabbitMessageBase, ProjectMessageBase):
channel_name: Literal["io.simcore.service.computation.pipeline-status"] = (
"io.simcore.service.computation.pipeline-status"
)
run_result: RunningState

def routing_key(self) -> str | None:
return f"{self.project_id}"
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,20 @@


async def _check_pipeline_not_running_or_raise_409(
comp_tasks_repo: CompTasksRepository, computation: ComputationCreate
comp_runs_repo: CompRunsRepository,
computation: ComputationCreate,
) -> None:
pipeline_state = utils.get_pipeline_state_from_task_states(
await comp_tasks_repo.list_computational_tasks(computation.project_id)
)
if utils.is_pipeline_running(pipeline_state):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Project {computation.project_id} already started, current state is {pipeline_state}",
with contextlib.suppress(ComputationalRunNotFoundError):
last_run = await comp_runs_repo.get(
user_id=computation.user_id, project_id=computation.project_id
)
pipeline_state = last_run.result

if utils.is_pipeline_running(pipeline_state):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Project {computation.project_id} already started, current state is {pipeline_state}",
)


async def _check_pipeline_startable(
Expand Down Expand Up @@ -302,7 +306,7 @@ async def create_or_update_or_start_computation( # noqa: PLR0913 # pylint: disa
project: ProjectAtDB = await project_repo.get_project(computation.project_id)

# check if current state allow to modify the computation
await _check_pipeline_not_running_or_raise_409(comp_tasks_repo, computation)
await _check_pipeline_not_running_or_raise_409(comp_runs_repo, computation)

# create the complete DAG graph
complete_dag = create_complete_dag(project.workbench)
Expand Down Expand Up @@ -353,20 +357,14 @@ async def create_or_update_or_start_computation( # noqa: PLR0913 # pylint: disa
projects_metadata_repo=projects_metadata_repo,
)

# filter the tasks by the effective pipeline
filtered_tasks = [
t
for t in comp_tasks
if f"{t.node_id}" in set(minimal_computational_dag.nodes())
]
pipeline_state = utils.get_pipeline_state_from_task_states(filtered_tasks)

# get run details if any
last_run: CompRunsAtDB | None = None
pipeline_state = RunningState.NOT_STARTED
with contextlib.suppress(ComputationalRunNotFoundError):
last_run = await comp_runs_repo.get(
user_id=computation.user_id, project_id=computation.project_id
)
pipeline_state = last_run.result

return ComputationGet(
id=computation.project_id,
Expand Down Expand Up @@ -449,21 +447,10 @@ async def get_computation(
# check that project actually exists
await project_repo.get_project(project_id)

pipeline_dag, all_tasks, filtered_tasks = await analyze_pipeline(
pipeline_dag, all_tasks, _filtered_tasks = await analyze_pipeline(
project_id, comp_pipelines_repo, comp_tasks_repo
)

pipeline_state: RunningState = utils.get_pipeline_state_from_task_states(
filtered_tasks
)

_logger.debug(
"Computational task status by %s for %s has %s",
f"{user_id=}",
f"{project_id=}",
f"{pipeline_state=}",
)

# create the complete DAG graph
complete_dag = create_complete_dag_from_tasks(all_tasks)
pipeline_details = await compute_pipeline_details(
Expand All @@ -472,8 +459,17 @@ async def get_computation(

# get run details if any
last_run: CompRunsAtDB | None = None
pipeline_state = RunningState.NOT_STARTED
with contextlib.suppress(ComputationalRunNotFoundError):
last_run = await comp_runs_repo.get(user_id=user_id, project_id=project_id)
pipeline_state = last_run.result

_logger.debug(
"Computational task status by %s for %s has %s",
f"{user_id=}",
f"{project_id=}",
f"{pipeline_state=}",
)

self_url = request.url.remove_query_params("user_id")
return ComputationGet(
Expand Down Expand Up @@ -536,23 +532,18 @@ async def stop_computation(
tasks: list[CompTaskAtDB] = await comp_tasks_repo.list_tasks(project_id)
# create the complete DAG graph
complete_dag = create_complete_dag_from_tasks(tasks)
# filter the tasks by the effective pipeline
filtered_tasks = [
t for t in tasks if f"{t.node_id}" in set(pipeline_dag.nodes())
]
pipeline_state = utils.get_pipeline_state_from_task_states(filtered_tasks)

if utils.is_pipeline_running(pipeline_state):
await stop_pipeline(
request.app, user_id=computation_stop.user_id, project_id=project_id
)

# get run details if any
# stop the pipeline if it is running
last_run: CompRunsAtDB | None = None
pipeline_state = RunningState.UNKNOWN
with contextlib.suppress(ComputationalRunNotFoundError):
last_run = await comp_runs_repo.get(
user_id=computation_stop.user_id, project_id=project_id
)
pipeline_state = last_run.result
if utils.is_pipeline_running(last_run.result):
await stop_pipeline(
request.app, user_id=computation_stop.user_id, project_id=project_id
)

return ComputationGet(
id=project_id,
Expand Down Expand Up @@ -594,15 +585,20 @@ async def delete_computation(
comp_tasks_repo: Annotated[
CompTasksRepository, Depends(get_repository(CompTasksRepository))
],
comp_runs_repo: Annotated[
CompRunsRepository, Depends(get_repository(CompRunsRepository))
],
) -> None:
try:
# get the project
project: ProjectAtDB = await project_repo.get_project(project_id)
# check if current state allow to stop the computation
comp_tasks: list[CompTaskAtDB] = await comp_tasks_repo.list_computational_tasks(
project_id
)
pipeline_state = utils.get_pipeline_state_from_task_states(comp_tasks)
pipeline_state = RunningState.UNKNOWN
with contextlib.suppress(ComputationalRunNotFoundError):
last_run = await comp_runs_repo.get(
user_id=computation_stop.user_id, project_id=project_id
)
pipeline_state = last_run.result
if utils.is_pipeline_running(pipeline_state):
if not computation_stop.force:
raise HTTPException(
Expand Down Expand Up @@ -634,12 +630,10 @@ def return_last_value(retry_state: Any) -> Any:
before_sleep=before_sleep_log(_logger, logging.INFO),
)
async def check_pipeline_stopped() -> bool:
comp_tasks: list[CompTaskAtDB] = (
await comp_tasks_repo.list_computational_tasks(project_id)
)
pipeline_state = utils.get_pipeline_state_from_task_states(
comp_tasks,
last_run = await comp_runs_repo.get(
user_id=computation_stop.user_id, project_id=project_id
)
pipeline_state = last_run.result
return utils.is_pipeline_stopped(pipeline_state)

# wait for the pipeline to be stopped
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
from typing import Final

Expand All @@ -13,10 +14,11 @@
from servicelib.utils import limited_gather
from sqlalchemy.ext.asyncio import AsyncEngine

from ...core.errors import ComputationalRunNotFoundError
from ...models.comp_pipelines import CompPipelineAtDB
from ...models.comp_runs import RunMetadataDict
from ...models.comp_tasks import CompTaskAtDB
from ...utils.rabbitmq import publish_project_log
from ...utils.rabbitmq import publish_pipeline_scheduling_state, publish_project_log
from ..db import get_db_engine
from ..db.repositories.comp_pipelines import CompPipelinesRepository
from ..db.repositories.comp_runs import CompRunsRepository
Expand Down Expand Up @@ -57,6 +59,18 @@ async def run_new_pipeline(
)
return

with contextlib.suppress(ComputationalRunNotFoundError):
# if the run already exists and is scheduled, do not schedule again.
last_run = await CompRunsRepository.instance(db_engine).get(
user_id=user_id, project_id=project_id
)
if last_run.result.is_running():
_logger.warning(
"run for project %s is already running. not scheduling it again.",
f"{project_id=}",
)
return

new_run = await CompRunsRepository.instance(db_engine).create(
user_id=user_id,
project_id=project_id,
Expand Down Expand Up @@ -92,6 +106,9 @@ async def run_new_pipeline(
log=f"Project pipeline scheduled using {'on-demand clusters' if use_on_demand_clusters else 'pre-defined clusters'}, starting soon...",
log_level=logging.INFO,
)
await publish_pipeline_scheduling_state(
rabbitmq_client, user_id, project_id, new_run.result
)


async def stop_pipeline(
Expand Down Expand Up @@ -128,8 +145,7 @@ async def _get_pipeline_at_db(
project_id: ProjectID, db_engine: AsyncEngine
) -> CompPipelineAtDB:
comp_pipeline_repo = CompPipelinesRepository.instance(db_engine)
pipeline_at_db = await comp_pipeline_repo.get_pipeline(project_id)
return pipeline_at_db
return await comp_pipeline_repo.get_pipeline(project_id)


async def _get_pipeline_tasks_at_db(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ async def request_pipeline_scheduling(
project_id: ProjectID,
iteration: Iteration,
) -> None:
# NOTE: we should use the transaction and the asyncpg engine here to ensure 100% consistency
# https://github.com/ITISFoundation/osparc-simcore/issues/6818
# async with transaction_context(get_asyncpg_engine(app)) as connection:
# NOTE: it is important that the DB is set up first before scheduling, in case the worker already schedules before we change the DB
await CompRunsRepository.instance(db_engine).mark_for_scheduling(
user_id=user_id, project_id=project_id, iteration=iteration
)
await rabbitmq_client.publish(
SchedulePipelineRabbitMessage.get_channel_name(),
SchedulePipelineRabbitMessage(
Expand All @@ -27,6 +28,3 @@ async def request_pipeline_scheduling(
iteration=iteration,
),
)
await CompRunsRepository.instance(db_engine).mark_for_scheduling(
user_id=user_id, project_id=project_id, iteration=iteration
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from ...models.comp_tasks import CompTaskAtDB
from ...utils.computations import get_pipeline_state_from_task_states
from ...utils.rabbitmq import (
publish_pipeline_scheduling_state,
publish_project_log,
publish_service_resource_tracking_heartbeat,
publish_service_resource_tracking_started,
Expand Down Expand Up @@ -208,10 +209,13 @@ async def _update_run_result_from_tasks(
project_id: ProjectID,
iteration: Iteration,
pipeline_tasks: dict[NodeIDStr, CompTaskAtDB],
current_result: RunningState,
) -> RunningState:
pipeline_state_from_tasks = get_pipeline_state_from_task_states(
list(pipeline_tasks.values()),
)
if pipeline_state_from_tasks == current_result:
return pipeline_state_from_tasks
_logger.debug(
"pipeline %s is currently in %s",
f"{user_id=}_{project_id=}_{iteration=}",
Expand All @@ -238,17 +242,35 @@ async def _set_run_result(
final_state=(run_result in COMPLETED_STATES),
)

async def _set_schedule_done(
if run_result in COMPLETED_STATES:
# send event to notify the piipeline is done
await publish_project_log(
self.rabbitmq_client,
user_id=user_id,
project_id=project_id,
log=f"Pipeline run {run_result.value} for iteration {iteration} is done with {run_result.value} state",
log_level=logging.INFO,
)
await publish_pipeline_scheduling_state(
self.rabbitmq_client, user_id, project_id, run_result
)

async def _set_processing_done(
self,
user_id: UserID,
project_id: ProjectID,
iteration: Iteration,
) -> None:
await CompRunsRepository.instance(self.db_engine).mark_as_processed(
user_id=user_id,
project_id=project_id,
iteration=iteration,
)
with log_context(
_logger,
logging.DEBUG,
msg=f"mark pipeline run for {iteration=} for {user_id=} and {project_id=} as processed",
):
await CompRunsRepository.instance(self.db_engine).mark_as_processed(
user_id=user_id,
project_id=project_id,
iteration=iteration,
)

async def _set_states_following_failed_to_aborted(
self, project_id: ProjectID, dag: nx.DiGraph, run_id: PositiveInt
Expand Down Expand Up @@ -622,7 +644,7 @@ async def apply(
)
# 3. do we want to stop the pipeline now?
if comp_run.cancelled:
await self._schedule_tasks_to_stop(
comp_tasks = await self._schedule_tasks_to_stop(
user_id, project_id, comp_tasks, comp_run
)
else:
Expand Down Expand Up @@ -653,7 +675,7 @@ async def apply(

# 6. Update the run result
pipeline_result = await self._update_run_result_from_tasks(
user_id, project_id, iteration, comp_tasks
user_id, project_id, iteration, comp_tasks, comp_run.result
)

# 7. Are we done scheduling that pipeline?
Expand Down Expand Up @@ -702,28 +724,37 @@ async def apply(
except ComputationalBackendNotConnectedError:
_logger.exception("Computational backend is not connected!")
finally:
await self._set_schedule_done(user_id, project_id, iteration)
await self._set_processing_done(user_id, project_id, iteration)

async def _schedule_tasks_to_stop(
self,
user_id: UserID,
project_id: ProjectID,
comp_tasks: dict[NodeIDStr, CompTaskAtDB],
comp_run: CompRunsAtDB,
) -> None:
# get any running task and stop them
) -> dict[NodeIDStr, CompTaskAtDB]:
# NOTE: tasks that were not yet started but can be marked as ABORTED straight away,
# the tasks that are already processing need some time to stop
# and we need to stop them in the backend
tasks_instantly_stopeable = [
t for t in comp_tasks.values() if t.state in TASK_TO_START_STATES
]
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
await (
comp_tasks_repo.mark_project_published_waiting_for_cluster_tasks_as_aborted(
project_id, comp_run.run_id
)
)
for task in tasks_instantly_stopeable:
comp_tasks[f"{task.node_id}"].state = RunningState.ABORTED
# stop any remaining running task, these are already submitted
if tasks_to_stop := [
t for t in comp_tasks.values() if t.state in PROCESSING_STATES
]:
await self._stop_tasks(user_id, tasks_to_stop, comp_run)

return comp_tasks

async def _schedule_tasks_to_start( # noqa: C901
self,
user_id: UserID,
Expand Down
Loading
Loading