Skip to content

Commit 86671ad

Browse files
authored
Merge pull request #573 from NicolasGensollen/fix-SpecInfo-in-Worflow-constructor
[FIX] Enable `Workflow` constructor to receive `SpecInfo` objects for `input_spec` parameter
2 parents 219c721 + 0c77ad5 commit 86671ad

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

pydra/engine/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,12 @@ def __init__(
843843
if input_spec:
844844
if isinstance(input_spec, BaseSpec):
845845
self.input_spec = input_spec
846+
elif isinstance(input_spec, SpecInfo):
847+
if not any([x == BaseSpec for x in input_spec.bases]):
848+
raise ValueError(
849+
"Provided SpecInfo must have BaseSpec as it's base."
850+
)
851+
self.input_spec = input_spec
846852
else:
847853
self.input_spec = SpecInfo(
848854
name="Inputs",

pydra/engine/tests/test_workflow.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,44 @@
3030
from ..submitter import Submitter
3131
from ..core import Workflow
3232
from ... import mark
33+
from ..specs import SpecInfo, BaseSpec, ShellSpec
3334

3435

3536
def test_wf_no_input_spec():
3637
with pytest.raises(ValueError, match="Empty input_spec"):
3738
Workflow(name="workflow")
3839

3940

41+
def test_wf_specinfo_input_spec():
42+
input_spec = SpecInfo(
43+
name="Input",
44+
fields=[
45+
("a", str, "", {"mandatory": True}),
46+
("b", dict, {"foo": 1, "bar": False}, {"mandatory": False}),
47+
],
48+
bases=(BaseSpec,),
49+
)
50+
wf = Workflow(
51+
name="workflow",
52+
input_spec=input_spec,
53+
)
54+
for x in ["a", "b"]:
55+
assert hasattr(wf.inputs, x)
56+
assert wf.inputs.a == ""
57+
assert wf.inputs.b == {"foo": 1, "bar": False}
58+
bad_input_spec = SpecInfo(
59+
name="Input",
60+
fields=[
61+
("a", str, {"mandatory": True}),
62+
],
63+
bases=(ShellSpec,),
64+
)
65+
with pytest.raises(
66+
ValueError, match="Provided SpecInfo must have BaseSpec as it's base."
67+
):
68+
Workflow(name="workflow", input_spec=bad_input_spec)
69+
70+
4071
def test_wf_name_conflict1():
4172
"""raise error when workflow name conflicts with a class attribute or method"""
4273
with pytest.raises(ValueError) as excinfo1:

0 commit comments

Comments
 (0)