9
9
import sys
10
10
from pathlib import Path
11
11
import typing as ty
12
- from copy import deepcopy
12
+ from copy import deepcopy , copy
13
13
from uuid import uuid4
14
14
from filelock import SoftFileLock
15
15
import shutil
@@ -281,13 +281,15 @@ def checksum_states(self, state_index=None):
281
281
282
282
"""
283
283
if is_workflow (self ) and self .inputs ._graph_checksums is attr .NOTHING :
284
- self .inputs ._graph_checksums = [nd .checksum for nd in self .graph_sorted ]
284
+ self .inputs ._graph_checksums = {
285
+ nd .name : nd .checksum for nd in self .graph_sorted
286
+ }
285
287
286
288
if state_index is not None :
287
- inputs_copy = deepcopy (self .inputs )
289
+ inputs_copy = copy (self .inputs )
288
290
for key , ind in self .state .inputs_ind [state_index ].items ():
289
291
val = self ._extract_input_el (
290
- inputs = inputs_copy , inp_nm = key .split ("." )[1 ], ind = ind
292
+ inputs = self . inputs , inp_nm = key .split ("." )[1 ], ind = ind
291
293
)
292
294
setattr (inputs_copy , key .split ("." )[1 ], val )
293
295
# setting files_hash again in case it was cleaned by setting specific element
@@ -462,13 +464,25 @@ def __call__(
462
464
return res
463
465
464
466
def _modify_inputs (self ):
465
- """Update and preserve a Task's original inputs"""
467
+ """This method modifies the inputs of the task ahead of its execution:
468
+ - links/copies upstream files and directories into the destination tasks
469
+ working directory as required select state array values corresponding to
470
+ state index (it will try to leave them where they are unless specified or
471
+ they are on different file systems)
472
+ - resolve template values (e.g. output_file_template)
473
+ - deepcopy all inputs to guard against in-place changes during the task's
474
+ execution (they will be replaced after the task's execution with the
475
+ original inputs to ensure the tasks checksums are consistent)
476
+ """
466
477
orig_inputs = {
467
- k : deepcopy (v ) for k , v in attr .asdict (self .inputs , recurse = False ).items ()
478
+ k : v
479
+ for k , v in attr .asdict (self .inputs , recurse = False ).items ()
480
+ if not k .startswith ("_" )
468
481
}
469
482
map_copyfiles = {}
470
- for fld in attr_fields (self .inputs ):
471
- value = getattr (self .inputs , fld .name )
483
+ input_fields = attr .fields (type (self .inputs ))
484
+ for name , value in orig_inputs .items ():
485
+ fld = getattr (input_fields , name )
472
486
copy_mode , copy_collation = parse_copyfile (
473
487
fld , default_collation = self .DEFAULT_COPY_COLLATION
474
488
)
@@ -483,12 +497,22 @@ def _modify_inputs(self):
483
497
supported_modes = self .SUPPORTED_COPY_MODES ,
484
498
)
485
499
if value is not copied_value :
486
- map_copyfiles [fld . name ] = copied_value
500
+ map_copyfiles [name ] = copied_value
487
501
modified_inputs = template_update (
488
502
self .inputs , self .output_dir , map_copyfiles = map_copyfiles
489
503
)
490
- if modified_inputs :
491
- self .inputs = attr .evolve (self .inputs , ** modified_inputs )
504
+ assert all (m in orig_inputs for m in modified_inputs ), (
505
+ "Modified inputs contain fields not present in original inputs. "
506
+ "This is likely a bug."
507
+ )
508
+ for name , orig_value in orig_inputs .items ():
509
+ try :
510
+ value = modified_inputs [name ]
511
+ except KeyError :
512
+ # Ensure we pass a copy not the original just in case inner
513
+ # attributes are modified during execution
514
+ value = deepcopy (orig_value )
515
+ setattr (self .inputs , name , value )
492
516
return orig_inputs
493
517
494
518
def _populate_filesystem (self , checksum , output_dir ):
@@ -548,13 +572,14 @@ def _run(self, rerun=False, environment=None, **kwargs):
548
572
save (output_dir , result = result , task = self )
549
573
# removing the additional file with the checksum
550
574
(self .cache_dir / f"{ self .uid } _info.json" ).unlink ()
551
- # # function etc. shouldn't change anyway, so removing
552
- orig_inputs = {
553
- k : v for k , v in orig_inputs .items () if not k .startswith ("_" )
554
- }
555
- self .inputs = attr .evolve (self .inputs , ** orig_inputs )
575
+ # Restore original values to inputs
576
+ for field_name , field_value in orig_inputs .items ():
577
+ setattr (self .inputs , field_name , field_value )
556
578
os .chdir (cwd )
557
579
self .hooks .post_run (self , result )
580
+ # Check for any changes to the input hashes that have occurred during the execution
581
+ # of the task
582
+ self ._check_for_hash_changes ()
558
583
return result
559
584
560
585
def _collect_outputs (self , output_dir ):
@@ -816,8 +841,8 @@ def result(self, state_index=None, return_inputs=False):
816
841
817
842
Returns
818
843
-------
819
- result :
820
-
844
+ result : Result
845
+ the result of the task
821
846
"""
822
847
# TODO: check if result is available in load_result and
823
848
# return a future if not
@@ -884,6 +909,47 @@ def _reset(self):
884
909
for task in self .graph .nodes :
885
910
task ._reset ()
886
911
912
+ def _check_for_hash_changes (self ):
913
+ hash_changes = self .inputs .hash_changes ()
914
+ details = ""
915
+ for changed in hash_changes :
916
+ field = getattr (attr .fields (type (self .inputs )), changed )
917
+ val = getattr (self .inputs , changed )
918
+ field_type = type (val )
919
+ if issubclass (field .type , FileSet ):
920
+ details += (
921
+ f"- { changed } : value passed to the { field .type } field is of type "
922
+ f"{ field_type } ('{ val } '). If it is intended to contain output data "
923
+ "then the type of the field in the interface class should be changed "
924
+ "to `pathlib.Path`. Otherwise, if the field is intended to be an "
925
+ "input field but it gets altered by the task in some way, then the "
926
+ "'copyfile' flag should be set to 'copy' in the field metadata of "
927
+ "the task interface class so copies of the files/directories in it "
928
+ "are passed to the task instead.\n "
929
+ )
930
+ else :
931
+ details += (
932
+ f"- { changed } : the { field_type } object passed to the { field .type } "
933
+ f"field appears to have an unstable hash. This could be due to "
934
+ "a stochastic/non-thread-safe attribute(s) of the object\n \n "
935
+ f"The { field .type } .__bytes_repr__() method can be implemented to "
936
+ "bespoke hashing methods based only on the stable attributes for "
937
+ f"the `{ field_type .__module__ } .{ field_type .__name__ } ` type. "
938
+ f"See pydra/utils/hash.py for examples. Value: { val } \n "
939
+ )
940
+ if hash_changes :
941
+ raise RuntimeError (
942
+ f"Input field hashes have changed during the execution of the "
943
+ f"'{ self .name } ' { type (self ).__name__ } .\n \n { details } "
944
+ )
945
+ logger .debug (
946
+ "Input values and hashes for '%s' %s node:\n %s\n %s" ,
947
+ self .name ,
948
+ type (self ).__name__ ,
949
+ self .inputs ,
950
+ self .inputs ._hashes ,
951
+ )
952
+
887
953
SUPPORTED_COPY_MODES = FileSet .CopyMode .any
888
954
DEFAULT_COPY_COLLATION = FileSet .CopyCollation .any
889
955
@@ -1076,7 +1142,9 @@ def checksum(self):
1076
1142
"""
1077
1143
# if checksum is called before run the _graph_checksums is not ready
1078
1144
if is_workflow (self ) and self .inputs ._graph_checksums is attr .NOTHING :
1079
- self .inputs ._graph_checksums = [nd .checksum for nd in self .graph_sorted ]
1145
+ self .inputs ._graph_checksums = {
1146
+ nd .name : nd .checksum for nd in self .graph_sorted
1147
+ }
1080
1148
1081
1149
input_hash = self .inputs .hash
1082
1150
if not self .state :
@@ -1256,8 +1324,9 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
1256
1324
(self .cache_dir / f"{ self .uid } _info.json" ).unlink ()
1257
1325
os .chdir (cwd )
1258
1326
self .hooks .post_run (self , result )
1259
- if result is None :
1260
- raise Exception ("This should never happen, please open new issue" )
1327
+ # Check for any changes to the input hashes that have occurred during the execution
1328
+ # of the task
1329
+ self ._check_for_hash_changes ()
1261
1330
return result
1262
1331
1263
1332
async def _run_task (self , submitter , rerun = False ):
0 commit comments