Skip to content

Commit b94f185

Browse files
authored
Merge branch 'master' into hash-change-guards
2 parents 4e1d4a8 + 0e66136 commit b94f185

File tree

10 files changed

+538
-94
lines changed

10 files changed

+538
-94
lines changed

pydra/engine/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def make_klass(spec):
263263
**kwargs,
264264
)
265265
checker_label = f"'{name}' field of {spec.name}"
266-
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
266+
type_checker = TypeParser[newfield.type](
267+
newfield.type, label=checker_label, superclass_auto_cast=True
268+
)
267269
if newfield.type in (MultiInputObj, MultiInputFile):
268270
converter = attr.converters.pipe(ensure_list, type_checker)
269271
elif newfield.type in (MultiOutputObj, MultiOutputFile):

pydra/engine/specs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,16 +449,19 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
449449
),
450450
):
451451
raise TypeError(
452-
f"Support for {fld.type} type, required for {fld.name} in {self}, "
452+
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
453453
"has not been implemented in collect_additional_output"
454454
)
455455
# assuming that field should have either default or metadata, but not both
456456
input_value = getattr(inputs, fld.name, attr.NOTHING)
457457
if input_value is not attr.NOTHING:
458458
if TypeParser.contains_type(FileSet, fld.type):
459-
label = f"output field '{fld.name}' of {self}"
460-
input_value = TypeParser(fld.type, label=label).coerce(input_value)
461-
additional_out[fld.name] = input_value
459+
if input_value is not False:
460+
label = f"output field '{fld.name}' of {self}"
461+
input_value = TypeParser(fld.type, label=label).coerce(
462+
input_value
463+
)
464+
additional_out[fld.name] = input_value
462465
elif (
463466
fld.default is None or fld.default == attr.NOTHING
464467
) and not fld.metadata: # TODO: is it right?

pydra/engine/submitter.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Handle execution backends."""
22

33
import asyncio
4+
import typing as ty
45
import pickle
56
from uuid import uuid4
6-
from .workers import WORKERS
7+
from .workers import Worker, WORKERS
78
from .core import is_workflow
89
from .helpers import get_open_loop, load_and_run_async
910

@@ -16,24 +17,34 @@
1617
class Submitter:
1718
"""Send a task to the execution backend."""
1819

19-
def __init__(self, plugin="cf", **kwargs):
20+
def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs):
2021
"""
2122
Initialize task submission.
2223
2324
Parameters
2425
----------
25-
plugin : :obj:`str`
26-
The identifier of the execution backend.
26+
plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]`
27+
Either the identifier of the execution backend or the worker class itself.
2728
Default is ``cf`` (Concurrent Futures).
29+
**kwargs
30+
Additional keyword arguments to pass to the worker.
2831
2932
"""
3033
self.loop = get_open_loop()
3134
self._own_loop = not self.loop.is_running()
32-
self.plugin = plugin
33-
try:
34-
self.worker = WORKERS[self.plugin](**kwargs)
35-
except KeyError:
36-
raise NotImplementedError(f"No worker for {self.plugin}")
35+
if isinstance(plugin, str):
36+
self.plugin = plugin
37+
try:
38+
worker_cls = WORKERS[self.plugin]
39+
except KeyError:
40+
raise NotImplementedError(f"No worker for '{self.plugin}' plugin")
41+
else:
42+
try:
43+
self.plugin = plugin.plugin_name
44+
except AttributeError:
45+
raise ValueError("Worker class must have a 'plugin_name' str attribute")
46+
worker_cls = plugin
47+
self.worker = worker_cls(**kwargs)
3748
self.worker.loop = self.loop
3849

3950
def __call__(self, runnable, cache_locations=None, rerun=False, environment=None):

pydra/engine/tests/test_node_task.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,7 @@ def test_task_init_3a(
133133

134134

135135
def test_task_init_4():
136-
"""task with interface and inputs. splitter set using split method"""
137-
nn = fun_addtwo(name="NA")
138-
nn.split(splitter="a", a=[3, 5])
139-
assert np.allclose(nn.inputs.a, [3, 5])
140-
141-
assert nn.state.splitter == "NA.a"
142-
assert nn.state.splitter_rpn == ["NA.a"]
143-
144-
nn.state.prepare_states(nn.inputs)
145-
assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
146-
assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}]
147-
148-
149-
def test_task_init_4a():
150-
"""task with a splitter and inputs set in the split method"""
136+
"""task with interface splitter and inputs set in the split method"""
151137
nn = fun_addtwo(name="NA")
152138
nn.split(splitter="a", a=[3, 5])
153139
assert np.allclose(nn.inputs.a, [3, 5])

pydra/engine/tests/test_shelltask.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4347,6 +4347,102 @@ def change_name(file):
43474347
# res = shelly(plugin="cf")
43484348

43494349

4350+
def test_shell_cmd_optional_output_file1(tmp_path):
4351+
"""
4352+
Test to see that 'unused' doesn't complain about not having an output passed to it
4353+
"""
4354+
my_cp_spec = SpecInfo(
4355+
name="Input",
4356+
fields=[
4357+
(
4358+
"input",
4359+
attr.ib(
4360+
type=File, metadata={"argstr": "", "help_string": "input file"}
4361+
),
4362+
),
4363+
(
4364+
"output",
4365+
attr.ib(
4366+
type=Path,
4367+
metadata={
4368+
"argstr": "",
4369+
"output_file_template": "out.txt",
4370+
"help_string": "output file",
4371+
},
4372+
),
4373+
),
4374+
(
4375+
"unused",
4376+
attr.ib(
4377+
type=ty.Union[Path, bool],
4378+
default=False,
4379+
metadata={
4380+
"argstr": "--not-used",
4381+
"output_file_template": "out.txt",
4382+
"help_string": "dummy output",
4383+
},
4384+
),
4385+
),
4386+
],
4387+
bases=(ShellSpec,),
4388+
)
4389+
4390+
my_cp = ShellCommandTask(
4391+
name="my_cp",
4392+
executable="cp",
4393+
input_spec=my_cp_spec,
4394+
)
4395+
file1 = tmp_path / "file1.txt"
4396+
file1.write_text("foo")
4397+
result = my_cp(input=file1, unused=False)
4398+
assert result.output.output.fspath.read_text() == "foo"
4399+
4400+
4401+
def test_shell_cmd_optional_output_file2(tmp_path):
4402+
"""
4403+
Test to see that 'unused' doesn't complain about not having an output passed to it
4404+
"""
4405+
my_cp_spec = SpecInfo(
4406+
name="Input",
4407+
fields=[
4408+
(
4409+
"input",
4410+
attr.ib(
4411+
type=File, metadata={"argstr": "", "help_string": "input file"}
4412+
),
4413+
),
4414+
(
4415+
"output",
4416+
attr.ib(
4417+
type=ty.Union[Path, bool],
4418+
default=False,
4419+
metadata={
4420+
"argstr": "",
4421+
"output_file_template": "out.txt",
4422+
"help_string": "dummy output",
4423+
},
4424+
),
4425+
),
4426+
],
4427+
bases=(ShellSpec,),
4428+
)
4429+
4430+
my_cp = ShellCommandTask(
4431+
name="my_cp",
4432+
executable="cp",
4433+
input_spec=my_cp_spec,
4434+
)
4435+
file1 = tmp_path / "file1.txt"
4436+
file1.write_text("foo")
4437+
result = my_cp(input=file1, output=True)
4438+
assert result.output.output.fspath.read_text() == "foo"
4439+
4440+
file2 = tmp_path / "file2.txt"
4441+
file2.write_text("bar")
4442+
with pytest.raises(RuntimeError):
4443+
my_cp(input=file2, output=False)
4444+
4445+
43504446
def test_shell_cmd_non_existing_outputs_1(tmp_path):
43514447
"""Checking that non existing output files do not return a phantom path,
43524448
but return NOTHING instead"""

pydra/engine/tests/test_submitter.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import attrs
77
import typing as ty
88
from random import randint
9+
import os
10+
from unittest.mock import patch
911
import pytest
1012
from fileformats.generic import Directory
1113
from .utils import (
@@ -15,8 +17,9 @@
1517
gen_basic_wf_with_threadcount,
1618
gen_basic_wf_with_threadcount_concurrent,
1719
)
18-
from ..core import Workflow
20+
from ..core import Workflow, TaskBase
1921
from ..submitter import Submitter
22+
from ..workers import SerialWorker
2023
from ... import mark
2124
from pathlib import Path
2225
from datetime import datetime
@@ -665,3 +668,65 @@ def to_tuple(x, y):
665668
):
666669
with Submitter("cf") as sub:
667670
result = sub(wf)
671+
672+
@mark.task
673+
def to_tuple(x, y):
674+
return (x, y)
675+
676+
677+
class BYOAddVarWorker(SerialWorker):
678+
"""A dummy worker that adds 1 to the output of the task"""
679+
680+
plugin_name = "byo_add_env_var"
681+
682+
def __init__(self, add_var, **kwargs):
683+
super().__init__(**kwargs)
684+
self.add_var = add_var
685+
686+
async def exec_serial(self, runnable, rerun=False, environment=None):
687+
if isinstance(runnable, TaskBase):
688+
with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}):
689+
result = runnable._run(rerun, environment=environment)
690+
return result
691+
else: # it could be tuple that includes pickle files with tasks and inputs
692+
return super().exec_serial(runnable, rerun, environment)
693+
694+
695+
@mark.task
696+
def add_env_var_task(x: int) -> int:
697+
return x + int(os.environ.get("BYO_ADD_VAR", 0))
698+
699+
700+
def test_byo_worker():
701+
702+
task1 = add_env_var_task(x=1)
703+
704+
with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub:
705+
assert sub.plugin == "byo_add_env_var"
706+
result = task1(submitter=sub)
707+
708+
assert result.output.out == 11
709+
710+
task2 = add_env_var_task(x=2)
711+
712+
with Submitter(plugin="serial") as sub:
713+
result = task2(submitter=sub)
714+
715+
assert result.output.out == 2
716+
717+
718+
def test_bad_builtin_worker():
719+
720+
with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"):
721+
Submitter(plugin="bad-worker")
722+
723+
724+
def test_bad_byo_worker():
725+
726+
class BadWorker:
727+
pass
728+
729+
with pytest.raises(
730+
ValueError, match="Worker class must have a 'plugin_name' str attribute"
731+
):
732+
Submitter(plugin=BadWorker)

0 commit comments

Comments
 (0)