Skip to content

Commit 2bc6401

Browse files
committed
feat(Resources): introduce fabric in SSHCE
1 parent cbb4b91 commit 2bc6401

File tree

4 files changed

+431
-452
lines changed

4 files changed

+431
-452
lines changed

environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ dependencies:
1818
- cwltool
1919
- db12
2020
- opensearch-py
21+
- fabric
2122
- fts3
2223
- gitpython >=2.1.0
24+
- invoke
2325
- m2crypto >=0.38.0
2426
- matplotlib
2527
- numpy
28+
- paramiko
2629
- pexpect >=4.0.1
2730
- pillow
2831
- prompt-toolkit >=3,<4

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@ install_requires =
3535
diracx-core >=v0.0.1
3636
diracx-cli >=v0.0.1
3737
db12
38+
fabric
3839
fts3
3940
gfal2-python
4041
importlib_metadata >=4.4
4142
importlib_resources
43+
invoke
4244
M2Crypto >=0.36
4345
packaging
46+
paramiko
4447
pexpect
4548
prompt-toolkit >=3
4649
psutil
Lines changed: 79 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs
1+
""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs
22
directly through ssh
33
"""
44

@@ -12,64 +12,77 @@
1212

1313

1414
class SSHBatchComputingElement(SSHComputingElement):
15-
#############################################################################
1615
def __init__(self, ceUniqueID):
1716
"""Standard constructor."""
1817
super().__init__(ceUniqueID)
1918

20-
self.ceType = "SSHBatch"
21-
self.sshHost = []
19+
self.connections = {}
2220
self.execution = "SSHBATCH"
2321

2422
def _reset(self):
2523
"""Process CE parameters and make necessary adjustments"""
24+
# Get the Batch System instance
2625
result = self._getBatchSystem()
2726
if not result["OK"]:
2827
return result
28+
29+
# Get the location of the remote directories
2930
self._getBatchSystemDirectoryLocations()
3031

31-
self.user = self.ceParameters["SSHUser"]
32+
# Get the SSH parameters
33+
self.timeout = self.ceParameters.get("Timeout", self.timeout)
34+
self.user = self.ceParameters.get("SSHUser", self.user)
35+
port = self.ceParameters.get("SSHPort", None)
36+
password = self.ceParameters.get("SSHPassword", None)
37+
key = self.ceParameters.get("SSHKey", None)
38+
39+
# Get submission parameters
40+
self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions)
41+
self.preamble = self.ceParameters.get("Preamble", self.preamble)
42+
self.account = self.ceParameters.get("Account", self.account)
3243
self.queue = self.ceParameters["Queue"]
3344
self.log.info("Using queue: ", self.queue)
3445

35-
self.submitOptions = self.ceParameters.get("SubmitOptions", "")
36-
self.preamble = self.ceParameters.get("Preamble", "")
37-
self.account = self.ceParameters.get("Account", "")
38-
39-
# Prepare all the hosts
40-
for hPar in self.ceParameters["SSHHost"].strip().split(","):
41-
host = hPar.strip().split("/")[0]
42-
result = self._prepareRemoteHost(host=host)
43-
if result["OK"]:
44-
self.log.info(f"Host {host} registered for usage")
45-
self.sshHost.append(hPar.strip())
46+
# Get output and error templates
47+
self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate)
48+
self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate)
49+
50+
# Prepare the remote hosts
51+
for host in self.ceParameters.get("SSHHost", "").strip().split(","):
52+
hostDetails = host.strip().split("/")
53+
if len(hostDetails) > 1:
54+
hostname = hostDetails[0]
55+
maxJobs = int(hostDetails[1])
4656
else:
47-
self.log.error("Failed to initialize host", host)
57+
hostname = hostDetails[0]
58+
maxJobs = self.ceParameters.get("MaxTotalJobs", 0)
59+
60+
connection = self._getConnection(hostname, self.user, port, password, key)
61+
62+
result = self._prepareRemoteHost(connection)
63+
if not result["OK"]:
4864
return result
4965

66+
self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs}
67+
self.log.info(f"Host {hostname} registered for usage")
68+
5069
return S_OK()
5170

5271
#############################################################################
72+
5373
def submitJob(self, executableFile, proxy, numberOfJobs=1):
5474
"""Method to submit job"""
55-
5675
# Choose eligible hosts, rank them by the number of available slots
5776
rankHosts = {}
5877
maxSlots = 0
59-
for host in self.sshHost:
60-
thost = host.split("/")
61-
hostName = thost[0]
62-
maxHostJobs = 1
63-
if len(thost) > 1:
64-
maxHostJobs = int(thost[1])
65-
66-
result = self._getHostStatus(hostName)
78+
for _, details in self.connections.items():
79+
result = self._getHostStatus(details["connection"])
6780
if not result["OK"]:
6881
continue
69-
slots = maxHostJobs - result["Value"]["Running"]
82+
slots = details["maxJobs"] - result["Value"]["Running"]
7083
if slots > 0:
7184
rankHosts.setdefault(slots, [])
72-
rankHosts[slots].append(hostName)
85+
rankHosts[slots].append(details["connection"])
7386
if slots > maxSlots:
7487
maxSlots = slots
7588

@@ -83,18 +96,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
8396
restJobs = numberOfJobs
8497
submittedJobs = []
8598
stampDict = {}
99+
batchSystemName = self.batchSystem.__class__.__name__.lower()
100+
86101
for slots in range(maxSlots, 0, -1):
87102
if slots not in rankHosts:
88103
continue
89-
for host in rankHosts[slots]:
90-
result = self._submitJobToHost(executableFile, min(slots, restJobs), host)
104+
for connection in rankHosts[slots]:
105+
result = self._submitJobToHost(connection, executableFile, min(slots, restJobs))
91106
if not result["OK"]:
92107
continue
93108

94-
nJobs = len(result["Value"])
109+
batchIDs, jobStamps = result["Value"]
110+
111+
nJobs = len(batchIDs)
95112
if nJobs > 0:
96-
submittedJobs.extend(result["Value"])
97-
stampDict.update(result.get("PilotStampDict", {}))
113+
jobIDs = [
114+
f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}"
115+
for _id in batchIDs
116+
]
117+
submittedJobs.extend(jobIDs)
118+
for iJob, jobID in enumerate(jobIDs):
119+
stampDict[jobID] = jobStamps[iJob]
120+
98121
restJobs = restJobs - nJobs
99122
if restJobs <= 0:
100123
break
@@ -105,6 +128,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
105128
result["PilotStampDict"] = stampDict
106129
return result
107130

131+
#############################################################################
132+
108133
def killJob(self, jobIDs):
109134
"""Kill specified jobs"""
110135
jobIDList = list(jobIDs)
@@ -120,7 +145,7 @@ def killJob(self, jobIDs):
120145

121146
failed = []
122147
for host, jobIDList in hostDict.items():
123-
result = self._killJobOnHost(jobIDList, host)
148+
result = self._killJobOnHost(self.connections[host]["connection"], jobIDList)
124149
if not result["OK"]:
125150
failed.extend(jobIDList)
126151
message = result["Message"]
@@ -133,16 +158,17 @@ def killJob(self, jobIDs):
133158

134159
return result
135160

161+
#############################################################################
162+
136163
def getCEStatus(self):
137164
"""Method to return information on running and pending jobs."""
138165
result = S_OK()
139166
result["SubmittedJobs"] = self.submittedJobs
140167
result["RunningJobs"] = 0
141168
result["WaitingJobs"] = 0
142169

143-
for host in self.sshHost:
144-
thost = host.split("/")
145-
resultHost = self._getHostStatus(thost[0])
170+
for _, details in self.connections:
171+
resultHost = self._getHostStatus(details["connection"])
146172
if resultHost["OK"]:
147173
result["RunningJobs"] += resultHost["Value"]["Running"]
148174

@@ -151,6 +177,8 @@ def getCEStatus(self):
151177

152178
return result
153179

180+
#############################################################################
181+
154182
def getJobStatus(self, jobIDList):
155183
"""Get status of the jobs in the given list"""
156184
hostDict = {}
@@ -162,7 +190,7 @@ def getJobStatus(self, jobIDList):
162190
resultDict = {}
163191
failed = []
164192
for host, jobIDList in hostDict.items():
165-
result = self._getJobStatusOnHost(jobIDList, host)
193+
result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList)
166194
if not result["OK"]:
167195
failed.extend(jobIDList)
168196
continue
@@ -173,3 +201,16 @@ def getJobStatus(self, jobIDList):
173201
resultDict[job] = PilotStatus.UNKNOWN
174202

175203
return S_OK(resultDict)
204+
205+
#############################################################################
206+
207+
def getJobOutput(self, jobID, localDir=None):
208+
"""Get the specified job standard output and error files. If the localDir is provided,
209+
the output is returned as file in this directory. Otherwise, the output is returned
210+
as strings.
211+
"""
212+
self.log.verbose("Getting output for jobID", jobID)
213+
214+
# host can be retrieved from the path of the jobID
215+
host = os.path.dirname(urlparse(jobID).path).lstrip("/")
216+
return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir)

0 commit comments

Comments
 (0)