Skip to content

Commit e2a2f91

Browse files
committed
feat: Partially implement FutureClient/JobStateUpdat
1 parent 3d76331 commit e2a2f91

File tree

3 files changed

+200
-37
lines changed

3 files changed

+200
-37
lines changed
Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,132 @@
1+
import functools
2+
from datetime import datetime, timezone
3+
4+
15
from DIRAC.Core.Security.DiracX import DiracXClient
26
from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue
7+
from DIRAC.Core.Utilities.TimeUtilities import fromString
8+
9+
10+
def stripValueIfOK(func):
11+
"""Decorator to remove S_OK["Value"] from the return value of a function if it is OK.
12+
13+
This is done as some update functions return the number of modified rows in
14+
the database. This likely not actually useful so it isn't supported in
15+
DiracX. Stripping the "Value" key of the dictionary means that we should
16+
get a fairly straight forward error if the assumption is incorrect.
17+
"""
18+
19+
@functools.wraps(func)
20+
def wrapper(*args, **kwargs):
21+
result = func(*args, **kwargs)
22+
if result.get("OK"):
23+
assert result.pop("Value") is None, "Value should be None if OK"
24+
return result
25+
26+
return wrapper
327

428

529
class JobStateUpdateClient:
30+
@stripValueIfOK
31+
@convertToReturnValue
632
def sendHeartBeat(self, jobID: str | int, dynamicData: dict, staticData: dict):
7-
raise NotImplementedError("TODO")
33+
print("HACK: This is a no-op until we decide what to do")
834

35+
@stripValueIfOK
36+
@convertToReturnValue
937
def setJobApplicationStatus(self, jobID: str | int, appStatus: str, source: str = "Unknown"):
10-
raise NotImplementedError("TODO")
38+
statusDict = {
39+
"application_status": appStatus,
40+
}
41+
if source:
42+
statusDict["Source"] = source
43+
with DiracXClient() as api:
44+
api.jobs.set_single_job_status(
45+
jobID,
46+
{datetime.now(tz=timezone.utc): statusDict},
47+
)
1148

49+
@stripValueIfOK
50+
@convertToReturnValue
1251
def setJobAttribute(self, jobID: str | int, attribute: str, value: str):
1352
with DiracXClient() as api:
14-
api.jobs.set_single_job_properties(jobID, "need to [patch the client to have a nice summer body ?")
15-
raise NotImplementedError("TODO")
53+
if attribute == "Status":
54+
api.jobs.set_single_job_status(
55+
jobID,
56+
{datetime.now(tz=timezone.utc): {"status": value}},
57+
)
58+
else:
59+
api.jobs.set_single_job_properties(jobID, {attribute: value})
1660

61+
@stripValueIfOK
62+
@convertToReturnValue
1763
def setJobFlag(self, jobID: str | int, flag: str):
18-
raise NotImplementedError("TODO")
64+
with DiracXClient() as api:
65+
api.jobs.set_single_job_properties(jobID, {flag: True})
1966

67+
@stripValueIfOK
68+
@convertToReturnValue
2069
def setJobParameter(self, jobID: str | int, name: str, value: str):
21-
raise NotImplementedError("TODO")
70+
print("HACK: This is a no-op until we decide what to do")
2271

72+
@stripValueIfOK
73+
@convertToReturnValue
2374
def setJobParameters(self, jobID: str | int, parameters: list):
24-
raise NotImplementedError("TODO")
75+
print("HACK: This is a no-op until we decide what to do")
2576

77+
@stripValueIfOK
78+
@convertToReturnValue
2679
def setJobSite(self, jobID: str | int, site: str):
27-
raise NotImplementedError("TODO")
80+
with DiracXClient() as api:
81+
api.jobs.set_single_job_properties(jobID, {"Site": site})
2882

83+
@stripValueIfOK
84+
@convertToReturnValue
2985
def setJobStatus(
3086
self,
3187
jobID: str | int,
3288
status: str = "",
3389
minorStatus: str = "",
3490
source: str = "Unknown",
35-
datetime=None,
91+
datetime_=None,
3692
force=False,
3793
):
38-
raise NotImplementedError("TODO")
94+
statusDict = {}
95+
if status:
96+
statusDict["Status"] = status
97+
if minorStatus:
98+
statusDict["MinorStatus"] = minorStatus
99+
if source:
100+
statusDict["Source"] = source
101+
if datetime_ is None:
102+
datetime_ = datetime.utcnow()
103+
with DiracXClient() as api:
104+
api.jobs.set_single_job_status(
105+
jobID,
106+
{fromString(datetime_).replace(tzinfo=timezone.utc): statusDict},
107+
force=force,
108+
)
39109

110+
@stripValueIfOK
111+
@convertToReturnValue
40112
def setJobStatusBulk(self, jobID: str | int, statusDict: dict, force=False):
41-
raise NotImplementedError("TODO")
113+
statusDict = {fromString(k).replace(tzinfo=timezone.utc): v for k, v in statusDict.items()}
114+
with DiracXClient() as api:
115+
api.jobs.set_job_status_bulk(
116+
{jobID: statusDict},
117+
force=force,
118+
)
42119

120+
@stripValueIfOK
121+
@convertToReturnValue
43122
def setJobsParameter(self, jobsParameterDict: dict):
44-
raise NotImplementedError("TODO")
123+
print("HACK: This is a no-op until we decide what to do")
45124

125+
@stripValueIfOK
126+
@convertToReturnValue
46127
def unsetJobFlag(self, jobID: str | int, flag: str):
47-
raise NotImplementedError("TODO")
128+
with DiracXClient() as api:
129+
api.jobs.set_single_job_properties(jobID, {flag: False})
48130

49131
def updateJobFromStager(self, jobID: str | int, status: str):
50132
raise NotImplementedError("TODO")

tests/Integration/FutureClient/WorkloadManagement/Test_JobStateUpdate.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,45 @@
1+
from datetime import datetime
12
from functools import partial
3+
from textwrap import dedent
24

35
import pytest
46

57
import DIRAC
68

79
DIRAC.initialize()
10+
from DIRAC.Core.Security.DiracX import DiracXClient
811
from DIRAC.WorkloadManagementSystem.Client.JobStateUpdateClient import JobStateUpdateClient
9-
from ..utils import compare_results
12+
from ..utils import compare_results2
13+
14+
test_jdl = """
15+
Arguments = "Hello world from DiracX";
16+
Executable = "echo";
17+
JobGroup = jobGroup;
18+
JobName = jobName;
19+
JobType = User;
20+
LogLevel = INFO;
21+
MinNumberOfProcessors = 1000;
22+
OutputSandbox =
23+
{
24+
std.err,
25+
std.out
26+
};
27+
Priority = 1;
28+
Sites = ANY;
29+
StdError = std.err;
30+
StdOutput = std.out;
31+
"""
32+
33+
34+
@pytest.fixture()
35+
def example_jobids():
36+
from DIRAC.Interfaces.API.Dirac import Dirac
37+
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
38+
39+
d = Dirac()
40+
job_id_1 = returnValueOrRaise(d.submitJob(test_jdl))
41+
job_id_2 = returnValueOrRaise(d.submitJob(test_jdl))
42+
return job_id_1, job_id_2
1043

1144

1245
def test_sendHeartBeat(monkeypatch):
@@ -15,16 +48,22 @@ def test_sendHeartBeat(monkeypatch):
1548
pytest.skip()
1649

1750

18-
def test_setJobApplicationStatus(monkeypatch):
51+
def test_setJobApplicationStatus(monkeypatch, example_jobids):
1952
# JobStateUpdateClient().setJobApplicationStatus(jobID: str | int, appStatus: str, source: str = Unknown)
2053
method = JobStateUpdateClient().setJobApplicationStatus
21-
pytest.skip()
54+
args = ["MyApplicationStatus"]
55+
test_func1 = partial(method, example_jobids[0], *args)
56+
test_func2 = partial(method, example_jobids[1], *args)
57+
compare_results2(monkeypatch, test_func1, test_func2)
2258

2359

24-
def test_setJobAttribute(monkeypatch):
60+
@pytest.mark.parametrize("args", [["Status", "Killed"], ["JobGroup", "newJobGroup"]])
61+
def test_setJobAttribute(monkeypatch, example_jobids, args):
2562
# JobStateUpdateClient().setJobAttribute(jobID: str | int, attribute: str, value: str)
2663
method = JobStateUpdateClient().setJobAttribute
27-
pytest.skip()
64+
test_func1 = partial(method, example_jobids[0], *args)
65+
test_func2 = partial(method, example_jobids[1], *args)
66+
compare_results2(monkeypatch, test_func1, test_func2)
2867

2968

3069
def test_setJobFlag(monkeypatch):
@@ -45,22 +84,37 @@ def test_setJobParameters(monkeypatch):
4584
pytest.skip()
4685

4786

48-
def test_setJobSite(monkeypatch):
87+
@pytest.mark.parametrize("jobid_type", [int, str])
88+
def test_setJobSite(monkeypatch, example_jobids, jobid_type):
4989
# JobStateUpdateClient().setJobSite(jobID: str | int, site: str)
5090
method = JobStateUpdateClient().setJobSite
51-
pytest.skip()
91+
args = ["LCG.CERN.ch"]
92+
test_func1 = partial(method, jobid_type(example_jobids[0]), *args)
93+
test_func2 = partial(method, jobid_type(example_jobids[1]), *args)
94+
compare_results2(monkeypatch, test_func1, test_func2)
5295

5396

54-
def test_setJobStatus(monkeypatch):
97+
def test_setJobStatus(monkeypatch, example_jobids):
5598
# JobStateUpdateClient().setJobStatus(jobID: str | int, status: str = , minorStatus: str = , source: str = Unknown, datetime = None, force = False)
5699
method = JobStateUpdateClient().setJobStatus
57-
pytest.skip()
100+
args = ["", "My Minor"]
101+
test_func1 = partial(method, example_jobids[0], *args)
102+
test_func2 = partial(method, example_jobids[1], *args)
103+
compare_results2(monkeypatch, test_func1, test_func2)
58104

59105

60-
def test_setJobStatusBulk(monkeypatch):
106+
def test_setJobStatusBulk(monkeypatch, example_jobids):
61107
# JobStateUpdateClient().setJobStatusBulk(jobID: str | int, statusDict: dict, force = False)
62108
method = JobStateUpdateClient().setJobStatusBulk
63-
pytest.skip()
109+
args = [
110+
{
111+
datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "SomethingElse"},
112+
datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f"): {"ApplicationStatus": "Something"},
113+
}
114+
]
115+
test_func1 = partial(method, example_jobids[0], *args)
116+
test_func2 = partial(method, example_jobids[1], *args)
117+
compare_results2(monkeypatch, test_func1, test_func2)
64118

65119

66120
def test_setJobsParameter(monkeypatch):
Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,47 @@
1-
def compare_results(test_func):
1+
import time
2+
3+
4+
def compare_results(monkeypatch, test_func):
25
"""Compare the results from DIRAC and DiracX based services for a reentrant function."""
3-
ClientClass = test_func.func.__self__
4-
assert ClientClass.diracxClient, "FutureClient is not set up!"
6+
compare_results2(monkeypatch, test_func, test_func)
7+
58

9+
def compare_results2(monkeypatch, test_func1, test_func2):
10+
"""Compare the results from DIRAC and DiracX based services for two functions which should behave identically."""
611
# Get the result from the diracx-based handler
7-
future_result = test_func()
12+
start = time.monotonic()
13+
with monkeypatch.context() as m:
14+
m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: True)
15+
try:
16+
future_result = test_func1()
17+
except Exception as e:
18+
future_result = e
19+
else:
20+
assert "rpcStub" not in future_result, "rpcStub should never be present when using DiracX!"
21+
diracx_duration = time.monotonic() - start
822

923
# Get the result from the DIRAC-based handler
10-
diracxClient = ClientClass.diracxClient
11-
ClientClass.diracxClient = None
12-
try:
13-
old_result = test_func()
14-
finally:
15-
ClientClass.diracxClient = diracxClient
16-
# We don't care about the rpcStub
24+
start = time.monotonic()
25+
with monkeypatch.context() as m:
26+
m.setattr("DIRAC.Core.Tornado.Client.ClientSelector.useLegacyAdapter", lambda *_: False)
27+
old_result = test_func2()
28+
assert "rpcStub" in old_result, "rpcStub should always be present when using legacy DIRAC!"
29+
legacy_duration = time.monotonic() - start
30+
31+
# We don't care about the rpcStub or Errno
1732
old_result.pop("rpcStub")
33+
old_result.pop("Errno", None)
34+
35+
if not old_result["OK"]:
36+
assert not future_result["OK"], "FutureClient should have failed too!"
37+
elif "Value" in future_result:
38+
# Ensure the results match exactly
39+
assert old_result == future_result
40+
else:
41+
# See the "stripValueIfOK" decorator for explanation
42+
assert old_result["OK"] == future_result["OK"]
43+
# assert isinstance(old_result["Value"], int)
1844

19-
# Ensure the results match
20-
assert old_result == future_result
45+
# if 3 * legacy_duration < diracx_duration:
46+
# print(f"Legacy DIRAC took {legacy_duration:.3f}s, FutureClient took {diracx_duration:.3f}s")
47+
# assert False, "FutureClient should be faster than legacy DIRAC!"

0 commit comments

Comments
 (0)