diff --git a/environment.yml b/environment.yml index 3c25b0c3c2e..5bfe0cd922d 100644 --- a/environment.yml +++ b/environment.yml @@ -18,11 +18,14 @@ dependencies: - cwltool - db12 - opensearch-py + - fabric - fts3 - gitpython >=2.1.0 + - invoke - m2crypto >=0.38.0 - matplotlib - numpy + - paramiko - pexpect >=4.0.1 - pillow - prompt-toolkit >=3,<4 diff --git a/setup.cfg b/setup.cfg index f03d63d3f29..4ea3ca4a4ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,12 +35,15 @@ install_requires = diracx-core >=v0.0.1 diracx-cli >=v0.0.1 db12 + fabric fts3 gfal2-python importlib_metadata >=4.4 importlib_resources + invoke M2Crypto >=0.36 packaging + paramiko pexpect prompt-toolkit >=3 psutil diff --git a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py index 38c70887d4e..18fd3da4db2 100644 --- a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py @@ -1,4 +1,4 @@ -""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs +""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs directly through ssh """ @@ -12,64 +12,77 @@ class SSHBatchComputingElement(SSHComputingElement): - ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.ceType = "SSHBatch" - self.sshHost = [] + self.connections = {} self.execution = "SSHBATCH" def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") - self.account = self.ceParameters.get("Account", "") - - # Prepare all the hosts - for hPar in self.ceParameters["SSHHost"].strip().split(","): - host = hPar.strip().split("/")[0] - result = self._prepareRemoteHost(host=host) - if result["OK"]: - self.log.info(f"Host {host} registered for usage") - self.sshHost.append(hPar.strip()) + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) + + # Prepare the remote hosts + for host in self.ceParameters.get("SSHHost", "").strip().split(","): + hostDetails = host.strip().split("/") + if len(hostDetails) > 1: + hostname = hostDetails[0] + maxJobs = int(hostDetails[1]) else: - self.log.error("Failed to initialize host", host) + hostname = hostDetails[0] + maxJobs = self.ceParameters.get("MaxTotalJobs", 0) + + connection = self._getConnection(hostname, self.user, port, password, key) + + result = self._prepareRemoteHost(connection) + if not result["OK"]: return result + self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs} + self.log.info(f"Host {hostname} registered for usage") + return S_OK() ############################################################################# + def submitJob(self, executableFile, proxy, numberOfJobs=1): """Method to submit job""" - # Choose eligible hosts, rank them by the number of available slots rankHosts = {} maxSlots = 0 - for host in self.sshHost: - thost = host.split("/") - hostName = thost[0] - maxHostJobs = 1 - if len(thost) > 1: - maxHostJobs = int(thost[1]) - - result = self._getHostStatus(hostName) + for _, details in self.connections.items(): + result = self._getHostStatus(details["connection"]) if not result["OK"]: continue - slots = maxHostJobs - result["Value"]["Running"] + slots = details["maxJobs"] - result["Value"]["Running"] if slots > 0: rankHosts.setdefault(slots, []) - rankHosts[slots].append(hostName) + rankHosts[slots].append(details["connection"]) if slots > maxSlots: maxSlots = slots @@ -83,18 +96,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): restJobs = numberOfJobs submittedJobs = [] stampDict = {} + batchSystemName = self.batchSystem.__class__.__name__.lower() + for slots in range(maxSlots, 0, -1): if slots not in rankHosts: continue - for host in rankHosts[slots]: - result = self._submitJobToHost(executableFile, min(slots, restJobs), host) + for connection in rankHosts[slots]: + result = self._submitJobToHost(connection, executableFile, min(slots, restJobs)) if not result["OK"]: continue - nJobs = len(result["Value"]) + batchIDs, jobStamps = result["Value"] + + nJobs = len(batchIDs) if nJobs > 0: - submittedJobs.extend(result["Value"]) - stampDict.update(result.get("PilotStampDict", {})) + jobIDs = [ + f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}" + for _id in batchIDs + ] + submittedJobs.extend(jobIDs) + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + restJobs = restJobs - nJobs if restJobs <= 0: break @@ -105,6 +128,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): result["PilotStampDict"] = stampDict return result + ############################################################################# + def killJob(self, jobIDs): """Kill specified jobs""" jobIDList = list(jobIDs) @@ -120,7 +145,7 @@ def killJob(self, jobIDs): failed = [] for host, jobIDList in hostDict.items(): - result = self._killJobOnHost(jobIDList, host) + result = self._killJobOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) message = result["Message"] @@ -133,6 +158,8 @@ def killJob(self, jobIDs): return result + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -140,9 +167,8 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - for host in self.sshHost: - thost = host.split("/") - resultHost = self._getHostStatus(thost[0]) + for _, details in self.connections: + resultHost = self._getHostStatus(details["connection"]) if resultHost["OK"]: result["RunningJobs"] += resultHost["Value"]["Running"] @@ -151,6 +177,8 @@ def getCEStatus(self): return result + ############################################################################# + def getJobStatus(self, jobIDList): """Get status of the jobs in the given list""" hostDict = {} @@ -162,7 +190,7 @@ def getJobStatus(self, jobIDList): resultDict = {} failed = [] for host, jobIDList in hostDict.items(): - result = self._getJobStatusOnHost(jobIDList, host) + result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) continue @@ -173,3 +201,16 @@ def getJobStatus(self, jobIDList): resultDict[job] = PilotStatus.UNKNOWN return S_OK(resultDict) + + ############################################################################# + + def getJobOutput(self, jobID, localDir=None): + """Get the specified job standard output and error files. If the localDir is provided, + the output is returned as file in this directory. Otherwise, the output is returned + as strings. + """ + self.log.verbose("Getting output for jobID", jobID) + + # host can be retrieved from the path of the jobID + host = os.path.dirname(urlparse(jobID).path).lstrip("/") + return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir) diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index 25668e62b57..3f35b1d2183 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -40,25 +40,19 @@ SSH password SSHPort: - Port number if not standard, e.g. for the gsissh access + Port number if not standard SSHKey: Location of the ssh private key for no-password connection -SSHOptions: - Any other SSH options to be used. Example:: - - SSHOptions = -o UserKnownHostsFile=/local/path/to/known_hosts - - Allows to have a local copy of the ``known_hosts`` file, independent of the HOME directory. - SSHTunnel: String defining the use of intermediate SSH host. Example:: ssh -i /private/key/location -l final_user final_host -SSHType: - SSH (default) or gsissh +Timeout: + Timeout for the SSH commands. Default is 120 seconds. + **Code Documentation** """ @@ -69,278 +63,99 @@ import stat import tempfile import uuid -from shlex import quote as shlex_quote from urllib.parse import urlparse -import pexpect +from fabric import Connection +from invoke.exceptions import CommandTimedOut +from paramiko.ssh_exception import SSHException import DIRAC -from DIRAC import S_ERROR, S_OK, gLogger +from DIRAC import S_ERROR, S_OK from DIRAC.Core.Utilities.List import breakListIntoChunks, uniqueElements from DIRAC.Resources.Computing.BatchSystems.executeBatch import executeBatchContent from DIRAC.Resources.Computing.ComputingElement import ComputingElement -class SSH: - """SSH class encapsulates passing commands and files through an SSH tunnel - to a remote host. It can use either ssh or gsissh access. The final host - where the commands will be executed and where the files will copied/retrieved - can be reached through an intermediate host if SSHTunnel parameters is defined. - - SSH constructor parameters are defined in a SSH accessible Computing Element - in the Configuration System: - - - SSHHost: SSH host name - - SSHUser: SSH user login - - SSHPassword: SSH password - - SSHPort: port number if not standard, e.g. for the gsissh access - - SSHKey: location of the ssh private key for no-password connection - - SSHOptions: any other SSH options to be used - - SSHTunnel: string defining the use of intermediate SSH host. Example: - 'ssh -i /private/key/location -l final_user final_host' - - SSHType: ssh ( default ) or gsissh - - The class public interface includes two methods: - - sshCall( timeout, command_sequence ) - scpCall( timeout, local_file, remote_file, upload = False/True ) - """ - - def __init__(self, host=None, parameters=None): - self.host = host - if parameters is None: - parameters = {} - if not host: - self.host = parameters.get("SSHHost", "") - - self.user = parameters.get("SSHUser", "") - self.password = parameters.get("SSHPassword", "") - self.port = parameters.get("SSHPort", "") - self.key = parameters.get("SSHKey", "") - self.options = parameters.get("SSHOptions", "") - self.sshTunnel = parameters.get("SSHTunnel", "") - self.sshType = parameters.get("SSHType", "ssh") - - if self.port: - self.options += f" -p {self.port}" - if self.key: - self.options += f" -i {self.key}" - self.options = self.options.strip() - - self.log = gLogger.getSubLogger("SSH") - - def __ssh_call(self, command, timeout): - if not timeout: - timeout = 999 - - ssh_newkey = "Are you sure you want to continue connecting" - try: - child = pexpect.spawn(command, timeout=timeout, encoding="utf-8") - i = child.expect([pexpect.TIMEOUT, ssh_newkey, pexpect.EOF, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, child.before, "SSH login failed")) - - if i == 1: # SSH does not have the public key. Just accept it. - child.sendline("yes") - child.expect("assword: ") - i = child.expect([pexpect.TIMEOUT, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, str(child.before) + str(child.after), "SSH login failed")) - if i == 1: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - if i == 2: - # Passwordless login, get the output - return S_OK((0, child.before, "")) - - if self.password: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - return S_ERROR(f"Unknown error: {child.before}") - except Exception as x: - return S_ERROR(f"Encountered exception: {str(x)}") - - def sshCall(self, timeout, cmdSeq): - """Execute remote command via a ssh remote call - - :param int timeout: timeout of the command - :param cmdSeq: list of command components - :type cmdSeq: python:list - """ - - command = cmdSeq - if isinstance(cmdSeq, list): - command = " ".join(cmdSeq) - - pattern = "__DIRAC__" - - if self.sshTunnel: - command = command.replace("'", '\\\\\\"') - command = command.replace("$", "\\\\\\$") - command = '/bin/sh -c \' {} -q {} -l {} {} "{} \\"echo {}; {}\\" " \' '.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - pattern, - command, - ) - else: - # command = command.replace( '$', '\$' ) - command = '{} -q {} -l {} {} "echo {}; {}"'.format( - self.sshType, - self.options, - self.user, - self.host, - pattern, - command, - ) - self.log.debug(f"SSH command: {command}") - result = self.__ssh_call(command, timeout) - self.log.debug(f"SSH command result {str(result)}") - if not result["OK"]: - return result - - # Take the output only after the predefined pattern - ind = result["Value"][1].find("__DIRAC__") - if ind == -1: - return result - - status, output, error = result["Value"] - output = output[ind + 9 :] - if output.startswith("\r"): - output = output[1:] - if output.startswith("\n"): - output = output[1:] - - result["Value"] = (status, output, error) - return result - - def scpCall(self, timeout, localFile, remoteFile, postUploadCommand="", upload=True): - """Perform file copy through an SSH magic. - - :param int timeout: timeout of the command - :param str localFile: local file path, serves as source for uploading and destination for downloading. - Can take 'Memory' as value, in this case the downloaded contents is returned - as result['Value'] - :param str remoteFile: remote file full path - :param str postUploadCommand: command executed on the remote side after file upload - :param bool upload: upload if True, download otherwise - """ - # shlex_quote aims to prevent any security issue or problems with filepath containing spaces - # it returns a shell-escaped version of the filename - localFile = shlex_quote(localFile) - remoteFile = shlex_quote(remoteFile) - if upload: - if self.sshTunnel: - remoteFile = remoteFile.replace("$", r"\\\\\$") - postUploadCommand = postUploadCommand.replace("$", r"\\\\\$") - command = '/bin/sh -c \'cat {} | {} -q {} {}@{} "{} \\"cat > {}; {}\\""\' '.format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - postUploadCommand, - ) - else: - command = "/bin/sh -c \"cat {} | {} -q {} {}@{} 'cat > {}; {}'\" ".format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - postUploadCommand, - ) - else: - finalCat = f"| cat > {localFile}" - if localFile.lower() == "memory": - finalCat = "" - if self.sshTunnel: - remoteFile = remoteFile.replace("$", "\\\\\\$") - command = '/bin/sh -c \'{} -q {} -l {} {} "{} \\"cat {}\\"" {}\''.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - finalCat, - ) - else: - remoteFile = remoteFile.replace("$", r"\$") - command = "/bin/sh -c '{} -q {} -l {} {} \"cat {}\" {}'".format( - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - finalCat, - ) - - self.log.debug(f"SSH copy command: {command}") - return self.__ssh_call(command, timeout) - - class SSHComputingElement(ComputingElement): ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.execution = "SSHCE" self.submittedJobs = 0 - self.outputTemplate = "" - self.errorTemplate = "" - - ############################################################################ - def setProxy(self, proxy): - """ - Set and prepare proxy to use - :param str proxy: proxy to use - :return: S_OK/S_ERROR - """ - ComputingElement.setProxy(self, proxy) - if self.ceParameters.get("SSHType", "ssh") == "gsissh": - result = self._prepareProxy() - if not result["OK"]: - gLogger.error("SSHComputingElement: failed to set up proxy", result["Message"]) - return result - return S_OK() + # SSH connection + self.connection = None + self.timeout = 120 + self.user = None + self.host = None + + # Submission parameters + self.queue = None + self.submitOptions = None + self.preamble = None + self.account = None + self.execution = "SSHCE" - ############################################################################# - def _addCEConfigDefaults(self): - """Method to make sure all necessary Configuration Parameters are defined""" - # First assure that any global parameters are loaded - ComputingElement._addCEConfigDefaults(self) - # Now batch system specific ones - if "SharedArea" not in self.ceParameters: - # . isn't a good location, move to $HOME - self.ceParameters["SharedArea"] = "$HOME" + # Directories + self.sharedArea = "$HOME" + self.batchOutput = "data" + self.batchError = "data" + self.infoArea = "data" + self.executableArea = "info" + self.workArea = "work" - if "BatchOutput" not in self.ceParameters: - self.ceParameters["BatchOutput"] = "data" + # Output and error templates + self.outputTemplate = "" + self.errorTemplate = "" - if "BatchError" not in self.ceParameters: - self.ceParameters["BatchError"] = "data" + ############################################################################# - if "ExecutableArea" not in self.ceParameters: - self.ceParameters["ExecutableArea"] = "data" + def _run(self, connection: Connection, command: str): + """Run the command on the remote host""" + try: + result = connection.run(command, warn=True, hide=True) + if result.failed: + return S_ERROR(f"[{connection.host}] Command returned an error: {result.stderr}") + return S_OK(result.stdout) + except CommandTimedOut as e: + return S_ERROR( + errno.ETIME, f"[{connection.host}] The command timed out. Consider increasing the timeout: {e}" + ) + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "InfoArea" not in self.ceParameters: - self.ceParameters["InfoArea"] = "info" + def _put(self, connection: Connection, local: str, remote: str, preserveMode: bool = True): + """Upload a file to the remote host""" + try: + connection.put(local, remote=remote, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed uploading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") + + def _get(self, connection: Connection, remote: str, local: str, preserveMode: bool = True): + """Download a file from the remote host""" + try: + if local == "Memory": + # Download to memory: use BytesIO buffer + from io import BytesIO + + buffer = BytesIO() + connection.get(remote, local=buffer) + content = buffer.getvalue().decode("utf-8", errors="replace") + return S_OK((0, content)) # Return (returncode, stdout) tuple + else: + # Download to file + connection.get(remote, local=local, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed downloading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "WorkArea" not in self.ceParameters: - self.ceParameters["WorkArea"] = "work" + ############################################################################# def _getBatchSystem(self): """Load a Batch System instance from the CE Parameters""" @@ -354,90 +169,229 @@ def _getBatchSystem(self): def _getBatchSystemDirectoryLocations(self): """Get names of the locations to store outputs, errors, info and executables.""" - self.sharedArea = self.ceParameters["SharedArea"] - self.batchOutput = self.ceParameters["BatchOutput"] - if not self.batchOutput.startswith("/"): - self.batchOutput = os.path.join(self.sharedArea, self.batchOutput) - self.batchError = self.ceParameters["BatchError"] - if not self.batchError.startswith("/"): - self.batchError = os.path.join(self.sharedArea, self.batchError) - self.infoArea = self.ceParameters["InfoArea"] - if not self.infoArea.startswith("/"): - self.infoArea = os.path.join(self.sharedArea, self.infoArea) - self.executableArea = self.ceParameters["ExecutableArea"] - if not self.executableArea.startswith("/"): - self.executableArea = os.path.join(self.sharedArea, self.executableArea) - self.workArea = self.ceParameters["WorkArea"] - if not self.workArea.startswith("/"): - self.workArea = os.path.join(self.sharedArea, self.workArea) + self.sharedArea = self.ceParameters.get("SharedArea", self.sharedArea) + + def _get_dir(directory: str, defaultValue: str) -> str: + value = self.ceParameters.get(directory, defaultValue) + if value.startswith("/"): + return value + return os.path.join(self.sharedArea, value) + + self.batchOutput = _get_dir("BatchOutput", self.batchOutput) + self.batchError = _get_dir("BatchError", self.batchError) + self.infoArea = _get_dir("InfoArea", self.infoArea) + self.executableArea = _get_dir("ExecutableArea", self.executableArea) + self.workArea = _get_dir("WorkArea", self.workArea) + + def _parseTunnel(self, host_string: str) -> tuple[str, int | None]: + """Parse a host string to extract hostname and port. + + Supports multiple formats: + 1. Legacy SSH command: "ssh [options] [user@]host[:port]" - extracts from command + - Also extracts port from -p flag if present + 2. user@host:port format + 3. host:port format + 4. Plain hostname + + :param host_string: The host string to parse + :return: Tuple of (hostname, port) where port is None if not specified + """ + if not host_string: + return None, None + + hostname = None + port = None + + # Check if it's a legacy SSH command format + if host_string.strip().startswith("ssh "): + # Parse the SSH command to extract the hostname and port + # Format: ssh [options] [-l user] [-p port] [-i keyfile] [user@]hostname[:port] + parts = host_string.split() + + # The hostname is typically the last argument + # Skip option flags and their values + skip_next = False + + for i, part in enumerate(parts): + if i == 0: # Skip "ssh" + continue + + if skip_next: + skip_next = False + continue + + # Check for -p flag to capture port + if part == "-p" and i + 1 < len(parts): + try: + port = int(parts[i + 1]) + except ValueError: + self.log.warn(f"Invalid port number in -p flag: {parts[i + 1]}") + skip_next = True + continue + + # Options that take a value + if part in ["-i", "-l", "-p", "-o", "-F", "-J"]: + skip_next = True + continue + + # Skip flags without values + if part.startswith("-"): + continue + + # This should be the hostname (possibly with user@ prefix and :port suffix) + hostname = part + + if not hostname: + self.log.warn(f"Failed to parse hostname from legacy SSH command format: {host_string}") + return None, None + + self.log.verbose(f"Parsed hostname '{hostname}' from legacy SSH command format") + else: + # New format: just the hostname (possibly with user@ and :port) + hostname = host_string.strip() + + # Now parse hostname to extract user@host:port components + # Remove user@ prefix if present + if "@" in hostname: + hostname = hostname.split("@", 1)[1] + + # Extract port if present in hostname (unless already extracted from -p flag) + if ":" in hostname and port is None: + hostname, port_str = hostname.rsplit(":", 1) + try: + port = int(port_str) + except ValueError: + self.log.warn(f"Invalid port number '{port_str}' in '{host_string}', ignoring port") + port = None + + return hostname, port + + def _getConnection( + self, + host: str, + user: str, + port: int, + password: str, + key: str, + gateway_host: str | None = None, + gateway_port: int | None = None, + ) -> Connection: + """Get a Connection instance to the host. + + :param host: The final destination host + :param user: SSH username + :param port: SSH port + :param password: SSH password + :param key: SSH key file path + :param gateway_host: The gateway/jump host (None if no gateway) + """ + connectionParams = {} + if password: + connectionParams["password"] = password + if key: + connectionParams["key_filename"] = key + + gateway = None + if gateway_host: + gateway = Connection(gateway_host, user=user, port=gateway_port, connect_kwargs=connectionParams) + + return Connection( + host, + user=user, + port=port, + gateway=gateway, + connect_kwargs=connectionParams, + connect_timeout=self.timeout, + ) def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.host = self.ceParameters.get("SSHHost", self.host) + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + tunnel = self.ceParameters.get("SSHTunnel", None) + + # IMPORTANT: The SSHTunnel/SSHHost naming is counterintuitive and backwards! + # When SSHTunnel is specified: + # - SSHHost = gateway/jump node + # - SSHTunnel = final destination host + # This should be fixed in DiracX to make the naming intuitive. + if tunnel: + # Swap them: SSHHost becomes gateway, SSHTunnel becomes final host + final_host, final_port = self._parseTunnel(tunnel) + gateway_host = self.host + self.host = final_host + # Same with the ports + gateway_port = port + port = final_port + else: + gateway_host = None + gateway_port = None + + # Configure the SSH connection + self.connection = self._getConnection(self.host, self.user, port, password, key, gateway_host, gateway_port) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) - self.account = self.ceParameters.get("Account", "") - result = self._prepareRemoteHost() + # Prepare the remote host + result = self._prepareRemoteHost(self.connection) if not result["OK"]: return result return S_OK() - def _prepareRemoteHost(self, host=None): + def _prepareRemoteHost(self, connection: Connection): """Prepare remote directories and upload control script""" - - ssh = SSH(host=host, parameters=self.ceParameters) - # Make remote directories + self.log.verbose(f"Creating working directories on {self.host}") dirTuple = tuple( uniqueElements( [self.sharedArea, self.executableArea, self.infoArea, self.batchOutput, self.batchError, self.workArea] ) ) - nDirs = len(dirTuple) - cmd = "mkdir -p %s; " * nDirs % dirTuple - cmd = f"bash -c '{cmd}'" - self.log.verbose(f"Creating working directories on {self.ceParameters['SSHHost']}") - result = ssh.sshCall(30, cmd) + cmd = f"mkdir -p {' '.join(dirTuple)}" + result = self._run(connection, cmd) if not result["OK"]: - self.log.error("Failed creating working directories", f"({result['Message']})") + self.log.error("Failed creating working directories: ", result["Message"]) return result - status, output, _error = result["Value"] - if status == -1: - self.log.error("Timeout while creating directories") - return S_ERROR(errno.ETIME, "Timeout while creating directories") - if "cannot" in output: - self.log.error("Failed to create directories", f"({output})") - return S_ERROR(errno.EACCES, "Failed to create directories") # Upload the control script now + self.log.verbose("Generating control script") result = self._generateControlScript() if not result["OK"]: - self.log.warn("Failed generating control script") + self.log.error("Failed generating control script") return result localScript = result["Value"] - self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.ceParameters['SSHHost']}") + os.chmod(localScript, 0o755) + + self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.host}") remoteScript = f"{self.sharedArea}/execute_batch" - result = ssh.scpCall(30, localScript, remoteScript, postUploadCommand=f"chmod +x {remoteScript}") + + result = self._put(connection, localScript, remote=remoteScript) if not result["OK"]: - self.log.warn(f"Failed uploading control script: {result['Message']}") + self.log.error(f"Failed uploading control script: {result['Message']}") return result - status, output, _error = result["Value"] - if status != 0: - if status == -1: - self.log.warn("Timeout while uploading control script") - return S_ERROR("Timeout while uploading control script") - self.log.warn(f"Failed uploading control script: {output}") - return S_ERROR("Failed uploading control script") # Delete the generated control script locally try: @@ -470,10 +424,10 @@ def _generateControlScript(self): return S_OK(f"{controlScript}") - def __executeHostCommand(self, command, options, ssh=None, host=None): - if not ssh: - ssh = SSH(host=host, parameters=self.ceParameters) + ############################################################################# + def __executeHostCommand(self, connection: Connection, command: str, options: dict[str]): + """Execute a command on the remote host""" options["BatchSystem"] = self.batchSystem.__class__.__name__ options["Method"] = command options["SharedDir"] = self.sharedArea @@ -497,7 +451,7 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): # Upload the options file to the remote host remoteOptionsFile = f"{self.sharedArea}/batch_options_{uuid.uuid4().hex}.json" - result = ssh.scpCall(30, localOptionsFile, remoteOptionsFile) + result = self._put(connection, localOptionsFile, remote=remoteOptionsFile) if not result["OK"]: return result @@ -510,17 +464,10 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): self.log.verbose(f"CE submission command: {cmd}") - result = ssh.sshCall(120, cmd) + result = self._run(connection, cmd) if not result["OK"]: - self.log.error(f"{self.ceType} CE job submission failed", result["Message"]) return result - sshStatus = result["Value"][0] - if sshStatus != 0: - sshStdout = result["Value"][1] - sshStderr = result["Value"][2] - return S_ERROR(f"CE job submission command failed with status {sshStatus}: {sshStdout} {sshStderr}") - # The result should be written to a JSON file by execute_batch # Compute the expected result file path remoteResultFile = remoteOptionsFile.replace(".json", "_result.json") @@ -529,7 +476,7 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): with tempfile.NamedTemporaryFile(mode="r", suffix=".json", delete=False) as f: localResultFile = f.name - result = ssh.scpCall(30, localResultFile, remoteResultFile, upload=False) + result = self._get(connection, remoteResultFile, local=localResultFile) if not result["OK"]: return result @@ -545,28 +492,44 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): os.remove(localResultFile) # Clean up remote temporary files if remoteOptionsFile: - ssh.sshCall(30, f"rm -f {remoteOptionsFile}") + self._run(connection, f"rm -f {remoteOptionsFile}") if remoteResultFile: - ssh.sshCall(30, f"rm -f {remoteResultFile}") + self._run(connection, f"rm -f {remoteResultFile}") def submitJob(self, executableFile, proxy, numberOfJobs=1): - # self.log.verbose( "Executable file path: %s" % executableFile ) if not os.access(executableFile, 5): os.chmod(executableFile, stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) - return self._submitJobToHost(executableFile, numberOfJobs) + result = self._submitJobToHost(self.connection, executableFile, numberOfJobs) + if not result["OK"]: + return result + + batchIDs, jobStamps = result["Value"] + batchSystemName = self.batchSystem.__class__.__name__.lower() + jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] + + result = S_OK(jobIDs) + stampDict = {} + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + result["PilotStampDict"] = stampDict + self.submittedJobs += len(batchIDs) + + return result - def _submitJobToHost(self, executableFile, numberOfJobs, host=None): + def _submitJobToHost(self, connection: Connection, executableFile: str, numberOfJobs: int): """Submit prepared executable to the given host""" - ssh = SSH(host=host, parameters=self.ceParameters) # Copy the executable + self.log.verbose(f"Copying executable to {self.host}") submitFile = os.path.join(self.executableArea, os.path.basename(executableFile)) - result = ssh.scpCall(30, executableFile, submitFile, postUploadCommand=f"chmod +x {submitFile}") + os.chmod(executableFile, 0o755) + + result = self._put(connection, executableFile, submitFile) if not result["OK"]: return result jobStamps = [] - for _i in range(numberOfJobs): + for _ in range(numberOfJobs): jobStamps.append(uuid.uuid4().hex) numberOfProcessors = self.ceParameters.get("NumberOfProcessors", 1) @@ -589,10 +552,8 @@ def _submitJobToHost(self, executableFile, numberOfJobs, host=None): "NumberOfGPUs": self.numberOfGPUs, "Account": self.account, } - if host: - commandOptions["SSHNodeHost"] = host - resultCommand = self.__executeHostCommand("submitJob", commandOptions, ssh=ssh, host=host) + resultCommand = self.__executeHostCommand(connection, "submitJob", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -601,42 +562,29 @@ def _submitJobToHost(self, executableFile, numberOfJobs, host=None): return S_ERROR("Invalid result from job submission") if result["Status"] != 0: return S_ERROR(f"Failed job submission: {result['Message']}") - else: - batchIDs = result["Jobs"] - if batchIDs: - batchSystemName = self.batchSystem.__class__.__name__.lower() - if host is None: - jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] - else: - jobIDs = [ - f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{host}/{_id}" for _id in batchIDs - ] - else: - return S_ERROR("No jobs IDs returned") - result = S_OK(jobIDs) - stampDict = {} - for iJob, jobID in enumerate(jobIDs): - stampDict[jobID] = jobStamps[iJob] - result["PilotStampDict"] = stampDict - self.submittedJobs += len(batchIDs) + batchIDs = result["Jobs"] + if not batchIDs: + return S_ERROR("No jobs IDs returned") - return result + return S_OK((batchIDs, jobStamps)) + + ############################################################################# def killJob(self, jobIDList): """Kill a bunch of jobs""" if isinstance(jobIDList, str): jobIDList = [jobIDList] - return self._killJobOnHost(jobIDList) + return self._killJobOnHost(self.connection, jobIDList) - def _killJobOnHost(self, jobIDList, host=None): + def _killJobOnHost(self, connection: Connection, jobIDList: list[str]): """Kill the jobs for the given list of job IDs""" batchSystemJobList = [] for jobID in jobIDList: batchSystemJobList.append(os.path.basename(urlparse(jobID.split(":::")[0]).path)) commandOptions = {"JobIDList": batchSystemJobList, "User": self.user} - resultCommand = self.__executeHostCommand("killJob", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "killJob", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -651,6 +599,8 @@ def _killJobOnHost(self, jobIDList, host=None): return S_OK(len(result["Successful"])) + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -658,7 +608,7 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - resultHost = self._getHostStatus() + resultHost = self._getHostStatus(self.connection) if not resultHost["OK"]: return resultHost @@ -671,9 +621,9 @@ def getCEStatus(self): return result - def _getHostStatus(self, host=None): + def _getHostStatus(self, connection: Connection): """Get jobs running at a given host""" - resultCommand = self.__executeHostCommand("getCEStatus", {}, host=host) + resultCommand = self.__executeHostCommand(connection, "getCEStatus", {}) if not resultCommand["OK"]: return resultCommand @@ -685,11 +635,13 @@ def _getHostStatus(self, host=None): return S_OK(result) + ############################################################################# + def getJobStatus(self, jobIDList): """Get the status information for the given list of jobs""" - return self._getJobStatusOnHost(jobIDList) + return self._getJobStatusOnHost(self.connection, jobIDList) - def _getJobStatusOnHost(self, jobIDList, host=None): + def _getJobStatusOnHost(self, connection: Connection, jobIDList: list[str]): """Get the status information for the given list of jobs""" resultDict = {} batchSystemJobDict = {} @@ -698,7 +650,7 @@ def _getJobStatusOnHost(self, jobIDList, host=None): batchSystemJobDict[batchSystemJobID] = jobID for jobList in breakListIntoChunks(list(batchSystemJobDict), 100): - resultCommand = self.__executeHostCommand("getJobStatus", {"JobIDList": jobList}, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobStatus", {"JobIDList": jobList}) if not resultCommand["OK"]: return resultCommand @@ -713,65 +665,23 @@ def _getJobStatusOnHost(self, jobIDList, host=None): return S_OK(resultDict) + ############################################################################# + def getJobOutput(self, jobID, localDir=None): """Get the specified job standard output and error files. If the localDir is provided, the output is returned as file in this directory. Otherwise, the output is returned as strings. """ self.log.verbose("Getting output for jobID", jobID) - result = self._getJobOutputFiles(jobID) - if not result["OK"]: - return result - - batchSystemJobID, host, outputFile, errorFile = result["Value"] - - if localDir: - localOutputFile = f"{localDir}/{batchSystemJobID}.out" - localErrorFile = f"{localDir}/{batchSystemJobID}.err" - else: - localOutputFile = "Memory" - localErrorFile = "Memory" - - # Take into account the SSHBatch possible SSHHost syntax - host = host.split("/")[0] - - ssh = SSH(host=host, parameters=self.ceParameters) - resultStdout = ssh.scpCall(30, localOutputFile, outputFile, upload=False) - if not resultStdout["OK"]: - return resultStdout + return self._getJobOutputFilesOnHost(self.connection, jobID, localDir) - resultStderr = ssh.scpCall(30, localErrorFile, errorFile, upload=False) - if not resultStderr["OK"]: - return resultStderr - - if localDir: - output = localOutputFile - error = localErrorFile - else: - output = resultStdout["Value"][1] - error = resultStderr["Value"][1] - - return S_OK((output, error)) - - def _getJobOutputFiles(self, jobID): + def _getJobOutputFilesOnHost(self, connection: Connection, jobID: str, localDir: str | None = None): """Get output file names for the specific CE""" batchSystemJobID = os.path.basename(urlparse(jobID.split(":::")[0]).path) - # host can be retrieved from the path of the jobID - # it might not be present, in this case host is an empty string and will be defined by the CE parameters later - host = os.path.dirname(urlparse(jobID).path).lstrip("/") - - if "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] if self.outputTemplate: - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID - elif "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID + outputFile = self.outputTemplate % batchSystemJobID + errorFile = self.errorTemplate % batchSystemJobID elif hasattr(self.batchSystem, "getJobOutputFiles"): # numberOfNodes is treated as a string as it can contain values such as "2-4" # where 2 would represent the minimum number of nodes to allocate, and 4 the maximum @@ -782,7 +692,7 @@ def _getJobOutputFiles(self, jobID): "ErrorDir": self.batchError, "NumberOfNodes": numberOfNodes, } - resultCommand = self.__executeHostCommand("getJobOutputFiles", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobOutputFiles", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -796,10 +706,32 @@ def _getJobOutputFiles(self, jobID): self.outputTemplate = result["OutputTemplate"] self.errorTemplate = result["ErrorTemplate"] - output = result["Jobs"][batchSystemJobID]["Output"] - error = result["Jobs"][batchSystemJobID]["Error"] + outputFile = result["Jobs"][batchSystemJobID]["Output"] + errorFile = result["Jobs"][batchSystemJobID]["Error"] + else: + outputFile = f"{self.batchOutput}/{batchSystemJobID}.out" + errorFile = f"{self.batchError}/{batchSystemJobID}.err" + + if localDir: + localOutputFile = f"{localDir}/{batchSystemJobID}.out" + localErrorFile = f"{localDir}/{batchSystemJobID}.err" + else: + localOutputFile = "Memory" + localErrorFile = "Memory" + + resultStdout = self._get(connection, outputFile, local=localOutputFile, preserveMode=False) + if not resultStdout["OK"]: + return resultStdout + + resultStderr = self._get(connection, errorFile, local=localErrorFile, preserveMode=False) + if not resultStderr["OK"]: + return resultStderr + + if localDir: + output = localOutputFile + error = localErrorFile else: - output = f"{self.batchOutput}/{batchSystemJobID}.out" - error = f"{self.batchError}/{batchSystemJobID}.err" + output = resultStdout["Value"][1] + error = resultStderr["Value"][1] - return S_OK((batchSystemJobID, host, output, error)) + return S_OK((output, error))