Skip to content

Commit 9481f82

Browse files
committed
feat(Resources): introduce fabric in SSHCE
1 parent cbb4b91 commit 9481f82

File tree

4 files changed

+394
-452
lines changed

4 files changed

+394
-452
lines changed

environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ dependencies:
1818
- cwltool
1919
- db12
2020
- opensearch-py
21+
- fabric
2122
- fts3
2223
- gitpython >=2.1.0
24+
- invoke
2325
- m2crypto >=0.38.0
2426
- matplotlib
2527
- numpy
28+
- paramiko
2629
- pexpect >=4.0.1
2730
- pillow
2831
- prompt-toolkit >=3,<4

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ install_requires =
3939
gfal2-python
4040
importlib_metadata >=4.4
4141
importlib_resources
42+
invoke
4243
M2Crypto >=0.36
4344
packaging
45+
paramiko
4446
pexpect
4547
prompt-toolkit >=3
4648
psutil
Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs
1+
""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs
22
directly through ssh
33
"""
44

@@ -12,64 +12,78 @@
1212

1313

1414
class SSHBatchComputingElement(SSHComputingElement):
15-
#############################################################################
1615
def __init__(self, ceUniqueID):
1716
"""Standard constructor."""
1817
super().__init__(ceUniqueID)
1918

20-
self.ceType = "SSHBatch"
21-
self.sshHost = []
19+
self.connections = {}
2220
self.execution = "SSHBATCH"
2321

2422
def _reset(self):
2523
"""Process CE parameters and make necessary adjustments"""
24+
# Get the Batch System instance
2625
result = self._getBatchSystem()
2726
if not result["OK"]:
2827
return result
28+
29+
# Get the location of the remote directories
2930
self._getBatchSystemDirectoryLocations()
3031

31-
self.user = self.ceParameters["SSHUser"]
32+
# Get the SSH parameters
33+
self.timeout = self.ceParameters.get("Timeout", self.timeout)
34+
self.user = self.ceParameters.get("SSHUser", self.user)
35+
port = self.ceParameters.get("SSHPort", None)
36+
password = self.ceParameters.get("SSHPassword", None)
37+
key = self.ceParameters.get("SSHKey", None)
38+
tunnel = self.ceParameters.get("SSHTunnel", None)
39+
40+
# Get submission parameters
41+
self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions)
42+
self.preamble = self.ceParameters.get("Preamble", self.preamble)
43+
self.account = self.ceParameters.get("Account", self.account)
3244
self.queue = self.ceParameters["Queue"]
3345
self.log.info("Using queue: ", self.queue)
3446

35-
self.submitOptions = self.ceParameters.get("SubmitOptions", "")
36-
self.preamble = self.ceParameters.get("Preamble", "")
37-
self.account = self.ceParameters.get("Account", "")
38-
39-
# Prepare all the hosts
40-
for hPar in self.ceParameters["SSHHost"].strip().split(","):
41-
host = hPar.strip().split("/")[0]
42-
result = self._prepareRemoteHost(host=host)
43-
if result["OK"]:
44-
self.log.info(f"Host {host} registered for usage")
45-
self.sshHost.append(hPar.strip())
47+
# Get output and error templates
48+
self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate)
49+
self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate)
50+
51+
# Prepare the remote hosts
52+
for host in self.ceParameters.get("SSHHost", "").strip().split(","):
53+
hostDetails = host.strip().split("/")
54+
if len(hostDetails) > 1:
55+
hostname = hostDetails[0]
56+
maxJobs = int(hostDetails[1])
4657
else:
47-
self.log.error("Failed to initialize host", host)
58+
hostname = hostDetails[0]
59+
maxJobs = self.ceParameters.get("MaxTotalJobs", 0)
60+
61+
connection = self._getConnection(hostname, self.user, port, password, key, tunnel)
62+
63+
result = self._prepareRemoteHost(connection)
64+
if not result["OK"]:
4865
return result
4966

67+
self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs}
68+
self.log.info(f"Host {hostname} registered for usage")
69+
5070
return S_OK()
5171

5272
#############################################################################
73+
5374
def submitJob(self, executableFile, proxy, numberOfJobs=1):
5475
"""Method to submit job"""
55-
5676
# Choose eligible hosts, rank them by the number of available slots
5777
rankHosts = {}
5878
maxSlots = 0
59-
for host in self.sshHost:
60-
thost = host.split("/")
61-
hostName = thost[0]
62-
maxHostJobs = 1
63-
if len(thost) > 1:
64-
maxHostJobs = int(thost[1])
65-
66-
result = self._getHostStatus(hostName)
79+
for _, details in self.connections.items():
80+
result = self._getHostStatus(details["connection"])
6781
if not result["OK"]:
6882
continue
69-
slots = maxHostJobs - result["Value"]["Running"]
83+
slots = details["maxJobs"] - result["Value"]["Running"]
7084
if slots > 0:
7185
rankHosts.setdefault(slots, [])
72-
rankHosts[slots].append(hostName)
86+
rankHosts[slots].append(details["connection"])
7387
if slots > maxSlots:
7488
maxSlots = slots
7589

@@ -83,18 +97,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
8397
restJobs = numberOfJobs
8498
submittedJobs = []
8599
stampDict = {}
100+
batchSystemName = self.batchSystem.__class__.__name__.lower()
101+
86102
for slots in range(maxSlots, 0, -1):
87103
if slots not in rankHosts:
88104
continue
89-
for host in rankHosts[slots]:
90-
result = self._submitJobToHost(executableFile, min(slots, restJobs), host)
105+
for connection in rankHosts[slots]:
106+
result = self._submitJobToHost(connection, executableFile, min(slots, restJobs))
91107
if not result["OK"]:
92108
continue
93109

94-
nJobs = len(result["Value"])
110+
batchIDs, jobStamps = result["Value"]
111+
112+
nJobs = len(batchIDs)
95113
if nJobs > 0:
96-
submittedJobs.extend(result["Value"])
97-
stampDict.update(result.get("PilotStampDict", {}))
114+
jobIDs = [
115+
f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}"
116+
for _id in batchIDs
117+
]
118+
submittedJobs.extend(jobIDs)
119+
for iJob, jobID in enumerate(jobIDs):
120+
stampDict[jobID] = jobStamps[iJob]
121+
98122
restJobs = restJobs - nJobs
99123
if restJobs <= 0:
100124
break
@@ -105,6 +129,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1):
105129
result["PilotStampDict"] = stampDict
106130
return result
107131

132+
#############################################################################
133+
108134
def killJob(self, jobIDs):
109135
"""Kill specified jobs"""
110136
jobIDList = list(jobIDs)
@@ -120,7 +146,7 @@ def killJob(self, jobIDs):
120146

121147
failed = []
122148
for host, jobIDList in hostDict.items():
123-
result = self._killJobOnHost(jobIDList, host)
149+
result = self._killJobOnHost(self.connections[host]["connection"], jobIDList)
124150
if not result["OK"]:
125151
failed.extend(jobIDList)
126152
message = result["Message"]
@@ -133,16 +159,17 @@ def killJob(self, jobIDs):
133159

134160
return result
135161

162+
#############################################################################
163+
136164
def getCEStatus(self):
137165
"""Method to return information on running and pending jobs."""
138166
result = S_OK()
139167
result["SubmittedJobs"] = self.submittedJobs
140168
result["RunningJobs"] = 0
141169
result["WaitingJobs"] = 0
142170

143-
for host in self.sshHost:
144-
thost = host.split("/")
145-
resultHost = self._getHostStatus(thost[0])
171+
for _, details in self.connections:
172+
resultHost = self._getHostStatus(details["connection"])
146173
if resultHost["OK"]:
147174
result["RunningJobs"] += resultHost["Value"]["Running"]
148175

@@ -151,6 +178,8 @@ def getCEStatus(self):
151178

152179
return result
153180

181+
#############################################################################
182+
154183
def getJobStatus(self, jobIDList):
155184
"""Get status of the jobs in the given list"""
156185
hostDict = {}
@@ -162,7 +191,7 @@ def getJobStatus(self, jobIDList):
162191
resultDict = {}
163192
failed = []
164193
for host, jobIDList in hostDict.items():
165-
result = self._getJobStatusOnHost(jobIDList, host)
194+
result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList)
166195
if not result["OK"]:
167196
failed.extend(jobIDList)
168197
continue
@@ -173,3 +202,16 @@ def getJobStatus(self, jobIDList):
173202
resultDict[job] = PilotStatus.UNKNOWN
174203

175204
return S_OK(resultDict)
205+
206+
#############################################################################
207+
208+
def getJobOutput(self, jobID, localDir=None):
209+
"""Get the specified job standard output and error files. If the localDir is provided,
210+
the output is returned as file in this directory. Otherwise, the output is returned
211+
as strings.
212+
"""
213+
self.log.verbose("Getting output for jobID", jobID)
214+
215+
# host can be retrieved from the path of the jobID
216+
host = os.path.dirname(urlparse(jobID).path).lstrip("/")
217+
return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir)

0 commit comments

Comments
 (0)