Skip to content

Commit 2f1b484

Browse files
authored
šŸ›Improvements on pipeline cancellation and ensure pipeline state is consistent (#7996)
1 parent 6b1b81e commit 2f1b484

File tree

10 files changed

+383
-73
lines changed

10 files changed

+383
-73
lines changed

ā€Žpackages/models-library/src/models_library/rabbitmq_messages.pyā€Ž

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,13 @@ class WalletCreditsLimitReachedMessage(RabbitMessageBase):
316316

317317
def routing_key(self) -> str | None:
318318
return f"{self.wallet_id}.{self.credits_limit}"
319+
320+
321+
class ComputationalPipelineStatusMessage(RabbitMessageBase, ProjectMessageBase):
322+
channel_name: Literal["io.simcore.service.computation.pipeline-status"] = (
323+
"io.simcore.service.computation.pipeline-status"
324+
)
325+
run_result: RunningState
326+
327+
def routing_key(self) -> str | None:
328+
return f"{self.project_id}"

ā€Žservices/director-v2/src/simcore_service_director_v2/api/routes/computations.pyā€Ž

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,20 @@
9393

9494

9595
async def _check_pipeline_not_running_or_raise_409(
96-
comp_tasks_repo: CompTasksRepository, computation: ComputationCreate
96+
comp_runs_repo: CompRunsRepository,
97+
computation: ComputationCreate,
9798
) -> None:
98-
pipeline_state = utils.get_pipeline_state_from_task_states(
99-
await comp_tasks_repo.list_computational_tasks(computation.project_id)
100-
)
101-
if utils.is_pipeline_running(pipeline_state):
102-
raise HTTPException(
103-
status_code=status.HTTP_409_CONFLICT,
104-
detail=f"Project {computation.project_id} already started, current state is {pipeline_state}",
99+
with contextlib.suppress(ComputationalRunNotFoundError):
100+
last_run = await comp_runs_repo.get(
101+
user_id=computation.user_id, project_id=computation.project_id
105102
)
103+
pipeline_state = last_run.result
104+
105+
if utils.is_pipeline_running(pipeline_state):
106+
raise HTTPException(
107+
status_code=status.HTTP_409_CONFLICT,
108+
detail=f"Project {computation.project_id} already started, current state is {pipeline_state}",
109+
)
106110

107111

108112
async def _check_pipeline_startable(
@@ -302,7 +306,7 @@ async def create_or_update_or_start_computation( # noqa: PLR0913 # pylint: disa
302306
project: ProjectAtDB = await project_repo.get_project(computation.project_id)
303307

304308
# check if current state allow to modify the computation
305-
await _check_pipeline_not_running_or_raise_409(comp_tasks_repo, computation)
309+
await _check_pipeline_not_running_or_raise_409(comp_runs_repo, computation)
306310

307311
# create the complete DAG graph
308312
complete_dag = create_complete_dag(project.workbench)
@@ -353,20 +357,14 @@ async def create_or_update_or_start_computation( # noqa: PLR0913 # pylint: disa
353357
projects_metadata_repo=projects_metadata_repo,
354358
)
355359

356-
# filter the tasks by the effective pipeline
357-
filtered_tasks = [
358-
t
359-
for t in comp_tasks
360-
if f"{t.node_id}" in set(minimal_computational_dag.nodes())
361-
]
362-
pipeline_state = utils.get_pipeline_state_from_task_states(filtered_tasks)
363-
364360
# get run details if any
365361
last_run: CompRunsAtDB | None = None
362+
pipeline_state = RunningState.NOT_STARTED
366363
with contextlib.suppress(ComputationalRunNotFoundError):
367364
last_run = await comp_runs_repo.get(
368365
user_id=computation.user_id, project_id=computation.project_id
369366
)
367+
pipeline_state = last_run.result
370368

371369
return ComputationGet(
372370
id=computation.project_id,
@@ -449,21 +447,10 @@ async def get_computation(
449447
# check that project actually exists
450448
await project_repo.get_project(project_id)
451449

452-
pipeline_dag, all_tasks, filtered_tasks = await analyze_pipeline(
450+
pipeline_dag, all_tasks, _filtered_tasks = await analyze_pipeline(
453451
project_id, comp_pipelines_repo, comp_tasks_repo
454452
)
455453

456-
pipeline_state: RunningState = utils.get_pipeline_state_from_task_states(
457-
filtered_tasks
458-
)
459-
460-
_logger.debug(
461-
"Computational task status by %s for %s has %s",
462-
f"{user_id=}",
463-
f"{project_id=}",
464-
f"{pipeline_state=}",
465-
)
466-
467454
# create the complete DAG graph
468455
complete_dag = create_complete_dag_from_tasks(all_tasks)
469456
pipeline_details = await compute_pipeline_details(
@@ -472,8 +459,17 @@ async def get_computation(
472459

473460
# get run details if any
474461
last_run: CompRunsAtDB | None = None
462+
pipeline_state = RunningState.NOT_STARTED
475463
with contextlib.suppress(ComputationalRunNotFoundError):
476464
last_run = await comp_runs_repo.get(user_id=user_id, project_id=project_id)
465+
pipeline_state = last_run.result
466+
467+
_logger.debug(
468+
"Computational task status by %s for %s has %s",
469+
f"{user_id=}",
470+
f"{project_id=}",
471+
f"{pipeline_state=}",
472+
)
477473

478474
self_url = request.url.remove_query_params("user_id")
479475
return ComputationGet(
@@ -536,23 +532,18 @@ async def stop_computation(
536532
tasks: list[CompTaskAtDB] = await comp_tasks_repo.list_tasks(project_id)
537533
# create the complete DAG graph
538534
complete_dag = create_complete_dag_from_tasks(tasks)
539-
# filter the tasks by the effective pipeline
540-
filtered_tasks = [
541-
t for t in tasks if f"{t.node_id}" in set(pipeline_dag.nodes())
542-
]
543-
pipeline_state = utils.get_pipeline_state_from_task_states(filtered_tasks)
544-
545-
if utils.is_pipeline_running(pipeline_state):
546-
await stop_pipeline(
547-
request.app, user_id=computation_stop.user_id, project_id=project_id
548-
)
549-
550-
# get run details if any
535+
# stop the pipeline if it is running
551536
last_run: CompRunsAtDB | None = None
537+
pipeline_state = RunningState.UNKNOWN
552538
with contextlib.suppress(ComputationalRunNotFoundError):
553539
last_run = await comp_runs_repo.get(
554540
user_id=computation_stop.user_id, project_id=project_id
555541
)
542+
pipeline_state = last_run.result
543+
if utils.is_pipeline_running(last_run.result):
544+
await stop_pipeline(
545+
request.app, user_id=computation_stop.user_id, project_id=project_id
546+
)
556547

557548
return ComputationGet(
558549
id=project_id,
@@ -594,15 +585,20 @@ async def delete_computation(
594585
comp_tasks_repo: Annotated[
595586
CompTasksRepository, Depends(get_repository(CompTasksRepository))
596587
],
588+
comp_runs_repo: Annotated[
589+
CompRunsRepository, Depends(get_repository(CompRunsRepository))
590+
],
597591
) -> None:
598592
try:
599593
# get the project
600594
project: ProjectAtDB = await project_repo.get_project(project_id)
601595
# check if current state allow to stop the computation
602-
comp_tasks: list[CompTaskAtDB] = await comp_tasks_repo.list_computational_tasks(
603-
project_id
604-
)
605-
pipeline_state = utils.get_pipeline_state_from_task_states(comp_tasks)
596+
pipeline_state = RunningState.UNKNOWN
597+
with contextlib.suppress(ComputationalRunNotFoundError):
598+
last_run = await comp_runs_repo.get(
599+
user_id=computation_stop.user_id, project_id=project_id
600+
)
601+
pipeline_state = last_run.result
606602
if utils.is_pipeline_running(pipeline_state):
607603
if not computation_stop.force:
608604
raise HTTPException(
@@ -634,12 +630,10 @@ def return_last_value(retry_state: Any) -> Any:
634630
before_sleep=before_sleep_log(_logger, logging.INFO),
635631
)
636632
async def check_pipeline_stopped() -> bool:
637-
comp_tasks: list[CompTaskAtDB] = (
638-
await comp_tasks_repo.list_computational_tasks(project_id)
639-
)
640-
pipeline_state = utils.get_pipeline_state_from_task_states(
641-
comp_tasks,
633+
last_run = await comp_runs_repo.get(
634+
user_id=computation_stop.user_id, project_id=project_id
642635
)
636+
pipeline_state = last_run.result
643637
return utils.is_pipeline_stopped(pipeline_state)
644638

645639
# wait for the pipeline to be stopped

ā€Žservices/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_manager.pyā€Ž

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import logging
23
from typing import Final
34

@@ -13,10 +14,11 @@
1314
from servicelib.utils import limited_gather
1415
from sqlalchemy.ext.asyncio import AsyncEngine
1516

17+
from ...core.errors import ComputationalRunNotFoundError
1618
from ...models.comp_pipelines import CompPipelineAtDB
1719
from ...models.comp_runs import RunMetadataDict
1820
from ...models.comp_tasks import CompTaskAtDB
19-
from ...utils.rabbitmq import publish_project_log
21+
from ...utils.rabbitmq import publish_pipeline_scheduling_state, publish_project_log
2022
from ..db import get_db_engine
2123
from ..db.repositories.comp_pipelines import CompPipelinesRepository
2224
from ..db.repositories.comp_runs import CompRunsRepository
@@ -57,6 +59,18 @@ async def run_new_pipeline(
5759
)
5860
return
5961

62+
with contextlib.suppress(ComputationalRunNotFoundError):
63+
# if the run already exists and is scheduled, do not schedule again.
64+
last_run = await CompRunsRepository.instance(db_engine).get(
65+
user_id=user_id, project_id=project_id
66+
)
67+
if last_run.result.is_running():
68+
_logger.warning(
69+
"run for project %s is already running. not scheduling it again.",
70+
f"{project_id=}",
71+
)
72+
return
73+
6074
new_run = await CompRunsRepository.instance(db_engine).create(
6175
user_id=user_id,
6276
project_id=project_id,
@@ -92,6 +106,9 @@ async def run_new_pipeline(
92106
log=f"Project pipeline scheduled using {'on-demand clusters' if use_on_demand_clusters else 'pre-defined clusters'}, starting soon...",
93107
log_level=logging.INFO,
94108
)
109+
await publish_pipeline_scheduling_state(
110+
rabbitmq_client, user_id, project_id, new_run.result
111+
)
95112

96113

97114
async def stop_pipeline(
@@ -128,8 +145,7 @@ async def _get_pipeline_at_db(
128145
project_id: ProjectID, db_engine: AsyncEngine
129146
) -> CompPipelineAtDB:
130147
comp_pipeline_repo = CompPipelinesRepository.instance(db_engine)
131-
pipeline_at_db = await comp_pipeline_repo.get_pipeline(project_id)
132-
return pipeline_at_db
148+
return await comp_pipeline_repo.get_pipeline(project_id)
133149

134150

135151
async def _get_pipeline_tasks_at_db(

ā€Žservices/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_publisher.pyā€Ž

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ async def request_pipeline_scheduling(
1616
project_id: ProjectID,
1717
iteration: Iteration,
1818
) -> None:
19-
# NOTE: we should use the transaction and the asyncpg engine here to ensure 100% consistency
20-
# https://github.com/ITISFoundation/osparc-simcore/issues/6818
21-
# async with transaction_context(get_asyncpg_engine(app)) as connection:
19+
# NOTE: it is important that the DB is set up first before scheduling, in case the worker already schedules before we change the DB
20+
await CompRunsRepository.instance(db_engine).mark_for_scheduling(
21+
user_id=user_id, project_id=project_id, iteration=iteration
22+
)
2223
await rabbitmq_client.publish(
2324
SchedulePipelineRabbitMessage.get_channel_name(),
2425
SchedulePipelineRabbitMessage(
@@ -27,6 +28,3 @@ async def request_pipeline_scheduling(
2728
iteration=iteration,
2829
),
2930
)
30-
await CompRunsRepository.instance(db_engine).mark_for_scheduling(
31-
user_id=user_id, project_id=project_id, iteration=iteration
32-
)

ā€Žservices/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.pyā€Ž

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from ...models.comp_tasks import CompTaskAtDB
5353
from ...utils.computations import get_pipeline_state_from_task_states
5454
from ...utils.rabbitmq import (
55+
publish_pipeline_scheduling_state,
5556
publish_project_log,
5657
publish_service_resource_tracking_heartbeat,
5758
publish_service_resource_tracking_started,
@@ -208,10 +209,13 @@ async def _update_run_result_from_tasks(
208209
project_id: ProjectID,
209210
iteration: Iteration,
210211
pipeline_tasks: dict[NodeIDStr, CompTaskAtDB],
212+
current_result: RunningState,
211213
) -> RunningState:
212214
pipeline_state_from_tasks = get_pipeline_state_from_task_states(
213215
list(pipeline_tasks.values()),
214216
)
217+
if pipeline_state_from_tasks == current_result:
218+
return pipeline_state_from_tasks
215219
_logger.debug(
216220
"pipeline %s is currently in %s",
217221
f"{user_id=}_{project_id=}_{iteration=}",
@@ -238,17 +242,35 @@ async def _set_run_result(
238242
final_state=(run_result in COMPLETED_STATES),
239243
)
240244

241-
async def _set_schedule_done(
245+
if run_result in COMPLETED_STATES:
246+
# send event to notify the piipeline is done
247+
await publish_project_log(
248+
self.rabbitmq_client,
249+
user_id=user_id,
250+
project_id=project_id,
251+
log=f"Pipeline run {run_result.value} for iteration {iteration} is done with {run_result.value} state",
252+
log_level=logging.INFO,
253+
)
254+
await publish_pipeline_scheduling_state(
255+
self.rabbitmq_client, user_id, project_id, run_result
256+
)
257+
258+
async def _set_processing_done(
242259
self,
243260
user_id: UserID,
244261
project_id: ProjectID,
245262
iteration: Iteration,
246263
) -> None:
247-
await CompRunsRepository.instance(self.db_engine).mark_as_processed(
248-
user_id=user_id,
249-
project_id=project_id,
250-
iteration=iteration,
251-
)
264+
with log_context(
265+
_logger,
266+
logging.DEBUG,
267+
msg=f"mark pipeline run for {iteration=} for {user_id=} and {project_id=} as processed",
268+
):
269+
await CompRunsRepository.instance(self.db_engine).mark_as_processed(
270+
user_id=user_id,
271+
project_id=project_id,
272+
iteration=iteration,
273+
)
252274

253275
async def _set_states_following_failed_to_aborted(
254276
self, project_id: ProjectID, dag: nx.DiGraph, run_id: PositiveInt
@@ -622,7 +644,7 @@ async def apply(
622644
)
623645
# 3. do we want to stop the pipeline now?
624646
if comp_run.cancelled:
625-
await self._schedule_tasks_to_stop(
647+
comp_tasks = await self._schedule_tasks_to_stop(
626648
user_id, project_id, comp_tasks, comp_run
627649
)
628650
else:
@@ -653,7 +675,7 @@ async def apply(
653675

654676
# 6. Update the run result
655677
pipeline_result = await self._update_run_result_from_tasks(
656-
user_id, project_id, iteration, comp_tasks
678+
user_id, project_id, iteration, comp_tasks, comp_run.result
657679
)
658680

659681
# 7. Are we done scheduling that pipeline?
@@ -702,28 +724,37 @@ async def apply(
702724
except ComputationalBackendNotConnectedError:
703725
_logger.exception("Computational backend is not connected!")
704726
finally:
705-
await self._set_schedule_done(user_id, project_id, iteration)
727+
await self._set_processing_done(user_id, project_id, iteration)
706728

707729
async def _schedule_tasks_to_stop(
708730
self,
709731
user_id: UserID,
710732
project_id: ProjectID,
711733
comp_tasks: dict[NodeIDStr, CompTaskAtDB],
712734
comp_run: CompRunsAtDB,
713-
) -> None:
714-
# get any running task and stop them
735+
) -> dict[NodeIDStr, CompTaskAtDB]:
736+
# NOTE: tasks that were not yet started but can be marked as ABORTED straight away,
737+
# the tasks that are already processing need some time to stop
738+
# and we need to stop them in the backend
739+
tasks_instantly_stopeable = [
740+
t for t in comp_tasks.values() if t.state in TASK_TO_START_STATES
741+
]
715742
comp_tasks_repo = CompTasksRepository.instance(self.db_engine)
716743
await (
717744
comp_tasks_repo.mark_project_published_waiting_for_cluster_tasks_as_aborted(
718745
project_id, comp_run.run_id
719746
)
720747
)
748+
for task in tasks_instantly_stopeable:
749+
comp_tasks[f"{task.node_id}"].state = RunningState.ABORTED
721750
# stop any remaining running task, these are already submitted
722751
if tasks_to_stop := [
723752
t for t in comp_tasks.values() if t.state in PROCESSING_STATES
724753
]:
725754
await self._stop_tasks(user_id, tasks_to_stop, comp_run)
726755

756+
return comp_tasks
757+
727758
async def _schedule_tasks_to_start( # noqa: C901
728759
self,
729760
user_id: UserID,

0 commit comments

Comments
Ā (0)