Skip to content

Commit dfecb8a

Browse files
Ryan CaliRyan Cali
authored andcommitted
Added dict for atlocation, digest tracking, added tests
1 parent 2912ee6 commit dfecb8a

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

pydra/engine/audit.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,25 @@ def audit_task(self, task):
176176
label = task.name
177177
entity_label = type(label)
178178

179-
command = task.cmdline if hasattr(task.inputs, "executable") else None
179+
if hasattr(task.inputs, "executable"):
180+
command = task.cmdline
181+
# assume function task
182+
else:
183+
command = None
184+
185+
path_hash_dict = {}
180186

181187
attr_list = attr_fields(task.inputs)
182188
for attrs in attr_list:
183189
if attrs.type in [File, Directory]:
184190
input_name = attrs.name
185191
input_path = os.path.abspath(getattr(task.inputs, input_name))
186192
file_hash = hash_file(input_path)
193+
path_hash_dict[input_path] = file_hash
187194

188-
else:
189-
input_name = attrs.name
190-
input_path = None
191-
file_hash = None
195+
# get the hash for the output
196+
input_paths = list(path_hash_dict.keys())
197+
input_paths_hash = list(path_hash_dict.values())
192198

193199
if command is not None:
194200
cmd_name = command.split()[0]
@@ -218,13 +224,16 @@ def audit_task(self, task):
218224
}
219225
entity_id = f"uid:{gen_uuid()}"
220226
entity_message = {
221-
"@id": entity_id,
227+
"@id": entity_id,
222228
"Label": print(entity_label),
223-
"AtLocation": input_path,
224-
"GeneratedBy": "test",
229+
"AtLocation": input_paths, #
230+
"GeneratedBy": "test",
225231
"@type": "input",
226-
"digest": file_hash,
232+
"digest": input_paths_hash
227233
}
228234

235+
236+
229237
self.audit_message(start_message, AuditFlag.PROV)
230238
self.audit_message(entity_message, AuditFlag.PROV)
239+

pydra/engine/tests/test_task.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,8 +1008,6 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
10081008
funky()
10091009
message_path = tmpdir / funky.checksum / "messages"
10101010
print(message_path)
1011-
# go through each jsonld file in message_path and check if the label field exists
1012-
json_content = []
10131011

10141012
for file in glob(str(message_path) + "/*.jsonld"):
10151013
with open(file, "r") as f:
@@ -1023,7 +1021,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
10231021
assert None == data["Label"]
10241022
# placeholder for atlocation until
10251023
# new test is added
1026-
assert None == data["AtLocation"]
1024+
assert [] == data["AtLocation"]
10271025

10281026
# assert data["Type"] == "input"
10291027

@@ -1072,13 +1070,19 @@ def test_audit_shellcommandtask(tmpdir):
10721070

10731071

10741072
def test_audit_shellcommandtask_file(tmpdir):
1073+
import shutil
10751074
# create test.txt file with "This is a test" in it in the tmpdir
10761075
with open(tmpdir / "test.txt", "w") as f:
10771076
f.write("This is a test.")
1077+
# make a copy of the test.txt file in the tmpdir and name it test2.txt
1078+
shutil.copy(tmpdir / "test.txt", tmpdir / "test2.txt")
1079+
10781080

10791081
cmd = "cat"
10801082
file_in = tmpdir / "test.txt"
1083+
file_in_2 = tmpdir / "test2.txt"
10811084
test_file_hash = hash_file(file_in)
1085+
test_file_hash_2 = hash_file(file_in_2)
10821086
my_input_spec = SpecInfo(
10831087
name="Input",
10841088
fields=[
@@ -1093,13 +1097,26 @@ def test_audit_shellcommandtask_file(tmpdir):
10931097
"mandatory": True,
10941098
},
10951099
),
1100+
),
1101+
(
1102+
"in_file_2",
1103+
attr.ib(
1104+
type=File,
1105+
metadata={
1106+
"position": 2,
1107+
"argstr": "",
1108+
"help_string": "text",
1109+
"mandatory": True,
1110+
},
1111+
),
10961112
)
10971113
],
10981114
bases=(ShellSpec,),
10991115
)
11001116
shelly = ShellCommandTask(
11011117
name="shelly",
11021118
in_file=file_in,
1119+
in_file_2=file_in_2,
11031120
input_spec=my_input_spec,
11041121
executable=cmd,
11051122
audit_flags=AuditFlag.PROV,
@@ -1113,9 +1130,9 @@ def test_audit_shellcommandtask_file(tmpdir):
11131130
data = json.load(f)
11141131
print(file_in)
11151132
if "AtLocation" in data:
1116-
assert data["AtLocation"] == str(file_in)
1133+
assert data["AtLocation"] == [file_in, file_in_2]
11171134
if "digest" in data:
1118-
assert test_file_hash == data["digest"]
1135+
assert data["digest"] == [test_file_hash, test_file_hash_2]
11191136

11201137

11211138
def test_audit_shellcommandtask_version(tmpdir):

0 commit comments

Comments
 (0)