diff --git a/pyalma/fileReader.py b/pyalma/fileReader.py index 4e5a522..27b305a 100644 --- a/pyalma/fileReader.py +++ b/pyalma/fileReader.py @@ -2,6 +2,7 @@ import pandas as pd import os from io import StringIO +import yaml from .pdfreader import read_pdf_to_dataframe from .anndatareader import read_adata class FileReader: @@ -82,3 +83,10 @@ def isfile(self, path): def get_file_size(self, path): pass + + def _load_yaml_file(custom_yaml): + if not custom_yaml: + print("No yaml was provided") + return {}#empty dict + with open(custom_yaml, "r") as file: + return yaml.safe_load(file) \ No newline at end of file diff --git a/pyalma/ssh.py b/pyalma/ssh.py index 2fa8fa4..b27c3d4 100644 --- a/pyalma/ssh.py +++ b/pyalma/ssh.py @@ -12,7 +12,7 @@ logging.basicConfig(level=logging.DEBUG) class SshClient(FileReader): - def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp="alma-app.icr.ac.uk"): + def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp="alma-app.icr.ac.uk", custom_yaml=None): """Initializes the SSH connection instance. :param username: SSH username for alma :param password: SSH password for alma @@ -26,7 +26,37 @@ def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp=" self.password = password.strip() if password else None self.filter_file = os.path.join(os.path.dirname(__file__), "config", "messages.yaml") self.filtered_patterns = self._load_filtered_patterns() + self.config = self._load_yaml_file(custom_yaml) self._connect() + self.groups = self.retrieve_user_groups() + + def get_user_paths(self, group, locations=["scratch", "rds"]): + if not self.config or "group_paths" not in self.config: + raise ValueError("Missing config: self.config must be set before calling get_user_paths.") + + group_config = self.config["group_paths"].get(group, self.config["group_paths"].get("default")) + + paths = {} + for loc in locations: + if loc in group_config: + formatted_path = group_config[loc].format(username=self.username, group=group) + paths[loc] = formatted_path + print(f"{loc}: {formatted_path}") + + return paths + + # to be used by lib users + def get_user_groups(self): + return self.groups + + def retrieve_user_groups(self): + cmd_usr = "sacctmgr list association user=$USER format=Account -P | tail -n +2" + results = self.run_cmd(cmd_usr) + + if results["err"] is not None: + raise Exception(f"❌ [retrieve_user_groups]: Couldn't retrieve user groups") + + return results["output"].strip().split("\n") def _connect(self): # should be called only at initialisation @@ -38,6 +68,7 @@ def _connect(self): #one single connection self.ssh_client.connect(self.sftp, username=self.username, password=self.password, timeout=30) self.sftp_client = self.ssh_client.open_sftp() + except paramiko.AuthenticationException: raise ConnectionError(f"❌ [_connect]: Authentication failed for {self.username}@{self.server}. Please check your credentials.")