Skip to content

Commit 64916f4

Browse files
authored
Merge pull request #556 from rcali21/PROV-test
WIP: adding task input info
2 parents 86671ad + 40200b3 commit 64916f4

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

pydra/engine/audit.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,15 @@ def start_audit(self, odir):
4949
self.odir = odir
5050
if self.audit_check(AuditFlag.PROV):
5151
self.aid = f"uid:{gen_uuid()}"
52-
start_message = {"@id": self.aid, "@type": "task", "startedAtTime": now()}
52+
53+
user_id = f"uid:{gen_uuid()}"
54+
start_message = {
55+
"@id": self.aid,
56+
"@type": "task",
57+
"startedAtTime": now(),
58+
"executedBy": user_id,
59+
}
60+
5361
os.chdir(self.odir)
5462
if self.audit_check(AuditFlag.PROV):
5563
self.audit_message(start_message, AuditFlag.PROV)
@@ -160,3 +168,24 @@ def audit_check(self, flag):
160168
Boolean AND for self.oudit_flags and flag
161169
"""
162170
return self.audit_flags & flag
171+
172+
def audit_task(self, task):
173+
label = task.name
174+
if hasattr(task.inputs, "executable"):
175+
command = task.cmdline
176+
# assume function task
177+
else:
178+
# work on changing this to function name
179+
command = None
180+
181+
start_message = {
182+
"@id": self.aid,
183+
"@type": "task",
184+
"label": label,
185+
"command": command,
186+
"startedAtTime": now(),
187+
}
188+
self.audit_message(start_message, AuditFlag.PROV)
189+
190+
# add more fields according to BEP208 doc
191+
# with every field, check in tests

pydra/engine/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ def _run(self, rerun=False, **kwargs):
503503
result = Result(output=None, runtime=None, errored=False)
504504
self.hooks.pre_run_task(self)
505505
self.audit.start_audit(odir=output_dir)
506+
if self.audit.audit_check(AuditFlag.PROV):
507+
self.audit.audit_task(task=self)
506508
try:
507509
self.audit.monitor()
508510
self._run_task()

pydra/engine/tests/test_task.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import cloudpickle as cp
66
from pathlib import Path
77
import re
8-
8+
import json
9+
import glob as glob
910
from ... import mark
1011
from ..core import Workflow
1112
from ..task import AuditFlag, ShellCommandTask, DockerTask, SingularityTask
@@ -986,6 +987,60 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
986987
assert (tmpdir / funky.checksum / "messages.jsonld").exists()
987988

988989

990+
def test_audit_task(tmpdir):
991+
@mark.task
992+
def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]):
993+
return a + b
994+
995+
from glob import glob
996+
997+
funky = testfunc(a=2, audit_flags=AuditFlag.PROV, messengers=FileMessenger())
998+
funky.cache_dir = tmpdir
999+
funky()
1000+
message_path = tmpdir / funky.checksum / "messages"
1001+
# go through each jsonld file in message_path and check if the label field exists
1002+
json_content = []
1003+
for file in glob(str(message_path) + "/*.jsonld"):
1004+
with open(file, "r") as f:
1005+
data = json.load(f)
1006+
if "label" in data:
1007+
json_content.append(True)
1008+
assert "testfunc" == data["label"]
1009+
assert any(json_content)
1010+
1011+
1012+
def test_audit_shellcommandtask(tmpdir):
1013+
args = "-l"
1014+
shelly = ShellCommandTask(
1015+
name="shelly",
1016+
executable="ls",
1017+
args=args,
1018+
audit_flags=AuditFlag.PROV,
1019+
messengers=FileMessenger(),
1020+
)
1021+
1022+
from glob import glob
1023+
1024+
shelly.cache_dir = tmpdir
1025+
shelly()
1026+
message_path = tmpdir / shelly.checksum / "messages"
1027+
# go through each jsonld file in message_path and check if the label field exists
1028+
label_content = []
1029+
command_content = []
1030+
1031+
for file in glob(str(message_path) + "/*.jsonld"):
1032+
with open(file, "r") as f:
1033+
data = json.load(f)
1034+
if "label" in data:
1035+
label_content.append(True)
1036+
if "command" in data:
1037+
command_content.append(True)
1038+
assert "ls -l" == data["command"]
1039+
1040+
print(command_content)
1041+
assert any(label_content)
1042+
1043+
9891044
def test_audit_prov_messdir_1(tmpdir, use_validator):
9901045
"""customized messenger dir"""
9911046

@@ -1082,7 +1137,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out", float)]
10821137
from glob import glob
10831138

10841139
assert len(glob(str(tmpdir / funky.checksum / "proc*.log"))) == 1
1085-
assert len(glob(str(message_path / "*.jsonld"))) == 6
1140+
assert len(glob(str(message_path / "*.jsonld"))) == 7
10861141

10871142
# commented out to speed up testing
10881143
collect_messages(tmpdir / funky.checksum, message_path, ld_op="compact")

0 commit comments

Comments
 (0)