@@ -446,6 +446,7 @@ def _print_summary_as_json(
446446 "user_id" : cluster .primary .user_id ,
447447 "wallet_id" : cluster .primary .wallet_id ,
448448 "disk_space" : cluster .primary .disk_space .human_readable (),
449+ "last_heartbeat" : cluster .primary .last_heartbeat .isoformat (),
449450 },
450451 "workers" : [
451452 {
@@ -561,6 +562,63 @@ async def _list_computational_clusters(
561562 )
562563
563564
565+ async def _cancel_all_jobs (
566+ state : AppState ,
567+ the_cluster : ComputationalCluster ,
568+ * ,
569+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]],
570+ abort_in_db : bool ,
571+ ) -> None :
572+ rich .print ("cancelling all tasks" )
573+ for comp_task , dask_task in task_to_dask_job :
574+ if dask_task is not None and dask_task .state != "unknown" :
575+ await dask .trigger_job_cancellation_in_scheduler (
576+ state ,
577+ the_cluster ,
578+ dask_task .job_id ,
579+ )
580+ if comp_task is None :
581+ # we need to clear it of the cluster
582+ await dask .remove_job_from_scheduler (
583+ state ,
584+ the_cluster ,
585+ dask_task .job_id ,
586+ )
587+ if comp_task is not None and abort_in_db :
588+ await db .abort_job_in_db (state , comp_task .project_id , comp_task .node_id )
589+
590+ rich .print ("cancelled all tasks" )
591+
592+
593+ async def _get_job_id_to_dask_state_from_cluster (
594+ cluster : ComputationalCluster ,
595+ ) -> dict [TaskId , TaskState ]:
596+ job_id_to_dask_state : dict [TaskId , TaskState ] = {}
597+ for job_state , job_ids in cluster .task_states_to_tasks .items ():
598+ for job_id in job_ids :
599+ job_id_to_dask_state [job_id ] = job_state
600+ return job_id_to_dask_state
601+
602+
603+ async def _get_db_task_to_dask_job (
604+ computational_tasks : list [ComputationalTask ],
605+ job_id_to_dask_state : dict [TaskId , TaskState ],
606+ ) -> list [tuple [ComputationalTask | None , DaskTask | None ]]:
607+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = []
608+ for task in computational_tasks :
609+ dask_task = None
610+ if task .job_id :
611+ dask_task = DaskTask (
612+ job_id = task .job_id ,
613+ state = job_id_to_dask_state .pop (task .job_id , None ) or "unknown" ,
614+ )
615+ task_to_dask_job .append ((task , dask_task ))
616+ # keep the jobs still in the cluster
617+ for job_id , dask_state in job_id_to_dask_state .items ():
618+ task_to_dask_job .append ((None , DaskTask (job_id = job_id , state = dask_state )))
619+ return task_to_dask_job
620+
621+
564622async def cancel_jobs ( # noqa: C901, PLR0912
565623 state : AppState , user_id : int , wallet_id : int | None , * , force : bool
566624) -> None :
@@ -571,7 +629,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
571629 computational_clusters = await _list_computational_clusters (
572630 state , user_id , wallet_id
573631 )
574- job_id_to_dask_state : dict [ TaskId , TaskState ] = {}
632+
575633 if computational_clusters :
576634 assert (
577635 len (computational_clusters ) == 1
@@ -580,22 +638,10 @@ async def cancel_jobs( # noqa: C901, PLR0912
580638 the_cluster = computational_clusters [0 ]
581639 rich .print (f"{ the_cluster .task_states_to_tasks = } " )
582640
583- for job_state , job_ids in the_cluster .task_states_to_tasks .items ():
584- for job_id in job_ids :
585- job_id_to_dask_state [job_id ] = job_state
586-
587- task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = []
588- for task in computational_tasks :
589- dask_task = None
590- if task .job_id :
591- dask_task = DaskTask (
592- job_id = task .job_id ,
593- state = job_id_to_dask_state .pop (task .job_id , None ) or "unknown" ,
594- )
595- task_to_dask_job .append ((task , dask_task ))
596- # keep the jobs still in the cluster
597- for job_id , dask_state in job_id_to_dask_state .items ():
598- task_to_dask_job .append ((None , DaskTask (job_id = job_id , state = dask_state )))
641+ job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster (the_cluster )
642+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = (
643+ await _get_db_task_to_dask_job (computational_tasks , job_id_to_dask_state )
644+ )
599645
600646 if not task_to_dask_job :
601647 rich .print ("[red]nothing found![/red]" )
@@ -611,27 +657,12 @@ async def cancel_jobs( # noqa: C901, PLR0912
611657 if response == "none" :
612658 rich .print ("[yellow]not cancelling anything[/yellow]" )
613659 elif response == "all" :
614- rich .print ("cancelling all tasks" )
615- for comp_task , dask_task in task_to_dask_job :
616- if dask_task is not None and dask_task .state != "unknown" :
617- await dask .trigger_job_cancellation_in_scheduler (
618- state ,
619- the_cluster ,
620- dask_task .job_id ,
621- )
622- if comp_task is None :
623- # we need to clear it of the cluster
624- await dask .remove_job_from_scheduler (
625- state ,
626- the_cluster ,
627- dask_task .job_id ,
628- )
629- if comp_task is not None and force :
630- await db .abort_job_in_db (
631- state , comp_task .project_id , comp_task .node_id
632- )
633-
634- rich .print ("cancelled all tasks" )
660+ await _cancel_all_jobs (
661+ state ,
662+ the_cluster ,
663+ task_to_dask_job = task_to_dask_job ,
664+ abort_in_db = force ,
665+ )
635666 else :
636667 try :
637668 # Split the response and handle ranges
@@ -673,7 +704,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
673704
674705
675706async def trigger_cluster_termination (
676- state : AppState , user_id : int , wallet_id : int
707+ state : AppState , user_id : int , wallet_id : int , * , force : bool
677708) -> None :
678709 assert state .ec2_resource_clusters_keeper
679710 computational_instances = await ec2 .list_computational_instances_from_ec2 (
@@ -692,8 +723,20 @@ async def trigger_cluster_termination(
692723 state .environment ,
693724 state .ec2_resource_clusters_keeper .meta .client .meta .region_name ,
694725 )
695- if typer .confirm ("Are you sure you want to trigger termination of that cluster?" ):
726+ if (force is True ) or typer .confirm (
727+ "Are you sure you want to trigger termination of that cluster?"
728+ ):
696729 the_cluster = computational_clusters [0 ]
730+
731+ computational_tasks = await db .list_computational_tasks_from_db (state , user_id )
732+ job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster (the_cluster )
733+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = (
734+ await _get_db_task_to_dask_job (computational_tasks , job_id_to_dask_state )
735+ )
736+ await _cancel_all_jobs (
737+ state , the_cluster , task_to_dask_job = task_to_dask_job , abort_in_db = force
738+ )
739+
697740 new_heartbeat_tag : TagTypeDef = {
698741 "Key" : "last_heartbeat" ,
699742 "Value" : f"{ arrow .utcnow ().datetime - datetime .timedelta (hours = 1 )} " ,
0 commit comments