Skip to content

Commit cd2a137

Browse files
committed
2 parents 54bb764 + a5dfe42 commit cd2a137

File tree

2 files changed

+60
-50
lines changed

2 files changed

+60
-50
lines changed

pydra/engine/audit.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import attr
66
from ..utils.messenger import send_message, make_message, gen_uuid, now, AuditFlag
77
from .helpers import ensure_list, gather_runtime_info, hash_file
8+
from .specs import attr_fields, File, Directory
89

910

1011
class Audit:
@@ -173,22 +174,24 @@ def audit_task(self, task):
173174
import subprocess as sp
174175

175176
label = task.name
176-
entity_label = type(label)
177177

178-
if hasattr(task.inputs, "executable"):
179-
command = task.cmdline
180-
# assume function task
181-
else:
182-
command = None
183-
184-
if hasattr(task.inputs, "in_file"):
185-
input_file = task.inputs.in_file
186-
file_hash = hash_file(input_file)
187-
at_location = os.path.abspath(input_file)
188-
else:
189-
file_hash = None
190-
at_location = None
191-
input_file = None
178+
command = task.cmdline if hasattr(task.inputs, "executable") else None
179+
attr_list = attr_fields(task.inputs)
180+
for attrs in attr_list:
181+
if attrs.type in [File, Directory]:
182+
input_name = attrs.name
183+
input_path = os.path.abspath(getattr(task.inputs, input_name))
184+
file_hash = hash_file(input_path)
185+
entity_id = f"uid:{gen_uuid()}"
186+
entity_message = {
187+
"@id": entity_id,
188+
"Label": input_name,
189+
"AtLocation": input_path,
190+
"GeneratedBy": None,
191+
"@type": "input",
192+
"digest": file_hash,
193+
}
194+
self.audit_message(entity_message, AuditFlag.PROV)
192195

193196
if command is not None:
194197
cmd_name = command.split()[0]
@@ -217,18 +220,4 @@ def audit_task(self, task):
217220
"AssociatedWith": version_cmd,
218221
}
219222

220-
entity_message = {
221-
"@id": self.aid,
222-
"Label": print(entity_label),
223-
"AtLocation": at_location,
224-
"GeneratedBy": "test", # if not part of workflow, this will be none
225-
"@type": "input",
226-
"digest": file_hash, # hash value under helpers.py
227-
}
228-
229-
# new code to be added here for i/o tracking - WIP
230-
231223
self.audit_message(start_message, AuditFlag.PROV)
232-
self.audit_message(entity_message, AuditFlag.PROV)
233-
# add more fields according to BEP208 doc
234-
# with every field, check in tests

pydra/engine/tests/test_task.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,9 +1007,6 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
10071007
funky.cache_dir = tmpdir
10081008
funky()
10091009
message_path = tmpdir / funky.checksum / "messages"
1010-
print(message_path)
1011-
# go through each jsonld file in message_path and check if the label field exists
1012-
json_content = []
10131010

10141011
for file in glob(str(message_path) + "/*.jsonld"):
10151012
with open(file, "r") as f:
@@ -1021,12 +1018,6 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
10211018
if "@type" in data:
10221019
if data["@type"] == "input":
10231020
assert None == data["Label"]
1024-
# placeholder for atlocation until
1025-
# new test is added
1026-
assert None == data["AtLocation"]
1027-
1028-
# assert data["Type"] == "input"
1029-
10301021
if "AssociatedWith" in data:
10311022
assert None == data["AssociatedWith"]
10321023

@@ -1072,13 +1063,27 @@ def test_audit_shellcommandtask(tmpdir):
10721063

10731064

10741065
def test_audit_shellcommandtask_file(tmpdir):
1066+
# sourcery skip: use-fstring-for-concatenation
1067+
import glob
1068+
import shutil
1069+
10751070
# create test.txt file with "This is a test" in it in the tmpdir
1076-
with open(tmpdir / "test.txt", "w") as f:
1077-
f.write("This is a test.")
1071+
# create txt file in cwd
1072+
with open("test.txt", "w") as f:
1073+
f.write("This is a test")
1074+
1075+
with open("test2.txt", "w") as f:
1076+
f.write("This is a test")
1077+
1078+
# copy the test.txt file to the tmpdir
1079+
shutil.copy("test.txt", tmpdir)
1080+
shutil.copy("test2.txt", tmpdir)
10781081

10791082
cmd = "cat"
10801083
file_in = tmpdir / "test.txt"
1084+
file_in_2 = tmpdir / "test2.txt"
10811085
test_file_hash = hash_file(file_in)
1086+
test_file_hash_2 = hash_file(file_in_2)
10821087
my_input_spec = SpecInfo(
10831088
name="Input",
10841089
fields=[
@@ -1093,29 +1098,45 @@ def test_audit_shellcommandtask_file(tmpdir):
10931098
"mandatory": True,
10941099
},
10951100
),
1096-
)
1101+
),
1102+
(
1103+
"in_file_2",
1104+
attr.ib(
1105+
type=File,
1106+
metadata={
1107+
"position": 2,
1108+
"argstr": "",
1109+
"help_string": "text",
1110+
"mandatory": True,
1111+
},
1112+
),
1113+
),
10971114
],
10981115
bases=(ShellSpec,),
10991116
)
11001117
shelly = ShellCommandTask(
11011118
name="shelly",
11021119
in_file=file_in,
1120+
in_file_2=file_in_2,
11031121
input_spec=my_input_spec,
11041122
executable=cmd,
11051123
audit_flags=AuditFlag.PROV,
1106-
messengers=PrintMessenger(),
1124+
messengers=FileMessenger(),
11071125
)
11081126
shelly.cache_dir = tmpdir
11091127
shelly()
11101128
message_path = tmpdir / shelly.checksum / "messages"
11111129
for file in glob.glob(str(message_path) + "/*.jsonld"):
1112-
with open(file, "r") as f:
1113-
data = json.load(f)
1114-
print(file_in)
1115-
if "AtLocation" in data:
1116-
assert data["AtLocation"] == str(file_in)
1117-
if "digest" in data:
1118-
assert test_file_hash == data["digest"]
1130+
with open(file, "r") as x:
1131+
data = json.load(x)
1132+
if "@type" in data:
1133+
if data["@type"] == "input":
1134+
if data["Label"] == "in_file":
1135+
assert data["AtLocation"] == str(file_in)
1136+
assert data["digest"] == test_file_hash
1137+
if data["Label"] == "in_file_2":
1138+
assert data["AtLocation"] == str(file_in_2)
1139+
assert data["digest"] == test_file_hash_2
11191140

11201141

11211142
def test_audit_shellcommandtask_version(tmpdir):
@@ -1247,7 +1268,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
12471268
from glob import glob
12481269

12491270
assert len(glob(str(tmpdir / funky.checksum / "proc*.log"))) == 1
1250-
assert len(glob(str(message_path / "*.jsonld"))) == 8
1271+
assert len(glob(str(message_path / "*.jsonld"))) == 7
12511272

12521273
# commented out to speed up testing
12531274
collect_messages(tmpdir / funky.checksum, message_path, ld_op="compact")

0 commit comments

Comments
 (0)