Skip to content

Commit ff01e4c

Browse files
authored
Merge pull request #698 from tclose/hash-change-guards
Hash change guards
2 parents 0e66136 + ff281aa commit ff01e4c

File tree

8 files changed

+286
-92
lines changed

8 files changed

+286
-92
lines changed

pydra/engine/core.py

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
from pathlib import Path
1111
import typing as ty
12-
from copy import deepcopy
12+
from copy import deepcopy, copy
1313
from uuid import uuid4
1414
from filelock import SoftFileLock
1515
import shutil
@@ -281,13 +281,15 @@ def checksum_states(self, state_index=None):
281281
282282
"""
283283
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
284-
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
284+
self.inputs._graph_checksums = {
285+
nd.name: nd.checksum for nd in self.graph_sorted
286+
}
285287

286288
if state_index is not None:
287-
inputs_copy = deepcopy(self.inputs)
289+
inputs_copy = copy(self.inputs)
288290
for key, ind in self.state.inputs_ind[state_index].items():
289291
val = self._extract_input_el(
290-
inputs=inputs_copy, inp_nm=key.split(".")[1], ind=ind
292+
inputs=self.inputs, inp_nm=key.split(".")[1], ind=ind
291293
)
292294
setattr(inputs_copy, key.split(".")[1], val)
293295
# setting files_hash again in case it was cleaned by setting specific element
@@ -462,13 +464,25 @@ def __call__(
462464
return res
463465

464466
def _modify_inputs(self):
465-
"""Update and preserve a Task's original inputs"""
467+
"""This method modifies the inputs of the task ahead of its execution:
468+
- links/copies upstream files and directories into the destination tasks
469+
working directory as required select state array values corresponding to
470+
state index (it will try to leave them where they are unless specified or
471+
they are on different file systems)
472+
- resolve template values (e.g. output_file_template)
473+
- deepcopy all inputs to guard against in-place changes during the task's
474+
execution (they will be replaced after the task's execution with the
475+
original inputs to ensure the tasks checksums are consistent)
476+
"""
466477
orig_inputs = {
467-
k: deepcopy(v) for k, v in attr.asdict(self.inputs, recurse=False).items()
478+
k: v
479+
for k, v in attr.asdict(self.inputs, recurse=False).items()
480+
if not k.startswith("_")
468481
}
469482
map_copyfiles = {}
470-
for fld in attr_fields(self.inputs):
471-
value = getattr(self.inputs, fld.name)
483+
input_fields = attr.fields(type(self.inputs))
484+
for name, value in orig_inputs.items():
485+
fld = getattr(input_fields, name)
472486
copy_mode, copy_collation = parse_copyfile(
473487
fld, default_collation=self.DEFAULT_COPY_COLLATION
474488
)
@@ -483,12 +497,22 @@ def _modify_inputs(self):
483497
supported_modes=self.SUPPORTED_COPY_MODES,
484498
)
485499
if value is not copied_value:
486-
map_copyfiles[fld.name] = copied_value
500+
map_copyfiles[name] = copied_value
487501
modified_inputs = template_update(
488502
self.inputs, self.output_dir, map_copyfiles=map_copyfiles
489503
)
490-
if modified_inputs:
491-
self.inputs = attr.evolve(self.inputs, **modified_inputs)
504+
assert all(m in orig_inputs for m in modified_inputs), (
505+
"Modified inputs contain fields not present in original inputs. "
506+
"This is likely a bug."
507+
)
508+
for name, orig_value in orig_inputs.items():
509+
try:
510+
value = modified_inputs[name]
511+
except KeyError:
512+
# Ensure we pass a copy not the original just in case inner
513+
# attributes are modified during execution
514+
value = deepcopy(orig_value)
515+
setattr(self.inputs, name, value)
492516
return orig_inputs
493517

494518
def _populate_filesystem(self, checksum, output_dir):
@@ -548,13 +572,14 @@ def _run(self, rerun=False, environment=None, **kwargs):
548572
save(output_dir, result=result, task=self)
549573
# removing the additional file with the checksum
550574
(self.cache_dir / f"{self.uid}_info.json").unlink()
551-
# # function etc. shouldn't change anyway, so removing
552-
orig_inputs = {
553-
k: v for k, v in orig_inputs.items() if not k.startswith("_")
554-
}
555-
self.inputs = attr.evolve(self.inputs, **orig_inputs)
575+
# Restore original values to inputs
576+
for field_name, field_value in orig_inputs.items():
577+
setattr(self.inputs, field_name, field_value)
556578
os.chdir(cwd)
557579
self.hooks.post_run(self, result)
580+
# Check for any changes to the input hashes that have occurred during the execution
581+
# of the task
582+
self._check_for_hash_changes()
558583
return result
559584

560585
def _collect_outputs(self, output_dir):
@@ -816,8 +841,8 @@ def result(self, state_index=None, return_inputs=False):
816841
817842
Returns
818843
-------
819-
result :
820-
844+
result : Result
845+
the result of the task
821846
"""
822847
# TODO: check if result is available in load_result and
823848
# return a future if not
@@ -884,6 +909,47 @@ def _reset(self):
884909
for task in self.graph.nodes:
885910
task._reset()
886911

912+
def _check_for_hash_changes(self):
913+
hash_changes = self.inputs.hash_changes()
914+
details = ""
915+
for changed in hash_changes:
916+
field = getattr(attr.fields(type(self.inputs)), changed)
917+
val = getattr(self.inputs, changed)
918+
field_type = type(val)
919+
if issubclass(field.type, FileSet):
920+
details += (
921+
f"- {changed}: value passed to the {field.type} field is of type "
922+
f"{field_type} ('{val}'). If it is intended to contain output data "
923+
"then the type of the field in the interface class should be changed "
924+
"to `pathlib.Path`. Otherwise, if the field is intended to be an "
925+
"input field but it gets altered by the task in some way, then the "
926+
"'copyfile' flag should be set to 'copy' in the field metadata of "
927+
"the task interface class so copies of the files/directories in it "
928+
"are passed to the task instead.\n"
929+
)
930+
else:
931+
details += (
932+
f"- {changed}: the {field_type} object passed to the {field.type}"
933+
f"field appears to have an unstable hash. This could be due to "
934+
"a stochastic/non-thread-safe attribute(s) of the object\n\n"
935+
f"The {field.type}.__bytes_repr__() method can be implemented to "
936+
"bespoke hashing methods based only on the stable attributes for "
937+
f"the `{field_type.__module__}.{field_type.__name__}` type. "
938+
f"See pydra/utils/hash.py for examples. Value: {val}\n"
939+
)
940+
if hash_changes:
941+
raise RuntimeError(
942+
f"Input field hashes have changed during the execution of the "
943+
f"'{self.name}' {type(self).__name__}.\n\n{details}"
944+
)
945+
logger.debug(
946+
"Input values and hashes for '%s' %s node:\n%s\n%s",
947+
self.name,
948+
type(self).__name__,
949+
self.inputs,
950+
self.inputs._hashes,
951+
)
952+
887953
SUPPORTED_COPY_MODES = FileSet.CopyMode.any
888954
DEFAULT_COPY_COLLATION = FileSet.CopyCollation.any
889955

@@ -1076,7 +1142,9 @@ def checksum(self):
10761142
"""
10771143
# if checksum is called before run the _graph_checksums is not ready
10781144
if is_workflow(self) and self.inputs._graph_checksums is attr.NOTHING:
1079-
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
1145+
self.inputs._graph_checksums = {
1146+
nd.name: nd.checksum for nd in self.graph_sorted
1147+
}
10801148

10811149
input_hash = self.inputs.hash
10821150
if not self.state:
@@ -1256,8 +1324,9 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
12561324
(self.cache_dir / f"{self.uid}_info.json").unlink()
12571325
os.chdir(cwd)
12581326
self.hooks.post_run(self, result)
1259-
if result is None:
1260-
raise Exception("This should never happen, please open new issue")
1327+
# Check for any changes to the input hashes that have occurred during the execution
1328+
# of the task
1329+
self._check_for_hash_changes()
12611330
return result
12621331

12631332
async def _run_task(self, submitter, rerun=False):

pydra/engine/specs.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
import pydra
1717
from .helpers_file import template_update_single
18-
from ..utils.hash import hash_function
18+
from ..utils.hash import hash_function, Cache
1919

2020
# from ..utils.misc import add_exc_note
2121

@@ -73,21 +73,22 @@ class SpecInfo:
7373
class BaseSpec:
7474
"""The base dataclass specs for all inputs and outputs."""
7575

76-
# def __attrs_post_init__(self):
77-
# self.files_hash = {
78-
# field.name: {}
79-
# for field in attr_fields(
80-
# self, exclude_names=("_graph_checksums", "bindings", "files_hash")
81-
# )
82-
# if field.metadata.get("output_file_template") is None
83-
# }
84-
8576
def collect_additional_outputs(self, inputs, output_dir, outputs):
8677
"""Get additional outputs."""
8778
return {}
8879

8980
@property
9081
def hash(self):
82+
hsh, self._hashes = self._compute_hashes()
83+
return hsh
84+
85+
def hash_changes(self):
86+
"""Detects any changes in the hashed values between the current inputs and the
87+
previously calculated values"""
88+
_, new_hashes = self._compute_hashes()
89+
return [k for k, v in new_hashes.items() if v != self._hashes[k]]
90+
91+
def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
9192
"""Compute a basic hash for any given set of fields."""
9293
inp_dict = {}
9394
for field in attr_fields(
@@ -101,10 +102,13 @@ def hash(self):
101102
if "container_path" in field.metadata:
102103
continue
103104
inp_dict[field.name] = getattr(self, field.name)
104-
inp_hash = hash_function(inp_dict)
105+
hash_cache = Cache({})
106+
field_hashes = {
107+
k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items()
108+
}
105109
if hasattr(self, "_graph_checksums"):
106-
inp_hash = hash_function((inp_hash, self._graph_checksums))
107-
return inp_hash
110+
field_hashes["_graph_checksums"] = self._graph_checksums
111+
return hash_function(sorted(field_hashes.items())), field_hashes
108112

109113
def retrieve_values(self, wf, state_index: ty.Optional[int] = None):
110114
"""Get values contained by this spec."""
@@ -984,8 +988,21 @@ def get_value(
984988
if result is None:
985989
raise RuntimeError(
986990
f"Could not find results of '{node.name}' node in a sub-directory "
987-
f"named '{node.checksum}' in any of the cache locations:\n"
991+
f"named '{node.checksum}' in any of the cache locations.\n"
988992
+ "\n".join(str(p) for p in set(node.cache_locations))
993+
+ f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. "
994+
f"Current values and hashes: {self.inputs}, "
995+
f"{self.inputs._hashes}\n\n"
996+
"Set loglevel to 'debug' in order to track hash changes "
997+
"throughout the execution of the workflow.\n\n "
998+
"These issues may have been caused by `bytes_repr()` methods "
999+
"that don't return stable hash values for specific object "
1000+
"types across multiple processes (see bytes_repr() "
1001+
'"singledispatch "function in pydra/utils/hash.py).'
1002+
"You may need to implement a specific `bytes_repr()` "
1003+
'"singledispatch overload"s or `__bytes_repr__()` '
1004+
"dunder methods to handle one or more types in "
1005+
"your interface inputs."
9891006
)
9901007
_, split_depth = TypeParser.strip_splits(self.type)
9911008

pydra/engine/submitter.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,55 @@ async def expand_workflow(self, wf, rerun=False):
183183
# don't block the event loop!
184184
await asyncio.sleep(1)
185185
if ii > 60:
186-
blocked = _list_blocked_tasks(graph_copy)
187-
# get_runnable_tasks(graph_copy) # Uncomment to debug `get_runnable_tasks`
188-
raise Exception(
189-
"graph is not empty, but not able to get more tasks "
190-
"- something may have gone wrong when retrieving the results "
191-
"of predecessor tasks. This could be caused by a file-system "
192-
"error or a bug in the internal workflow logic, but is likely "
193-
"to be caused by the hash of an upstream node being unstable."
194-
" \n\nHash instability can be caused by an input of the node being "
195-
"modified in place, or by psuedo-random ordering of `set` or "
196-
"`frozenset` inputs (or nested attributes of inputs) in the hash "
197-
"calculation. To ensure that sets are hashed consistently you can "
198-
"you can try set the environment variable PYTHONHASHSEED=0 for "
199-
"all processes, but it is best to try to identify where the set "
200-
"objects are occurring and manually hash their sorted elements. "
201-
"(or use list objects instead)"
202-
"\n\nBlocked tasks\n-------------\n" + "\n".join(blocked)
186+
msg = (
187+
f"Graph of '{wf}' workflow is not empty, but not able to get "
188+
"more tasks - something has gone wrong when retrieving the "
189+
"results predecessors:\n\n"
203190
)
191+
# Get blocked tasks and the predecessors they are waiting on
192+
outstanding = {
193+
t: [
194+
p for p in graph_copy.predecessors[t.name] if not p.done
195+
]
196+
for t in graph_copy.sorted_nodes
197+
}
198+
199+
hashes_have_changed = False
200+
for task, waiting_on in outstanding.items():
201+
if not waiting_on:
202+
continue
203+
msg += f"- '{task.name}' node blocked due to\n"
204+
for pred in waiting_on:
205+
if (
206+
pred.checksum
207+
!= wf.inputs._graph_checksums[pred.name]
208+
):
209+
msg += (
210+
f" - hash changes in '{pred.name}' node inputs. "
211+
f"Current values and hashes: {pred.inputs}, "
212+
f"{pred.inputs._hashes}\n"
213+
)
214+
hashes_have_changed = True
215+
elif pred not in outstanding:
216+
msg += (
217+
f" - undiagnosed issues in '{pred.name}' node, "
218+
"potentially related to file-system access issues "
219+
)
220+
msg += "\n"
221+
if hashes_have_changed:
222+
msg += (
223+
"Set loglevel to 'debug' in order to track hash changes "
224+
"throughout the execution of the workflow.\n\n "
225+
"These issues may have been caused by `bytes_repr()` methods "
226+
"that don't return stable hash values for specific object "
227+
"types across multiple processes (see bytes_repr() "
228+
'"singledispatch "function in pydra/utils/hash.py).'
229+
"You may need to implement a specific `bytes_repr()` "
230+
'"singledispatch overload"s or `__bytes_repr__()` '
231+
"dunder methods to handle one or more types in "
232+
"your interface inputs."
233+
)
234+
raise RuntimeError(msg)
204235
for task in tasks:
205236
# grab inputs if needed
206237
logger.debug(f"Retrieving inputs for {task}")

pydra/engine/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,8 @@ def command_args(self, root=None):
337337
raise NotImplementedError
338338

339339
modified_inputs = template_update(self.inputs, output_dir=self.output_dir)
340-
if modified_inputs is not None:
341-
self.inputs = attr.evolve(self.inputs, **modified_inputs)
340+
for field_name, field_value in modified_inputs.items():
341+
setattr(self.inputs, field_name, field_value)
342342

343343
pos_args = [] # list for (position, command arg)
344344
self._positions_provided = []

0 commit comments

Comments
 (0)