Skip to content

Commit e57bd57

Browse files
authored
Merge pull request #258 from djarecka/fix/wf_set_output
[ fix] adding wf._connection to the checksum (closes #253, closes #252)
2 parents 194a3ba + 5ce0fe7 commit e57bd57

File tree

4 files changed

+382
-14
lines changed

4 files changed

+382
-14
lines changed

pydra/engine/core.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,15 @@ def version(self):
219219

220220
@property
221221
def checksum(self):
222-
"""Calculate a unique checksum of this task."""
223-
# if checksum is called before run the _graph_checksums is not ready
224-
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
225-
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
226-
222+
""" Calculates the unique checksum of the task.
223+
Used to create specific directory name for task that are run;
224+
and to create nodes checksums needed for graph checkums
225+
(before the tasks have inputs etc.)
226+
"""
227227
input_hash = self.inputs.hash
228228
if self.state is None:
229229
self._checksum = create_checksum(self.__class__.__name__, input_hash)
230230
else:
231-
# including splitter in the hash
232231
splitter_hash = hash_function(self.state.splitter)
233232
self._checksum = create_checksum(
234233
self.__class__.__name__, hash_function([input_hash, splitter_hash])
@@ -237,10 +236,9 @@ def checksum(self):
237236

238237
def checksum_states(self, state_index=None):
239238
"""
240-
Calculate a checksum for the specific state or all of the states.
241-
239+
Calculate a checksum for the specific state or all of the states of the task.
242240
Replaces lists in the inputs fields with a specific values for states.
243-
Can be used only for tasks with a state.
241+
Used to recreate names of the task directories,
244242
245243
Parameters
246244
----------
@@ -259,7 +257,14 @@ def checksum_states(self, state_index=None):
259257
getattr(inputs_copy, key.split(".")[1])[ind],
260258
)
261259
input_hash = inputs_copy.hash
262-
checksum_ind = create_checksum(self.__class__.__name__, input_hash)
260+
if is_workflow(self):
261+
con_hash = hash_function(self._connections)
262+
hash_list = [input_hash, con_hash]
263+
checksum_ind = create_checksum(
264+
self.__class__.__name__, self._checksum_wf(input_hash)
265+
)
266+
else:
267+
checksum_ind = create_checksum(self.__class__.__name__, input_hash)
263268
return checksum_ind
264269
else:
265270
checksum_list = []
@@ -753,6 +758,41 @@ def graph_sorted(self):
753758
"""Get a sorted graph representation of the workflow."""
754759
return self.graph.sorted_nodes
755760

761+
@property
762+
def checksum(self):
763+
""" Calculates the unique checksum of the task.
764+
Used to create specific directory name for task that are run;
765+
and to create nodes checksums needed for graph checkums
766+
(before the tasks have inputs etc.)
767+
"""
768+
# if checksum is called before run the _graph_checksums is not ready
769+
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
770+
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
771+
772+
input_hash = self.inputs.hash
773+
if not self.state:
774+
self._checksum = create_checksum(
775+
self.__class__.__name__, self._checksum_wf(input_hash)
776+
)
777+
else:
778+
self._checksum = create_checksum(
779+
self.__class__.__name__,
780+
self._checksum_wf(input_hash, with_splitter=True),
781+
)
782+
return self._checksum
783+
784+
def _checksum_wf(self, input_hash, with_splitter=False):
785+
""" creating hash value for workflows
786+
includes connections and splitter if with_splitter is True
787+
"""
788+
connection_hash = hash_function(self._connections)
789+
hash_list = [input_hash, connection_hash]
790+
if with_splitter and self.state:
791+
# including splitter in the hash
792+
splitter_hash = hash_function(self.state.splitter)
793+
hash_list.append(splitter_hash)
794+
return hash_function(hash_list)
795+
756796
def add(self, task):
757797
"""
758798
Add a task to the workflow.
@@ -887,18 +927,29 @@ def set_output(self, connections):
887927
TODO
888928
889929
"""
930+
if self._connections is None:
931+
self._connections = []
890932
if isinstance(connections, tuple) and len(connections) == 2:
891-
self._connections = [connections]
933+
new_connections = [connections]
892934
elif isinstance(connections, list) and all(
893935
[len(el) == 2 for el in connections]
894936
):
895-
self._connections = connections
937+
new_connections = connections
896938
elif isinstance(connections, dict):
897-
self._connections = list(connections.items())
939+
new_connections = list(connections.items())
898940
else:
899941
raise Exception(
900942
"Connections can be a 2-elements tuple, a list of these tuples, or dictionary"
901943
)
944+
# checking if a new output name is already in the connections
945+
connection_names = [name for name, _ in self._connections]
946+
new_names = [name for name, _ in new_connections]
947+
if set(connection_names).intersection(new_names):
948+
raise Exception(
949+
f"output name {set(connection_names).intersection(new_names)} is already set"
950+
)
951+
952+
self._connections += new_connections
902953
fields = [(name, ty.Any) for name, _ in self._connections]
903954
self.output_spec = SpecInfo(name="Output", fields=fields, bases=(BaseSpec,))
904955
logger.info("Added %s to %s", self.output_spec, self)

pydra/engine/tests/test_boutiques.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@no_win
2323
@need_bosh_docker
24-
@pytest.mark.flaky(reruns=2) # need for travis
24+
@pytest.mark.flaky(reruns=3) # need for travis
2525
@pytest.mark.parametrize(
2626
"maskfile", ["test_brain.nii.gz", "test_brain", "test_brain.nii"]
2727
)
@@ -45,6 +45,7 @@ def test_boutiques_1(maskfile, plugin, results_function):
4545

4646
@no_win
4747
@need_bosh_docker
48+
@pytest.mark.flaky(reruns=3)
4849
def test_boutiques_spec_1():
4950
""" testing spec: providing input/output fields names"""
5051
btask = BoshTask(
@@ -69,6 +70,7 @@ def test_boutiques_spec_1():
6970

7071
@no_win
7172
@need_bosh_docker
73+
@pytest.mark.flaky(reruns=3)
7274
def test_boutiques_spec_2():
7375
""" testing spec: providing partial input/output fields names"""
7476
btask = BoshTask(
@@ -91,6 +93,7 @@ def test_boutiques_spec_2():
9193

9294
@no_win
9395
@need_bosh_docker
96+
@pytest.mark.flaky(reruns=3)
9497
@pytest.mark.parametrize(
9598
"maskfile", ["test_brain.nii.gz", "test_brain", "test_brain.nii"]
9699
)
@@ -121,6 +124,7 @@ def test_boutiques_wf_1(maskfile, plugin):
121124

122125
@no_win
123126
@need_bosh_docker
127+
@pytest.mark.flaky(reruns=3)
124128
@pytest.mark.parametrize(
125129
"maskfile", ["test_brain.nii.gz", "test_brain", "test_brain.nii"]
126130
)

0 commit comments

Comments
 (0)