Skip to content

Commit 6f12146

Browse files
authored
ENH: Add the ability to bring your own worker (#733)
* add the ability to BYO your own worker, i.e. without it needing to monkey path the pydra.engine.workers.WORKERS dict * changed byo plugins to be classes not instances * touch up * added test to catch missing patch lines * touch up
1 parent 1858668 commit 6f12146

File tree

3 files changed

+116
-29
lines changed

3 files changed

+116
-29
lines changed

pydra/engine/submitter.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Handle execution backends."""
22

33
import asyncio
4+
import typing as ty
45
import pickle
56
from uuid import uuid4
6-
from .workers import WORKERS
7+
from .workers import Worker, WORKERS
78
from .core import is_workflow
89
from .helpers import get_open_loop, load_and_run_async
910

@@ -16,24 +17,34 @@
1617
class Submitter:
1718
"""Send a task to the execution backend."""
1819

19-
def __init__(self, plugin="cf", **kwargs):
20+
def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs):
2021
"""
2122
Initialize task submission.
2223
2324
Parameters
2425
----------
25-
plugin : :obj:`str`
26-
The identifier of the execution backend.
26+
plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]`
27+
Either the identifier of the execution backend or the worker class itself.
2728
Default is ``cf`` (Concurrent Futures).
29+
**kwargs
30+
Additional keyword arguments to pass to the worker.
2831
2932
"""
3033
self.loop = get_open_loop()
3134
self._own_loop = not self.loop.is_running()
32-
self.plugin = plugin
33-
try:
34-
self.worker = WORKERS[self.plugin](**kwargs)
35-
except KeyError:
36-
raise NotImplementedError(f"No worker for {self.plugin}")
35+
if isinstance(plugin, str):
36+
self.plugin = plugin
37+
try:
38+
worker_cls = WORKERS[self.plugin]
39+
except KeyError:
40+
raise NotImplementedError(f"No worker for '{self.plugin}' plugin")
41+
else:
42+
try:
43+
self.plugin = plugin.plugin_name
44+
except AttributeError:
45+
raise ValueError("Worker class must have a 'plugin_name' str attribute")
46+
worker_cls = plugin
47+
self.worker = worker_cls(**kwargs)
3748
self.worker.loop = self.loop
3849

3950
def __call__(self, runnable, cache_locations=None, rerun=False, environment=None):

pydra/engine/tests/test_submitter.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import re
33
import subprocess as sp
44
import time
5+
import os
6+
from unittest.mock import patch
57

68
import pytest
79

@@ -12,8 +14,9 @@
1214
gen_basic_wf_with_threadcount,
1315
gen_basic_wf_with_threadcount_concurrent,
1416
)
15-
from ..core import Workflow
17+
from ..core import Workflow, TaskBase
1618
from ..submitter import Submitter
19+
from ..workers import SerialWorker
1720
from ... import mark
1821
from pathlib import Path
1922
from datetime import datetime
@@ -612,3 +615,61 @@ def alter_input(x):
612615
@mark.task
613616
def to_tuple(x, y):
614617
return (x, y)
618+
619+
620+
class BYOAddVarWorker(SerialWorker):
621+
"""A dummy worker that adds 1 to the output of the task"""
622+
623+
plugin_name = "byo_add_env_var"
624+
625+
def __init__(self, add_var, **kwargs):
626+
super().__init__(**kwargs)
627+
self.add_var = add_var
628+
629+
async def exec_serial(self, runnable, rerun=False, environment=None):
630+
if isinstance(runnable, TaskBase):
631+
with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}):
632+
result = runnable._run(rerun, environment=environment)
633+
return result
634+
else: # it could be tuple that includes pickle files with tasks and inputs
635+
return super().exec_serial(runnable, rerun, environment)
636+
637+
638+
@mark.task
639+
def add_env_var_task(x: int) -> int:
640+
return x + int(os.environ.get("BYO_ADD_VAR", 0))
641+
642+
643+
def test_byo_worker():
644+
645+
task1 = add_env_var_task(x=1)
646+
647+
with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub:
648+
assert sub.plugin == "byo_add_env_var"
649+
result = task1(submitter=sub)
650+
651+
assert result.output.out == 11
652+
653+
task2 = add_env_var_task(x=2)
654+
655+
with Submitter(plugin="serial") as sub:
656+
result = task2(submitter=sub)
657+
658+
assert result.output.out == 2
659+
660+
661+
def test_bad_builtin_worker():
662+
663+
with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"):
664+
Submitter(plugin="bad-worker")
665+
666+
667+
def test_bad_byo_worker():
668+
669+
class BadWorker:
670+
pass
671+
672+
with pytest.raises(
673+
ValueError, match="Worker class must have a 'plugin_name' str attribute"
674+
):
675+
Submitter(plugin=BadWorker)

pydra/engine/workers.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ async def fetch_finished(self, futures):
128128
class SerialWorker(Worker):
129129
"""A worker to execute linearly."""
130130

131+
plugin_name = "serial"
132+
131133
def __init__(self, **kwargs):
132134
"""Initialize worker."""
133135
logger.debug("Initialize SerialWorker")
@@ -157,6 +159,8 @@ async def fetch_finished(self, futures):
157159
class ConcurrentFuturesWorker(Worker):
158160
"""A worker to execute in parallel using Python's concurrent futures."""
159161

162+
plugin_name = "cf"
163+
160164
def __init__(self, n_procs=None):
161165
"""Initialize Worker."""
162166
super().__init__()
@@ -192,6 +196,7 @@ def close(self):
192196
class SlurmWorker(DistributedWorker):
193197
"""A worker to execute tasks on SLURM systems."""
194198

199+
plugin_name = "slurm"
195200
_cmd = "sbatch"
196201
_sacct_re = re.compile(
197202
"(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
@@ -367,6 +372,8 @@ async def _verify_exit_code(self, jobid):
367372
class SGEWorker(DistributedWorker):
368373
"""A worker to execute tasks on SLURM systems."""
369374

375+
plugin_name = "sge"
376+
370377
_cmd = "qsub"
371378
_sacct_re = re.compile(
372379
"(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
@@ -860,6 +867,8 @@ class DaskWorker(Worker):
860867
This is an experimental implementation with limited testing.
861868
"""
862869

870+
plugin_name = "dask"
871+
863872
def __init__(self, **kwargs):
864873
"""Initialize Worker."""
865874
super().__init__()
@@ -898,7 +907,7 @@ def close(self):
898907
class PsijWorker(Worker):
899908
"""A worker to execute tasks using PSI/J."""
900909

901-
def __init__(self, subtype, **kwargs):
910+
def __init__(self, **kwargs):
902911
"""
903912
Initialize PsijWorker.
904913
@@ -915,15 +924,6 @@ def __init__(self, subtype, **kwargs):
915924
logger.debug("Initialize PsijWorker")
916925
self.psij = psij
917926

918-
# Check if the provided subtype is valid
919-
valid_subtypes = ["local", "slurm"]
920-
if subtype not in valid_subtypes:
921-
raise ValueError(
922-
f"Invalid 'subtype' provided. Available options: {', '.join(valid_subtypes)}"
923-
)
924-
925-
self.subtype = subtype
926-
927927
def run_el(self, interface, rerun=False, **kwargs):
928928
"""Run a task."""
929929
return self.exec_psij(interface, rerun=rerun)
@@ -1039,14 +1039,29 @@ def close(self):
10391039
pass
10401040

10411041

1042+
class PsijLocalWorker(PsijWorker):
1043+
"""A worker to execute tasks using PSI/J on the local machine."""
1044+
1045+
subtype = "local"
1046+
plugin_name = f"psij-{subtype}"
1047+
1048+
1049+
class PsijSlurmWorker(PsijWorker):
1050+
"""A worker to execute tasks using PSI/J using SLURM."""
1051+
1052+
subtype = "slurm"
1053+
plugin_name = f"psij-{subtype}"
1054+
1055+
10421056
WORKERS = {
1043-
"serial": SerialWorker,
1044-
"cf": ConcurrentFuturesWorker,
1045-
"slurm": SlurmWorker,
1046-
"dask": DaskWorker,
1047-
"sge": SGEWorker,
1048-
**{
1049-
"psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype)
1050-
for subtype in ["local", "slurm"]
1051-
},
1057+
w.plugin_name: w
1058+
for w in (
1059+
SerialWorker,
1060+
ConcurrentFuturesWorker,
1061+
SlurmWorker,
1062+
DaskWorker,
1063+
SGEWorker,
1064+
PsijLocalWorker,
1065+
PsijSlurmWorker,
1066+
)
10521067
}

0 commit comments

Comments
 (0)