Skip to content

Commit cbb4adf

Browse files
committed
trigger cluster termination is now able to force close
1 parent eb7eac5 commit cbb4adf

File tree

2 files changed

+85
-41
lines changed
  • scripts/maintenance/computational-clusters/autoscaled_monitor

2 files changed

+85
-41
lines changed

scripts/maintenance/computational-clusters/autoscaled_monitor/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def cancel_jobs(
177177
def trigger_cluster_termination(
178178
user_id: Annotated[int, typer.Option(help="the user ID")],
179179
wallet_id: Annotated[int, typer.Option(help="the wallet ID")],
180+
force: Annotated[bool, typer.Option(help="will not ask for confirmation")] = False,
180181
) -> None:
181182
"""this will set the Heartbeat tag on the primary machine to 1 hour, thus ensuring the
182183
clusters-keeper will properly terminate that cluster.
@@ -185,7 +186,7 @@ def trigger_cluster_termination(
185186
user_id -- the user ID
186187
wallet_id -- the wallet ID
187188
"""
188-
asyncio.run(api.trigger_cluster_termination(state, user_id, wallet_id))
189+
asyncio.run(api.trigger_cluster_termination(state, user_id, wallet_id, force=force))
189190

190191

191192
@app.command()

scripts/maintenance/computational-clusters/autoscaled_monitor/core.py

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def _print_summary_as_json(
446446
"user_id": cluster.primary.user_id,
447447
"wallet_id": cluster.primary.wallet_id,
448448
"disk_space": cluster.primary.disk_space.human_readable(),
449+
"last_heartbeat": cluster.primary.last_heartbeat.isoformat(),
449450
},
450451
"workers": [
451452
{
@@ -561,6 +562,63 @@ async def _list_computational_clusters(
561562
)
562563

563564

565+
async def _cancel_all_jobs(
566+
state: AppState,
567+
the_cluster: ComputationalCluster,
568+
*,
569+
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]],
570+
abort_in_db: bool,
571+
) -> None:
572+
rich.print("cancelling all tasks")
573+
for comp_task, dask_task in task_to_dask_job:
574+
if dask_task is not None and dask_task.state != "unknown":
575+
await dask.trigger_job_cancellation_in_scheduler(
576+
state,
577+
the_cluster,
578+
dask_task.job_id,
579+
)
580+
if comp_task is None:
581+
# we need to clear it of the cluster
582+
await dask.remove_job_from_scheduler(
583+
state,
584+
the_cluster,
585+
dask_task.job_id,
586+
)
587+
if comp_task is not None and abort_in_db:
588+
await db.abort_job_in_db(state, comp_task.project_id, comp_task.node_id)
589+
590+
rich.print("cancelled all tasks")
591+
592+
593+
async def _get_job_id_to_dask_state_from_cluster(
594+
cluster: ComputationalCluster,
595+
) -> dict[TaskId, TaskState]:
596+
job_id_to_dask_state: dict[TaskId, TaskState] = {}
597+
for job_state, job_ids in cluster.task_states_to_tasks.items():
598+
for job_id in job_ids:
599+
job_id_to_dask_state[job_id] = job_state
600+
return job_id_to_dask_state
601+
602+
603+
async def _get_db_task_to_dask_job(
604+
computational_tasks: list[ComputationalTask],
605+
job_id_to_dask_state: dict[TaskId, TaskState],
606+
) -> list[tuple[ComputationalTask | None, DaskTask | None]]:
607+
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = []
608+
for task in computational_tasks:
609+
dask_task = None
610+
if task.job_id:
611+
dask_task = DaskTask(
612+
job_id=task.job_id,
613+
state=job_id_to_dask_state.pop(task.job_id, None) or "unknown",
614+
)
615+
task_to_dask_job.append((task, dask_task))
616+
# keep the jobs still in the cluster
617+
for job_id, dask_state in job_id_to_dask_state.items():
618+
task_to_dask_job.append((None, DaskTask(job_id=job_id, state=dask_state)))
619+
return task_to_dask_job
620+
621+
564622
async def cancel_jobs( # noqa: C901, PLR0912
565623
state: AppState, user_id: int, wallet_id: int | None, *, force: bool
566624
) -> None:
@@ -571,7 +629,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
571629
computational_clusters = await _list_computational_clusters(
572630
state, user_id, wallet_id
573631
)
574-
job_id_to_dask_state: dict[TaskId, TaskState] = {}
632+
575633
if computational_clusters:
576634
assert (
577635
len(computational_clusters) == 1
@@ -580,22 +638,10 @@ async def cancel_jobs( # noqa: C901, PLR0912
580638
the_cluster = computational_clusters[0]
581639
rich.print(f"{the_cluster.task_states_to_tasks=}")
582640

583-
for job_state, job_ids in the_cluster.task_states_to_tasks.items():
584-
for job_id in job_ids:
585-
job_id_to_dask_state[job_id] = job_state
586-
587-
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = []
588-
for task in computational_tasks:
589-
dask_task = None
590-
if task.job_id:
591-
dask_task = DaskTask(
592-
job_id=task.job_id,
593-
state=job_id_to_dask_state.pop(task.job_id, None) or "unknown",
594-
)
595-
task_to_dask_job.append((task, dask_task))
596-
# keep the jobs still in the cluster
597-
for job_id, dask_state in job_id_to_dask_state.items():
598-
task_to_dask_job.append((None, DaskTask(job_id=job_id, state=dask_state)))
641+
job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster(the_cluster)
642+
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = (
643+
await _get_db_task_to_dask_job(computational_tasks, job_id_to_dask_state)
644+
)
599645

600646
if not task_to_dask_job:
601647
rich.print("[red]nothing found![/red]")
@@ -611,27 +657,12 @@ async def cancel_jobs( # noqa: C901, PLR0912
611657
if response == "none":
612658
rich.print("[yellow]not cancelling anything[/yellow]")
613659
elif response == "all":
614-
rich.print("cancelling all tasks")
615-
for comp_task, dask_task in task_to_dask_job:
616-
if dask_task is not None and dask_task.state != "unknown":
617-
await dask.trigger_job_cancellation_in_scheduler(
618-
state,
619-
the_cluster,
620-
dask_task.job_id,
621-
)
622-
if comp_task is None:
623-
# we need to clear it of the cluster
624-
await dask.remove_job_from_scheduler(
625-
state,
626-
the_cluster,
627-
dask_task.job_id,
628-
)
629-
if comp_task is not None and force:
630-
await db.abort_job_in_db(
631-
state, comp_task.project_id, comp_task.node_id
632-
)
633-
634-
rich.print("cancelled all tasks")
660+
await _cancel_all_jobs(
661+
state,
662+
the_cluster,
663+
task_to_dask_job=task_to_dask_job,
664+
abort_in_db=force,
665+
)
635666
else:
636667
try:
637668
# Split the response and handle ranges
@@ -673,7 +704,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
673704

674705

675706
async def trigger_cluster_termination(
676-
state: AppState, user_id: int, wallet_id: int
707+
state: AppState, user_id: int, wallet_id: int, *, force: bool
677708
) -> None:
678709
assert state.ec2_resource_clusters_keeper
679710
computational_instances = await ec2.list_computational_instances_from_ec2(
@@ -692,8 +723,20 @@ async def trigger_cluster_termination(
692723
state.environment,
693724
state.ec2_resource_clusters_keeper.meta.client.meta.region_name,
694725
)
695-
if typer.confirm("Are you sure you want to trigger termination of that cluster?"):
726+
if (force is True) or typer.confirm(
727+
"Are you sure you want to trigger termination of that cluster?"
728+
):
696729
the_cluster = computational_clusters[0]
730+
731+
computational_tasks = await db.list_computational_tasks_from_db(state, user_id)
732+
job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster(the_cluster)
733+
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = (
734+
await _get_db_task_to_dask_job(computational_tasks, job_id_to_dask_state)
735+
)
736+
await _cancel_all_jobs(
737+
state, the_cluster, task_to_dask_job=task_to_dask_job, abort_in_db=force
738+
)
739+
697740
new_heartbeat_tag: TagTypeDef = {
698741
"Key": "last_heartbeat",
699742
"Value": f"{arrow.utcnow().datetime - datetime.timedelta(hours=1)}",

0 commit comments

Comments
 (0)