Skip to content

Commit a55f2a1

Browse files
authored
Merge pull request #21050 from nsoranzo/job_runners_type_annot
Improve type annotation of job runners
2 parents aa67b87 + d216ec4 commit a55f2a1

File tree

17 files changed

+299
-178
lines changed

17 files changed

+299
-178
lines changed

lib/galaxy/jobs/__init__.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Callable,
2727
Optional,
2828
TYPE_CHECKING,
29+
Union,
2930
)
3031

3132
import yaml
@@ -1019,7 +1020,7 @@ def __init__(
10191020
app: MinimalManagerApp,
10201021
use_persisted_destination: bool = False,
10211022
tool: Optional["Tool"] = None,
1022-
):
1023+
) -> None:
10231024
self.job_id = job.id
10241025
self.session_id = job.session_id
10251026
self.user_id = job.user_id
@@ -1029,7 +1030,7 @@ def __init__(
10291030
self.extra_filenames: list[str] = []
10301031
self.environment_variables: list[dict[str, str]] = []
10311032
self.interactivetools: list[dict[str, Any]] = []
1032-
self.command_line = None
1033+
self.command_line: Union[str, None] = None
10331034
self.version_command_line = None
10341035
self._dependency_shell_commands = None
10351036
# Tool versioning variables
@@ -2826,26 +2827,19 @@ class TaskWrapper(JobWrapper):
28262827

28272828
is_task = True
28282829

2829-
def __init__(self, task, queue):
2830+
def __init__(self, task: Task, queue: "BaseJobHandlerQueue") -> None:
28302831
self.task_id = task.id
28312832
super().__init__(task.job, queue)
2832-
if task.prepare_input_files_cmd is not None:
2833-
self.prepare_input_files_cmds = [task.prepare_input_files_cmd]
2834-
else:
2835-
self.prepare_input_files_cmds = None
2833+
self.prepare_input_files_cmds = (
2834+
[task.prepare_input_files_cmd] if task.prepare_input_files_cmd is not None else None
2835+
)
28362836
self.status = task.states.NEW
28372837

28382838
def can_split(self):
28392839
# Should the job handler split this job up? TaskWrapper should
28402840
# always return False as the job has already been split.
28412841
return False
28422842

2843-
def get_job(self):
2844-
if self.job_id:
2845-
return self.sa_session.get(Job, self.job_id)
2846-
else:
2847-
return None
2848-
28492843
def get_task(self):
28502844
return self.sa_session.get(Task, self.task_id)
28512845

lib/galaxy/jobs/runners/__init__.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
)
1818
from typing import (
1919
Any,
20+
Generic,
2021
Optional,
2122
TYPE_CHECKING,
23+
TypeVar,
2224
Union,
2325
)
2426

@@ -325,7 +327,7 @@ def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
325327
def stop_job(self, job_wrapper):
326328
raise NotImplementedError()
327329

328-
def recover(self, job, job_wrapper):
330+
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
329331
raise NotImplementedError()
330332

331333
def build_command_line(
@@ -591,9 +593,15 @@ def _handle_runner_state(self, runner_state, job_state: "JobState"):
591593
except Exception:
592594
log.exception("Caught exception in runner state handler")
593595

594-
def fail_job(self, job_state: "JobState", exception=False, message="Job failed", full_status=None):
596+
def fail_job(
597+
self,
598+
job_state: "JobState",
599+
exception: bool = False,
600+
message: str = "Job failed",
601+
full_status: Union[dict[str, Any], None] = None,
602+
) -> None:
595603
job = job_state.job_wrapper.get_job()
596-
if getattr(job_state, "stop_job", True) and job.state != model.Job.states.NEW:
604+
if job_state.stop_job and job.state != model.Job.states.NEW:
597605
self.stop_job(job_state.job_wrapper)
598606
job_state.job_wrapper.reclaim_ownership()
599607
self._handle_runner_state("failure", job_state)
@@ -705,13 +713,14 @@ class JobState:
705713

706714
runner_states = runner_states
707715

708-
def __init__(self, job_wrapper: "JobWrapper", job_destination: "JobDestination"):
716+
def __init__(self, job_wrapper: "MinimalJobWrapper", job_destination: "JobDestination") -> None:
709717
self.runner_state_handled = False
710718
self.job_wrapper = job_wrapper
711719
self.job_destination = job_destination
712720
self.runner_state = None
713721
self.redact_email_in_job_name = True
714722
self._exit_code_file = None
723+
self.stop_job = True
715724
if self.job_wrapper:
716725
self.redact_email_in_job_name = self.job_wrapper.app.config.redact_email_in_job_name
717726

@@ -765,23 +774,26 @@ class AsynchronousJobState(JobState):
765774
to communicate with distributed resource manager.
766775
"""
767776

777+
old_state: Union["JobStateEnum", None]
778+
768779
def __init__(
769780
self,
781+
job_wrapper: "MinimalJobWrapper",
782+
job_destination: "JobDestination",
783+
*,
770784
files_dir=None,
771-
job_wrapper=None,
772785
job_id: Union[str, None] = None,
773786
job_file=None,
774787
output_file=None,
775788
error_file=None,
776789
exit_code_file=None,
777790
job_name=None,
778-
job_destination=None,
779-
):
791+
) -> None:
780792
super().__init__(job_wrapper, job_destination)
781-
self.old_state: Union[JobStateEnum, None] = None
793+
self.old_state = None
782794
self._running = False
783795
self.check_count = 0
784-
self.start_time = None
796+
self.start_time: Union[datetime.datetime, None] = None
785797

786798
# job_id is the DRM's job id, not the Galaxy job id
787799
self.job_id = job_id
@@ -796,11 +808,11 @@ def __init__(
796808
self.set_defaults(files_dir)
797809

798810
@property
799-
def running(self):
811+
def running(self) -> bool:
800812
return self._running
801813

802814
@running.setter
803-
def running(self, is_running):
815+
def running(self, is_running: bool) -> None:
804816
self._running = is_running
805817
# This will be invalid for job recovery
806818
if self.start_time is None:
@@ -834,22 +846,28 @@ def init_job_stream_files(self):
834846
pass
835847

836848

837-
class AsynchronousJobRunner(BaseJobRunner, Monitors):
849+
T = TypeVar("T", bound=AsynchronousJobState)
850+
851+
852+
class AsynchronousJobRunner(BaseJobRunner, Monitors, Generic[T]):
838853
"""Parent class for any job runner that runs jobs asynchronously (e.g. via
839854
a distributed resource manager). Provides general methods for having a
840855
thread to monitor the state of asynchronous jobs and submitting those jobs
841856
to the correct methods (queue, finish, cleanup) at appropriate times..
842857
"""
843858

859+
monitor_queue: Queue[T]
860+
watched: list[T]
861+
844862
def __init__(self, app: "GalaxyManagerApplication", nworkers: int, **kwargs) -> None:
845863
super().__init__(app, nworkers, **kwargs)
846864
# 'watched' and 'queue' are both used to keep track of jobs to watch.
847865
# 'queue' is used to add new watched jobs, and can be called from
848866
# any thread (usually by the 'queue_job' method). 'watched' must only
849867
# be modified by the monitor thread, which will move items from 'queue'
850868
# to 'watched' and then manage the watched jobs.
851-
self.watched: list[AsynchronousJobState] = []
852-
self.monitor_queue: Queue[AsynchronousJobState] = Queue()
869+
self.watched = []
870+
self.monitor_queue = Queue()
853871

854872
def _init_monitor_thread(self):
855873
name = f"{self.runner_name}.monitor_thread"
@@ -892,7 +910,7 @@ def monitor(self):
892910
# Sleep a bit before the next state check
893911
time.sleep(self.app.config.job_runner_monitor_sleep)
894912

895-
def monitor_job(self, job_state: AsynchronousJobState) -> None:
913+
def monitor_job(self, job_state: T) -> None:
896914
self.monitor_queue.put(job_state)
897915

898916
def shutdown(self):
@@ -903,7 +921,7 @@ def shutdown(self):
903921
self.shutdown_monitor()
904922
super().shutdown()
905923

906-
def check_watched_items(self):
924+
def check_watched_items(self) -> None:
907925
"""
908926
This method is responsible for iterating over self.watched and handling
909927
state changes and updating self.watched with a new list of watched job
@@ -919,7 +937,7 @@ def check_watched_items(self):
919937
self.watched = new_watched
920938

921939
# Subclasses should implement this unless they override check_watched_items all together.
922-
def check_watched_item(self, job_state: AsynchronousJobState) -> Union[AsynchronousJobState, None]:
940+
def check_watched_item(self, job_state: T) -> Union[T, None]:
923941
raise NotImplementedError()
924942

925943
def _collect_job_output(self, job_id: int, external_job_id: Optional[str], job_state: JobState):
@@ -943,7 +961,7 @@ def _collect_job_output(self, job_id: int, external_job_id: Optional[str], job_s
943961
which_try += 1
944962
return collect_output_success, stdout, stderr
945963

946-
def finish_job(self, job_state: AsynchronousJobState):
964+
def finish_job(self, job_state: T) -> None:
947965
"""
948966
Get the output/error for a finished job, pass to `job_wrapper.finish`
949967
and cleanup all the job's temporary files.

lib/galaxy/jobs/runners/aws.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import time
1010
from queue import Empty
1111
from typing import (
12+
Any,
1213
TYPE_CHECKING,
14+
Union,
1315
)
1416

1517
from galaxy import model
@@ -81,7 +83,7 @@ def _add_resource_requirements(destination_params):
8183
return rval
8284

8385

84-
class AWSBatchJobRunner(AsynchronousJobRunner):
86+
class AWSBatchJobRunner(AsynchronousJobRunner[AsynchronousJobState]):
8587
"""
8688
This runner uses container only. It requires that an AWS EFS is mounted as a local drive
8789
and all Galaxy job-related paths, such as objects, job_directory, tool_directory and so
@@ -213,7 +215,7 @@ def __init__(self, app, nworkers, **kwargs):
213215
)
214216
self._batch_client = session.client("batch")
215217

216-
def queue_job(self, job_wrapper):
218+
def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
217219
log.debug(f"Starting queue_job for job {job_wrapper.get_id_tag()}")
218220
if not self.prepare_job(job_wrapper, include_metadata=False, modify_command_for_container=False):
219221
log.debug(f"Not ready {job_wrapper.get_id_tag()}")
@@ -225,11 +227,11 @@ def queue_job(self, job_wrapper):
225227
job_name, job_id = self._submit_job(job_def, job_wrapper, destination_params)
226228
job_wrapper.set_external_id(job_id)
227229
ajs = AsynchronousJobState(
228-
files_dir=job_wrapper.working_directory,
229230
job_wrapper=job_wrapper,
231+
job_destination=job_destination,
232+
files_dir=job_wrapper.working_directory,
230233
job_name=job_name,
231234
job_id=job_id,
232-
job_destination=job_destination,
233235
)
234236
self.monitor_queue.put(ajs)
235237

@@ -395,16 +397,16 @@ def stop_job(self, job_wrapper):
395397
msg = "Job {name!r} is terminated"
396398
log.debug(msg.format(name=job_name))
397399

398-
def recover(self, job, job_wrapper):
400+
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
399401
msg = "(name!r/runner!r) is still in {state!s} state, adding to the runner monitor queue"
400402
job_id = job.get_job_runner_external_id()
401403
job_name = self.JOB_NAME_PREFIX + job_wrapper.get_id_tag()
402404
ajs = AsynchronousJobState(
403-
files_dir=job_wrapper.working_directory,
404405
job_wrapper=job_wrapper,
406+
job_destination=job_wrapper.job_destination,
407+
files_dir=job_wrapper.working_directory,
405408
job_id=str(job_id),
406409
job_name=job_name,
407-
job_destination=job_wrapper.job_destination,
408410
)
409411
if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
410412
log.debug(msg.format(name=job.id, runner=job.job_runner_name, state=job.state))
@@ -417,9 +419,9 @@ def recover(self, job, job_wrapper):
417419
ajs.running = False
418420
self.monitor_queue.put(ajs)
419421

420-
def fail_job(self, job_state: JobState, exception=False, message="Job failed", full_status=None):
422+
def fail_job(self, job_state: JobState, exception: bool = False, message: str = "Job failed", full_status: Union[dict[str, Any], None] = None) -> None:
421423
job = job_state.job_wrapper.get_job()
422-
if getattr(job_state, "stop_job", True) and job.state != model.Job.states.NEW:
424+
if job_state.stop_job and job.state != model.Job.states.NEW:
423425
self.stop_job(job_state.job_wrapper)
424426
job_state.job_wrapper.reclaim_ownership()
425427
self._handle_runner_state("failure", job_state)
@@ -460,7 +462,7 @@ def monitor(self):
460462
# Sleep a bit before the next state check
461463
time.sleep(max(self.app.config.job_runner_monitor_sleep, self.MIN_QUERY_INTERVAL))
462464

463-
def check_watched_items(self):
465+
def check_watched_items(self) -> None:
464466
done: set[str] = set()
465467
self.check_watched_items_by_batch(0, len(self.watched), done)
466468
self.watched = [ajs for ajs in self.watched if ajs.job_id not in done]

lib/galaxy/jobs/runners/chronos.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import functools
22
import logging
33
import os
4-
from typing import Union
4+
from typing import (
5+
TYPE_CHECKING,
6+
Union,
7+
)
58

69
from galaxy import model
710
from galaxy.jobs.runners import (
@@ -10,6 +13,9 @@
1013
)
1114
from galaxy.util import unicodify
1215

16+
if TYPE_CHECKING:
17+
from galaxy.jobs import MinimalJobWrapper
18+
1319
CHRONOS_IMPORT_MSG = (
1420
"The Python 'chronos' package is required to use "
1521
"this feature, please install it or correct the "
@@ -82,7 +88,7 @@ def _add_galaxy_environment_variables(cpus, memory):
8288
return [{"name": "GALAXY_SLOTS", "value": cpus}, {"name": "GALAXY_MEMORY_MB", "value": memory}]
8389

8490

85-
class ChronosJobRunner(AsynchronousJobRunner):
91+
class ChronosJobRunner(AsynchronousJobRunner[AsynchronousJobState]):
8692
runner_name = "ChronosRunner"
8793
RUNNER_PARAM_SPEC_KEY = "runner_param_specs"
8894
JOB_NAME_PREFIX = "galaxy-chronos-"
@@ -148,7 +154,7 @@ def __init__(self, app, nworkers, **kwargs):
148154
)
149155

150156
@handle_exception_call
151-
def queue_job(self, job_wrapper):
157+
def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
152158
LOGGER.debug(f"Starting queue_job for job {job_wrapper.get_id_tag()}")
153159
if not self.prepare_job(job_wrapper, include_metadata=False, modify_command_for_container=False):
154160
LOGGER.debug(f"Not ready {job_wrapper.get_id_tag()}")
@@ -158,10 +164,10 @@ def queue_job(self, job_wrapper):
158164
job_name = chronos_job_spec["name"]
159165
self._chronos_client.add(chronos_job_spec)
160166
ajs = AsynchronousJobState(
161-
files_dir=job_wrapper.working_directory,
162167
job_wrapper=job_wrapper,
163-
job_id=job_name,
164168
job_destination=job_destination,
169+
files_dir=job_wrapper.working_directory,
170+
job_id=job_name,
165171
)
166172
self.monitor_queue.put(ajs)
167173

@@ -178,16 +184,15 @@ def stop_job(self, job_wrapper):
178184
msg = "Job {name!r} not found. It cannot be terminated."
179185
LOGGER.error(msg.format(name=job_name))
180186

181-
def recover(self, job, job_wrapper):
187+
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
182188
msg = "(name!r/runner!r) is still in {state!s} state, adding to the runner monitor queue"
183189
job_id = job.get_job_runner_external_id()
184190
ajs = AsynchronousJobState(
185-
files_dir=job_wrapper.working_directory,
186191
job_wrapper=job_wrapper,
187-
job_id=self.JOB_NAME_PREFIX + str(job_id),
188192
job_destination=job_wrapper.job_destination,
193+
files_dir=job_wrapper.working_directory,
194+
job_id=self.JOB_NAME_PREFIX + str(job_id),
189195
)
190-
ajs.command_line = job.command_line
191196
if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
192197
LOGGER.debug(msg.format(name=job.id, runner=job.job_runner_external_id, state=job.state))
193198
ajs.old_state = model.Job.states.RUNNING
@@ -241,14 +246,14 @@ def _mark_as_active(self, job_state: AsynchronousJobState) -> AsynchronousJobSta
241246
def _mark_as_failed(self, job_state: AsynchronousJobState, reason: str) -> None:
242247
_write_logfile(job_state.error_file, reason)
243248
job_state.running = False
244-
job_state.stop_job = True
249+
job_state.stop_job = False
245250
job_state.job_wrapper.change_state(model.Job.states.ERROR)
246251
job_state.fail_message = reason
247252
self.mark_as_failed(job_state)
248253
return None
249254

250255
@handle_exception_call
251-
def finish_job(self, job_state):
256+
def finish_job(self, job_state: AsynchronousJobState) -> None:
252257
super().finish_job(job_state)
253258
self._chronos_client.delete(job_state.job_id)
254259

0 commit comments

Comments
 (0)