Skip to content

Commit 8c8a79c

Browse files
authored
Merge pull request #584 from NicolasGensollen/add-graph-checksums-if-missing
[ENH] Add `_graph_checksums` to `input_spec` if missing
2 parents 9053ba7 + a7374aa commit 8c8a79c

File tree

2 files changed

+61
-31
lines changed

2 files changed

+61
-31
lines changed

pydra/engine/core.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,64 @@ def _reset(self):
797797
task._reset()
798798

799799

800+
def _sanitize_input_spec(
801+
input_spec: ty.Union[SpecInfo, ty.List[str]],
802+
wf_name: str,
803+
) -> SpecInfo:
804+
"""Makes sure the provided input specifications are valid.
805+
806+
If the input specification is a list of strings, this will
807+
build a proper SpecInfo object out of it.
808+
809+
Parameters
810+
----------
811+
input_spec : SpecInfo or List[str]
812+
Input specification to be sanitized.
813+
814+
wf_name : str
815+
The name of the workflow for which the input specifications
816+
are sanitized.
817+
818+
Returns
819+
-------
820+
input_spec : SpecInfo
821+
Sanitized input specifications.
822+
823+
Raises
824+
------
825+
ValueError
826+
If provided `input_spec` is None.
827+
"""
828+
graph_checksum_input = ("_graph_checksums", ty.Any)
829+
if input_spec:
830+
if isinstance(input_spec, SpecInfo):
831+
if not any([x == BaseSpec for x in input_spec.bases]):
832+
raise ValueError("Provided SpecInfo must have BaseSpec as it's base.")
833+
if "_graph_checksums" not in {f[0] for f in input_spec.fields}:
834+
input_spec.fields.insert(0, graph_checksum_input)
835+
return input_spec
836+
else:
837+
return SpecInfo(
838+
name="Inputs",
839+
fields=[graph_checksum_input]
840+
+ [
841+
(
842+
nm,
843+
attr.ib(
844+
type=ty.Any,
845+
metadata={
846+
"help_string": f"{nm} input from {wf_name} workflow"
847+
},
848+
),
849+
)
850+
for nm in input_spec
851+
],
852+
bases=(BaseSpec,),
853+
)
854+
else:
855+
raise ValueError(f"Empty input_spec provided to Workflow {wf_name}.")
856+
857+
800858
class Workflow(TaskBase):
801859
"""A composite task with structure of computational graph."""
802860

@@ -806,7 +864,7 @@ def __init__(
806864
audit_flags: AuditFlag = AuditFlag.NONE,
807865
cache_dir=None,
808866
cache_locations=None,
809-
input_spec: ty.Optional[ty.Union[ty.List[ty.Text], SpecInfo, BaseSpec]] = None,
867+
input_spec: ty.Optional[ty.Union[ty.List[ty.Text], SpecInfo]] = None,
810868
cont_dim=None,
811869
messenger_args=None,
812870
messengers=None,
@@ -842,35 +900,7 @@ def __init__(
842900
TODO
843901
844902
"""
845-
if input_spec:
846-
if isinstance(input_spec, BaseSpec):
847-
self.input_spec = input_spec
848-
elif isinstance(input_spec, SpecInfo):
849-
if not any([x == BaseSpec for x in input_spec.bases]):
850-
raise ValueError(
851-
"Provided SpecInfo must have BaseSpec as it's base."
852-
)
853-
self.input_spec = input_spec
854-
else:
855-
self.input_spec = SpecInfo(
856-
name="Inputs",
857-
fields=[("_graph_checksums", ty.Any)]
858-
+ [
859-
(
860-
nm,
861-
attr.ib(
862-
type=ty.Any,
863-
metadata={
864-
"help_string": f"{nm} input from {name} workflow"
865-
},
866-
),
867-
)
868-
for nm in input_spec
869-
],
870-
bases=(BaseSpec,),
871-
)
872-
else:
873-
raise ValueError("Empty input_spec provided to Workflow")
903+
self.input_spec = _sanitize_input_spec(input_spec, name)
874904

875905
self.output_spec = output_spec
876906

pydra/engine/tests/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_wf_specinfo_input_spec():
5151
name="workflow",
5252
input_spec=input_spec,
5353
)
54-
for x in ["a", "b"]:
54+
for x in ["a", "b", "_graph_checksums"]:
5555
assert hasattr(wf.inputs, x)
5656
assert wf.inputs.a == ""
5757
assert wf.inputs.b == {"foo": 1, "bar": False}

0 commit comments

Comments
 (0)