Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyalma/fileReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import os
from io import StringIO
import yaml
from .pdfreader import read_pdf_to_dataframe
from .anndatareader import read_adata
class FileReader:
Expand Down Expand Up @@ -82,3 +83,10 @@ def isfile(self, path):

def get_file_size(self, path):
pass

def _load_yaml_file(custom_yaml):
if not custom_yaml:
print("No yaml was provided")
return {}#empty dict
with open(custom_yaml, "r") as file:
return yaml.safe_load(file)
33 changes: 32 additions & 1 deletion pyalma/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logging.basicConfig(level=logging.DEBUG)

class SshClient(FileReader):
def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp="alma-app.icr.ac.uk"):
def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp="alma-app.icr.ac.uk", custom_yaml=None):
"""Initializes the SSH connection instance.
:param username: SSH username for alma
:param password: SSH password for alma
Expand All @@ -26,7 +26,37 @@ def __init__(self, server="alma.icr.ac.uk", username=None, password=None, sftp="
self.password = password.strip() if password else None
self.filter_file = os.path.join(os.path.dirname(__file__), "config", "messages.yaml")
self.filtered_patterns = self._load_filtered_patterns()
self.config = self._load_yaml_file(custom_yaml)
self._connect()
self.groups = self.retrieve_user_groups()

def get_user_paths(self, group, locations=["scratch", "rds"]):
if not self.config or "group_paths" not in self.config:
raise ValueError("Missing config: self.config must be set before calling get_user_paths.")

group_config = self.config["group_paths"].get(group, self.config["group_paths"].get("default"))

paths = {}
for loc in locations:
if loc in group_config:
formatted_path = group_config[loc].format(username=self.username, group=group)
paths[loc] = formatted_path
print(f"{loc}: {formatted_path}")

return paths

# to be used by lib users
def get_user_groups(self):
return self.groups

def retrieve_user_groups(self):
cmd_usr = "sacctmgr list association user=$USER format=Account -P | tail -n +2"
results = self.run_cmd(cmd_usr)

if results["err"] is not None:
raise Exception(f"❌ [retrieve_user_groups]: Couldn't retrieve user groups")

return results["output"].strip().split("\n")

def _connect(self):
# should be called only at initialisation
Expand All @@ -38,6 +68,7 @@ def _connect(self):
#one single connection
self.ssh_client.connect(self.sftp, username=self.username, password=self.password, timeout=30)
self.sftp_client = self.ssh_client.open_sftp()

except paramiko.AuthenticationException:
raise ConnectionError(f"❌ [_connect]: Authentication failed for {self.username}@{self.server}. Please check your credentials.")

Expand Down