1717)
1818from typing import (
1919 Any ,
20+ Generic ,
2021 Optional ,
2122 TYPE_CHECKING ,
23+ TypeVar ,
2224 Union ,
2325)
2426
@@ -325,7 +327,7 @@ def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
325327 def stop_job (self , job_wrapper ):
326328 raise NotImplementedError ()
327329
328- def recover (self , job , job_wrapper ) :
330+ def recover (self , job : model . Job , job_wrapper : "MinimalJobWrapper" ) -> None :
329331 raise NotImplementedError ()
330332
331333 def build_command_line (
@@ -591,9 +593,15 @@ def _handle_runner_state(self, runner_state, job_state: "JobState"):
591593 except Exception :
592594 log .exception ("Caught exception in runner state handler" )
593595
594- def fail_job (self , job_state : "JobState" , exception = False , message = "Job failed" , full_status = None ):
596+ def fail_job (
597+ self ,
598+ job_state : "JobState" ,
599+ exception : bool = False ,
600+ message : str = "Job failed" ,
601+ full_status : Union [dict [str , Any ], None ] = None ,
602+ ) -> None :
595603 job = job_state .job_wrapper .get_job ()
596- if getattr ( job_state , " stop_job" , True ) and job .state != model .Job .states .NEW :
604+ if job_state . stop_job and job .state != model .Job .states .NEW :
597605 self .stop_job (job_state .job_wrapper )
598606 job_state .job_wrapper .reclaim_ownership ()
599607 self ._handle_runner_state ("failure" , job_state )
@@ -705,13 +713,14 @@ class JobState:
705713
706714 runner_states = runner_states
707715
708- def __init__ (self , job_wrapper : "JobWrapper " , job_destination : "JobDestination" ):
716+ def __init__ (self , job_wrapper : "MinimalJobWrapper " , job_destination : "JobDestination" ) -> None :
709717 self .runner_state_handled = False
710718 self .job_wrapper = job_wrapper
711719 self .job_destination = job_destination
712720 self .runner_state = None
713721 self .redact_email_in_job_name = True
714722 self ._exit_code_file = None
723+ self .stop_job = True
715724 if self .job_wrapper :
716725 self .redact_email_in_job_name = self .job_wrapper .app .config .redact_email_in_job_name
717726
@@ -765,23 +774,26 @@ class AsynchronousJobState(JobState):
765774 to communicate with distributed resource manager.
766775 """
767776
777+ old_state : Union ["JobStateEnum" , None ]
778+
768779 def __init__ (
769780 self ,
781+ job_wrapper : "MinimalJobWrapper" ,
782+ job_destination : "JobDestination" ,
783+ * ,
770784 files_dir = None ,
771- job_wrapper = None ,
772785 job_id : Union [str , None ] = None ,
773786 job_file = None ,
774787 output_file = None ,
775788 error_file = None ,
776789 exit_code_file = None ,
777790 job_name = None ,
778- job_destination = None ,
779- ):
791+ ) -> None :
780792 super ().__init__ (job_wrapper , job_destination )
781- self .old_state : Union [ JobStateEnum , None ] = None
793+ self .old_state = None
782794 self ._running = False
783795 self .check_count = 0
784- self .start_time = None
796+ self .start_time : Union [ datetime . datetime , None ] = None
785797
786798 # job_id is the DRM's job id, not the Galaxy job id
787799 self .job_id = job_id
@@ -796,11 +808,11 @@ def __init__(
796808 self .set_defaults (files_dir )
797809
798810 @property
799- def running (self ):
811+ def running (self ) -> bool :
800812 return self ._running
801813
802814 @running .setter
803- def running (self , is_running ) :
815+ def running (self , is_running : bool ) -> None :
804816 self ._running = is_running
805817 # This will be invalid for job recovery
806818 if self .start_time is None :
@@ -834,22 +846,28 @@ def init_job_stream_files(self):
834846 pass
835847
836848
837- class AsynchronousJobRunner (BaseJobRunner , Monitors ):
849+ T = TypeVar ("T" , bound = AsynchronousJobState )
850+
851+
852+ class AsynchronousJobRunner (BaseJobRunner , Monitors , Generic [T ]):
838853 """Parent class for any job runner that runs jobs asynchronously (e.g. via
839854 a distributed resource manager). Provides general methods for having a
840855 thread to monitor the state of asynchronous jobs and submitting those jobs
841856 to the correct methods (queue, finish, cleanup) at appropriate times..
842857 """
843858
859+ monitor_queue : Queue [T ]
860+ watched : list [T ]
861+
844862 def __init__ (self , app : "GalaxyManagerApplication" , nworkers : int , ** kwargs ) -> None :
845863 super ().__init__ (app , nworkers , ** kwargs )
846864 # 'watched' and 'queue' are both used to keep track of jobs to watch.
847865 # 'queue' is used to add new watched jobs, and can be called from
848866 # any thread (usually by the 'queue_job' method). 'watched' must only
849867 # be modified by the monitor thread, which will move items from 'queue'
850868 # to 'watched' and then manage the watched jobs.
851- self .watched : list [ AsynchronousJobState ] = []
852- self .monitor_queue : Queue [ AsynchronousJobState ] = Queue ()
869+ self .watched = []
870+ self .monitor_queue = Queue ()
853871
854872 def _init_monitor_thread (self ):
855873 name = f"{ self .runner_name } .monitor_thread"
@@ -892,7 +910,7 @@ def monitor(self):
892910 # Sleep a bit before the next state check
893911 time .sleep (self .app .config .job_runner_monitor_sleep )
894912
895- def monitor_job (self , job_state : AsynchronousJobState ) -> None :
913+ def monitor_job (self , job_state : T ) -> None :
896914 self .monitor_queue .put (job_state )
897915
898916 def shutdown (self ):
@@ -903,7 +921,7 @@ def shutdown(self):
903921 self .shutdown_monitor ()
904922 super ().shutdown ()
905923
906- def check_watched_items (self ):
924+ def check_watched_items (self ) -> None :
907925 """
908926 This method is responsible for iterating over self.watched and handling
909927 state changes and updating self.watched with a new list of watched job
@@ -919,7 +937,7 @@ def check_watched_items(self):
919937 self .watched = new_watched
920938
921939 # Subclasses should implement this unless they override check_watched_items all together.
922- def check_watched_item (self , job_state : AsynchronousJobState ) -> Union [AsynchronousJobState , None ]:
940+ def check_watched_item (self , job_state : T ) -> Union [T , None ]:
923941 raise NotImplementedError ()
924942
925943 def _collect_job_output (self , job_id : int , external_job_id : Optional [str ], job_state : JobState ):
@@ -943,7 +961,7 @@ def _collect_job_output(self, job_id: int, external_job_id: Optional[str], job_s
943961 which_try += 1
944962 return collect_output_success , stdout , stderr
945963
946- def finish_job (self , job_state : AsynchronousJobState ) :
964+ def finish_job (self , job_state : T ) -> None :
947965 """
948966 Get the output/error for a finished job, pass to `job_wrapper.finish`
949967 and cleanup all the job's temporary files.
0 commit comments