Skip to content

Commit 54bb764

Browse files
committed
Merge branch 'PROV_io' of https://github.com/rcali21/pydra
2 parents f72193e + 0759ab1 commit 54bb764

File tree

2 files changed

+173
-20
lines changed

2 files changed

+173
-20
lines changed

pydra/engine/audit.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import attr
66
from ..utils.messenger import send_message, make_message, gen_uuid, now, AuditFlag
7-
from .helpers import ensure_list, gather_runtime_info
7+
from .helpers import ensure_list, gather_runtime_info, hash_file
88

99

1010
class Audit:
@@ -170,22 +170,65 @@ def audit_check(self, flag):
170170
return self.audit_flags & flag
171171

172172
def audit_task(self, task):
173+
import subprocess as sp
174+
173175
label = task.name
176+
entity_label = type(label)
177+
174178
if hasattr(task.inputs, "executable"):
175179
command = task.cmdline
176180
# assume function task
177181
else:
178-
# work on changing this to function name
179182
command = None
180183

184+
if hasattr(task.inputs, "in_file"):
185+
input_file = task.inputs.in_file
186+
file_hash = hash_file(input_file)
187+
at_location = os.path.abspath(input_file)
188+
else:
189+
file_hash = None
190+
at_location = None
191+
input_file = None
192+
193+
if command is not None:
194+
cmd_name = command.split()[0]
195+
software = f"{cmd_name} --version"
196+
# take the first word of command as the
197+
# name of the executable
198+
# (this may not always be the case)
199+
version_cmd = sp.run(software, shell=True, stdout=sp.PIPE).stdout.decode(
200+
"utf-8"
201+
)
202+
try:
203+
version_cmd = version_cmd.splitlines()[0]
204+
205+
except IndexError:
206+
version_cmd = f"{cmd_name} -- Version unknown"
207+
208+
else:
209+
version_cmd = None
210+
181211
start_message = {
182212
"@id": self.aid,
183213
"@type": "task",
184-
"label": label,
185-
"command": command,
186-
"startedAtTime": now(),
214+
"Label": label,
215+
"Command": command,
216+
"StartedAtTime": now(),
217+
"AssociatedWith": version_cmd,
187218
}
188-
self.audit_message(start_message, AuditFlag.PROV)
189219

220+
entity_message = {
221+
"@id": self.aid,
222+
"Label": print(entity_label),
223+
"AtLocation": at_location,
224+
"GeneratedBy": "test", # if not part of workflow, this will be none
225+
"@type": "input",
226+
"digest": file_hash, # hash value under helpers.py
227+
}
228+
229+
# new code to be added here for i/o tracking - WIP
230+
231+
self.audit_message(start_message, AuditFlag.PROV)
232+
self.audit_message(entity_message, AuditFlag.PROV)
190233
# add more fields according to BEP208 doc
191234
# with every field, check in tests

pydra/engine/tests/test_task.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,17 @@
1111
from ..core import Workflow
1212
from ..task import AuditFlag, ShellCommandTask, DockerTask, SingularityTask
1313
from ...utils.messenger import FileMessenger, PrintMessenger, collect_messages
14-
from .utils import gen_basic_wf, use_validator
15-
from ..specs import MultiInputObj, MultiOutputObj, SpecInfo, FunctionSpec, BaseSpec
14+
from .utils import gen_basic_wf, use_validator, Submitter
15+
from ..specs import (
16+
MultiInputObj,
17+
MultiOutputObj,
18+
SpecInfo,
19+
FunctionSpec,
20+
BaseSpec,
21+
ShellSpec,
22+
File,
23+
)
24+
from ..helpers import hash_file
1625

1726
no_win = pytest.mark.skipif(
1827
sys.platform.startswith("win"),
@@ -998,15 +1007,30 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
9981007
funky.cache_dir = tmpdir
9991008
funky()
10001009
message_path = tmpdir / funky.checksum / "messages"
1010+
print(message_path)
10011011
# go through each jsonld file in message_path and check if the label field exists
10021012
json_content = []
1013+
10031014
for file in glob(str(message_path) + "/*.jsonld"):
10041015
with open(file, "r") as f:
10051016
data = json.load(f)
1006-
if "label" in data:
1007-
json_content.append(True)
1008-
assert "testfunc" == data["label"]
1009-
assert any(json_content)
1017+
if "@type" in data:
1018+
if "AssociatedWith" in data:
1019+
assert "testfunc" in data["Label"]
1020+
1021+
if "@type" in data:
1022+
if data["@type"] == "input":
1023+
assert None == data["Label"]
1024+
# placeholder for atlocation until
1025+
# new test is added
1026+
assert None == data["AtLocation"]
1027+
1028+
# assert data["Type"] == "input"
1029+
1030+
if "AssociatedWith" in data:
1031+
assert None == data["AssociatedWith"]
1032+
1033+
# assert any(json_content)
10101034

10111035

10121036
def test_audit_shellcommandtask(tmpdir):
@@ -1025,20 +1049,106 @@ def test_audit_shellcommandtask(tmpdir):
10251049
shelly()
10261050
message_path = tmpdir / shelly.checksum / "messages"
10271051
# go through each jsonld file in message_path and check if the label field exists
1028-
label_content = []
1052+
10291053
command_content = []
10301054

10311055
for file in glob(str(message_path) + "/*.jsonld"):
10321056
with open(file, "r") as f:
10331057
data = json.load(f)
1034-
if "label" in data:
1035-
label_content.append(True)
1036-
if "command" in data:
1058+
1059+
if "@type" in data:
1060+
if "AssociatedWith" in data:
1061+
assert "shelly" in data["Label"]
1062+
1063+
if "@type" in data:
1064+
if data["@type"] == "input":
1065+
assert data["Label"] == None
1066+
1067+
if "Command" in data:
10371068
command_content.append(True)
1038-
assert "ls -l" == data["command"]
1069+
assert "ls -l" == data["Command"]
1070+
1071+
assert any(command_content)
1072+
1073+
1074+
def test_audit_shellcommandtask_file(tmpdir):
1075+
# create test.txt file with "This is a test" in it in the tmpdir
1076+
with open(tmpdir / "test.txt", "w") as f:
1077+
f.write("This is a test.")
1078+
1079+
cmd = "cat"
1080+
file_in = tmpdir / "test.txt"
1081+
test_file_hash = hash_file(file_in)
1082+
my_input_spec = SpecInfo(
1083+
name="Input",
1084+
fields=[
1085+
(
1086+
"in_file",
1087+
attr.ib(
1088+
type=File,
1089+
metadata={
1090+
"position": 1,
1091+
"argstr": "",
1092+
"help_string": "text",
1093+
"mandatory": True,
1094+
},
1095+
),
1096+
)
1097+
],
1098+
bases=(ShellSpec,),
1099+
)
1100+
shelly = ShellCommandTask(
1101+
name="shelly",
1102+
in_file=file_in,
1103+
input_spec=my_input_spec,
1104+
executable=cmd,
1105+
audit_flags=AuditFlag.PROV,
1106+
messengers=PrintMessenger(),
1107+
)
1108+
shelly.cache_dir = tmpdir
1109+
shelly()
1110+
message_path = tmpdir / shelly.checksum / "messages"
1111+
for file in glob.glob(str(message_path) + "/*.jsonld"):
1112+
with open(file, "r") as f:
1113+
data = json.load(f)
1114+
print(file_in)
1115+
if "AtLocation" in data:
1116+
assert data["AtLocation"] == str(file_in)
1117+
if "digest" in data:
1118+
assert test_file_hash == data["digest"]
1119+
1120+
1121+
def test_audit_shellcommandtask_version(tmpdir):
1122+
import subprocess as sp
1123+
1124+
version_cmd = sp.run("less --version", shell=True, stdout=sp.PIPE).stdout.decode(
1125+
"utf-8"
1126+
)
1127+
version_cmd = version_cmd.splitlines()[0]
1128+
cmd = "less"
1129+
shelly = ShellCommandTask(
1130+
name="shelly",
1131+
executable=cmd,
1132+
args="test_task.py",
1133+
audit_flags=AuditFlag.PROV,
1134+
messengers=FileMessenger(),
1135+
)
1136+
1137+
import glob
1138+
1139+
shelly.cache_dir = tmpdir
1140+
shelly()
1141+
message_path = tmpdir / shelly.checksum / "messages"
1142+
# go through each jsonld file in message_path and check if the label field exists
1143+
version_content = []
1144+
for file in glob.glob(str(message_path) + "/*.jsonld"):
1145+
with open(file, "r") as f:
1146+
data = json.load(f)
1147+
if "AssociatedWith" in data:
1148+
if version_cmd in data["AssociatedWith"]:
1149+
version_content.append(True)
10391150

1040-
print(command_content)
1041-
assert any(label_content)
1151+
assert any(version_content)
10421152

10431153

10441154
def test_audit_prov_messdir_1(tmpdir, use_validator):
@@ -1137,7 +1247,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
11371247
from glob import glob
11381248

11391249
assert len(glob(str(tmpdir / funky.checksum / "proc*.log"))) == 1
1140-
assert len(glob(str(message_path / "*.jsonld"))) == 7
1250+
assert len(glob(str(message_path / "*.jsonld"))) == 8
11411251

11421252
# commented out to speed up testing
11431253
collect_messages(tmpdir / funky.checksum, message_path, ld_op="compact")

0 commit comments

Comments
 (0)