@@ -413,7 +413,61 @@ async def _parse_dynamic_instances(
413413 return dynamic_instances
414414
415415
416- async def summary (state : AppState , user_id : int | None , wallet_id : int | None ) -> bool :
416+ def _print_summary_as_json (
417+ dynamic_instances : list [DynamicInstance ],
418+ computational_clusters : list [ComputationalCluster ],
419+ ) -> None :
420+ result = {
421+ "dynamic_instances" : [
422+ {
423+ "name" : instance .name ,
424+ "ec2_instance_id" : instance .ec2_instance .instance_id ,
425+ "running_services" : [
426+ {
427+ "user_id" : service .user_id ,
428+ "project_id" : service .project_id ,
429+ "node_id" : service .node_id ,
430+ "service_name" : service .service_name ,
431+ "service_version" : service .service_version ,
432+ "created_at" : service .created_at .isoformat (),
433+ "needs_manual_intervention" : service .needs_manual_intervention ,
434+ }
435+ for service in instance .running_services
436+ ],
437+ "disk_space" : instance .disk_space .human_readable (),
438+ }
439+ for instance in dynamic_instances
440+ ],
441+ "computational_clusters" : [
442+ {
443+ "primary" : {
444+ "name" : cluster .primary .name ,
445+ "ec2_instance_id" : cluster .primary .ec2_instance .instance_id ,
446+ "user_id" : cluster .primary .user_id ,
447+ "wallet_id" : cluster .primary .wallet_id ,
448+ "disk_space" : cluster .primary .disk_space .human_readable (),
449+ "last_heartbeat" : cluster .primary .last_heartbeat .isoformat (),
450+ },
451+ "workers" : [
452+ {
453+ "name" : worker .name ,
454+ "ec2_instance_id" : worker .ec2_instance .instance_id ,
455+ "disk_space" : worker .disk_space .human_readable (),
456+ }
457+ for worker in cluster .workers
458+ ],
459+ "datasets" : cluster .datasets ,
460+ "tasks" : cluster .task_states_to_tasks ,
461+ }
462+ for cluster in computational_clusters
463+ ],
464+ }
465+ rich .print_json (json .dumps (result ))
466+
467+
468+ async def summary (
469+ state : AppState , user_id : int | None , wallet_id : int | None , * , output_json : bool
470+ ) -> bool :
417471 # get all the running instances
418472 assert state .ec2_resource_autoscaling
419473 dynamic_instances = await ec2 .list_dynamic_instances_from_ec2 (
@@ -422,19 +476,6 @@ async def summary(state: AppState, user_id: int | None, wallet_id: int | None) -
422476 dynamic_autoscaled_instances = await _parse_dynamic_instances (
423477 state , dynamic_instances , state .ssh_key_path , user_id , wallet_id
424478 )
425- _print_dynamic_instances (
426- dynamic_autoscaled_instances ,
427- state .environment ,
428- state .ec2_resource_autoscaling .meta .client .meta .region_name ,
429- )
430-
431- time_threshold = arrow .utcnow ().shift (minutes = - 30 ).datetime
432-
433- dynamic_services_in_error = any (
434- service .needs_manual_intervention and service .created_at < time_threshold
435- for instance in dynamic_autoscaled_instances
436- for service in instance .running_services
437- )
438479
439480 assert state .ec2_resource_clusters_keeper
440481 computational_instances = await ec2 .list_computational_instances_from_ec2 (
@@ -443,10 +484,27 @@ async def summary(state: AppState, user_id: int | None, wallet_id: int | None) -
443484 computational_clusters = await _parse_computational_clusters (
444485 state , computational_instances , state .ssh_key_path , user_id , wallet_id
445486 )
446- _print_computational_clusters (
447- computational_clusters ,
448- state .environment ,
449- state .ec2_resource_clusters_keeper .meta .client .meta .region_name ,
487+
488+ if output_json :
489+ _print_summary_as_json (dynamic_autoscaled_instances , computational_clusters )
490+
491+ if not output_json :
492+ _print_dynamic_instances (
493+ dynamic_autoscaled_instances ,
494+ state .environment ,
495+ state .ec2_resource_autoscaling .meta .client .meta .region_name ,
496+ )
497+ _print_computational_clusters (
498+ computational_clusters ,
499+ state .environment ,
500+ state .ec2_resource_clusters_keeper .meta .client .meta .region_name ,
501+ )
502+
503+ time_threshold = arrow .utcnow ().shift (minutes = - 30 ).datetime
504+ dynamic_services_in_error = any (
505+ service .needs_manual_intervention and service .created_at < time_threshold
506+ for instance in dynamic_autoscaled_instances
507+ for service in instance .running_services
450508 )
451509
452510 return not dynamic_services_in_error
@@ -504,29 +562,48 @@ async def _list_computational_clusters(
504562 )
505563
506564
507- async def cancel_jobs ( # noqa: C901, PLR0912
508- state : AppState , user_id : int , wallet_id : int | None , * , force : bool
565+ async def _cancel_all_jobs (
566+ state : AppState ,
567+ the_cluster : ComputationalCluster ,
568+ * ,
569+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]],
570+ abort_in_db : bool ,
509571) -> None :
510- # get the theory
511- computational_tasks = await db .list_computational_tasks_from_db (state , user_id )
572+ rich .print ("cancelling all tasks" )
573+ for comp_task , dask_task in task_to_dask_job :
574+ if dask_task is not None and dask_task .state != "unknown" :
575+ await dask .trigger_job_cancellation_in_scheduler (
576+ state ,
577+ the_cluster ,
578+ dask_task .job_id ,
579+ )
580+ if comp_task is None :
581+ # we need to clear it of the cluster
582+ await dask .remove_job_from_scheduler (
583+ state ,
584+ the_cluster ,
585+ dask_task .job_id ,
586+ )
587+ if comp_task is not None and abort_in_db :
588+ await db .abort_job_in_db (state , comp_task .project_id , comp_task .node_id )
512589
513- # get the reality
514- computational_clusters = await _list_computational_clusters (
515- state , user_id , wallet_id
516- )
517- job_id_to_dask_state : dict [TaskId , TaskState ] = {}
518- if computational_clusters :
519- assert (
520- len (computational_clusters ) == 1
521- ), "too many clusters found! TIP: fix this code or something weird is playing out"
590+ rich .print ("cancelled all tasks" )
522591
523- the_cluster = computational_clusters [0 ]
524- rich .print (f"{ the_cluster .task_states_to_tasks = } " )
525592
526- for job_state , job_ids in the_cluster .task_states_to_tasks .items ():
527- for job_id in job_ids :
528- job_id_to_dask_state [job_id ] = job_state
593+ async def _get_job_id_to_dask_state_from_cluster (
594+ cluster : ComputationalCluster ,
595+ ) -> dict [TaskId , TaskState ]:
596+ job_id_to_dask_state : dict [TaskId , TaskState ] = {}
597+ for job_state , job_ids in cluster .task_states_to_tasks .items ():
598+ for job_id in job_ids :
599+ job_id_to_dask_state [job_id ] = job_state
600+ return job_id_to_dask_state
529601
602+
603+ async def _get_db_task_to_dask_job (
604+ computational_tasks : list [ComputationalTask ],
605+ job_id_to_dask_state : dict [TaskId , TaskState ],
606+ ) -> list [tuple [ComputationalTask | None , DaskTask | None ]]:
530607 task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = []
531608 for task in computational_tasks :
532609 dask_task = None
@@ -539,6 +616,32 @@ async def cancel_jobs( # noqa: C901, PLR0912
539616 # keep the jobs still in the cluster
540617 for job_id , dask_state in job_id_to_dask_state .items ():
541618 task_to_dask_job .append ((None , DaskTask (job_id = job_id , state = dask_state )))
619+ return task_to_dask_job
620+
621+
622+ async def cancel_jobs ( # noqa: C901, PLR0912
623+ state : AppState , user_id : int , wallet_id : int | None , * , abort_in_db : bool
624+ ) -> None :
625+ # get the theory
626+ computational_tasks = await db .list_computational_tasks_from_db (state , user_id )
627+
628+ # get the reality
629+ computational_clusters = await _list_computational_clusters (
630+ state , user_id , wallet_id
631+ )
632+
633+ if computational_clusters :
634+ assert (
635+ len (computational_clusters ) == 1
636+ ), "too many clusters found! TIP: fix this code or something weird is playing out"
637+
638+ the_cluster = computational_clusters [0 ]
639+ rich .print (f"{ the_cluster .task_states_to_tasks = } " )
640+
641+ job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster (the_cluster )
642+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = (
643+ await _get_db_task_to_dask_job (computational_tasks , job_id_to_dask_state )
644+ )
542645
543646 if not task_to_dask_job :
544647 rich .print ("[red]nothing found![/red]" )
@@ -554,27 +657,12 @@ async def cancel_jobs( # noqa: C901, PLR0912
554657 if response == "none" :
555658 rich .print ("[yellow]not cancelling anything[/yellow]" )
556659 elif response == "all" :
557- rich .print ("cancelling all tasks" )
558- for comp_task , dask_task in task_to_dask_job :
559- if dask_task is not None and dask_task .state != "unknown" :
560- await dask .trigger_job_cancellation_in_scheduler (
561- state ,
562- the_cluster ,
563- dask_task .job_id ,
564- )
565- if comp_task is None :
566- # we need to clear it of the cluster
567- await dask .remove_job_from_scheduler (
568- state ,
569- the_cluster ,
570- dask_task .job_id ,
571- )
572- if comp_task is not None and force :
573- await db .abort_job_in_db (
574- state , comp_task .project_id , comp_task .node_id
575- )
576-
577- rich .print ("cancelled all tasks" )
660+ await _cancel_all_jobs (
661+ state ,
662+ the_cluster ,
663+ task_to_dask_job = task_to_dask_job ,
664+ abort_in_db = abort_in_db ,
665+ )
578666 else :
579667 try :
580668 # Split the response and handle ranges
@@ -597,7 +685,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
597685 state , the_cluster , dask_task .job_id
598686 )
599687
600- if comp_task is not None and force :
688+ if comp_task is not None and abort_in_db :
601689 await db .abort_job_in_db (
602690 state , comp_task .project_id , comp_task .node_id
603691 )
@@ -616,7 +704,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
616704
617705
618706async def trigger_cluster_termination (
619- state : AppState , user_id : int , wallet_id : int
707+ state : AppState , user_id : int , wallet_id : int , * , force : bool
620708) -> None :
621709 assert state .ec2_resource_clusters_keeper
622710 computational_instances = await ec2 .list_computational_instances_from_ec2 (
@@ -635,8 +723,20 @@ async def trigger_cluster_termination(
635723 state .environment ,
636724 state .ec2_resource_clusters_keeper .meta .client .meta .region_name ,
637725 )
638- if typer .confirm ("Are you sure you want to trigger termination of that cluster?" ):
726+ if (force is True ) or typer .confirm (
727+ "Are you sure you want to trigger termination of that cluster?"
728+ ):
639729 the_cluster = computational_clusters [0 ]
730+
731+ computational_tasks = await db .list_computational_tasks_from_db (state , user_id )
732+ job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster (the_cluster )
733+ task_to_dask_job : list [tuple [ComputationalTask | None , DaskTask | None ]] = (
734+ await _get_db_task_to_dask_job (computational_tasks , job_id_to_dask_state )
735+ )
736+ await _cancel_all_jobs (
737+ state , the_cluster , task_to_dask_job = task_to_dask_job , abort_in_db = force
738+ )
739+
640740 new_heartbeat_tag : TagTypeDef = {
641741 "Key" : "last_heartbeat" ,
642742 "Value" : f"{ arrow .utcnow ().datetime - datetime .timedelta (hours = 1 )} " ,
0 commit comments