Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/slurmcostmanager.css
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ nav button:hover {

.table-container {
overflow-x: auto;
width: 350px;
margin: 1em;
width: 100%;
margin: 1em 0;
}

.summary-table th {
Expand Down
63 changes: 62 additions & 1 deletion src/slurmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@

STATE_FILE = os.path.expanduser("last_run.json")

JOB_STATE_MAP = {
0: "PENDING",
1: "RUNNING",
2: "SUSPENDED",
3: "COMPLETED",
4: "CANCELLED",
5: "FAILED",
6: "TIMEOUT",
7: "NODE_FAIL",
8: "PREEMPTED",
9: "BOOT_FAIL",
10: "DEADLINE",
11: "OUT_OF_MEMORY",
}


def _read_last_run():
"""Return the last processed end date from the state file."""
Expand Down Expand Up @@ -58,6 +73,7 @@ def __init__(
self.password = password or os.environ.get("SLURMDB_PASS") or cfg.get("password", "")
self.database = database or os.environ.get("SLURMDB_DB") or cfg.get("db", "slurm_acct_db")
self._conn = None
self._tres_map = None
self._config_file = conf_path
self._slurm_conf = slurm_conf or os.environ.get("SLURM_CONF", "/etc/slurm/slurm.conf")
self.cluster = (
Expand Down Expand Up @@ -223,6 +239,46 @@ def _parse_tres(self, tres_str, key):
return 0
return 0

def _get_tres_map(self):
if self._tres_map is None:
self.connect()
self._tres_map = {}
with self._conn.cursor() as cur:
cur.execute("SELECT id, type, name FROM tres")
for row in cur.fetchall():
t_type = row.get("type")
t_name = row.get("name")
if t_type == "gres":
name = f"{t_type}/{t_name}" if t_name else t_type
elif t_name:
name = f"{t_type}/{t_name}"
else:
name = t_type
self._tres_map[row["id"]] = name
return self._tres_map

def _tres_to_str(self, tres_str):
if not tres_str:
return ""
tmap = self._get_tres_map()
parts = []
for part in str(tres_str).split(','):
if '=' not in part:
continue
key, val = part.split('=', 1)
try:
name = tmap.get(int(key), key)
except ValueError:
name = key
parts.append(f"{name}={val}")
return ','.join(parts)

def _state_to_str(self, state):
try:
return JOB_STATE_MAP[int(state)]
except (TypeError, ValueError, KeyError):
return state

def fetch_usage_records(self, start_time, end_time):
"""Fetch raw job records from SlurmDBD."""
start_time = self._validate_time(start_time, "start_time")
Expand Down Expand Up @@ -254,7 +310,12 @@ def fetch_usage_records(self, start_time, end_time):
f"WHERE j.time_start >= %s AND j.time_end <= %s"
)
cur.execute(query, (start_time, end_time))
return cur.fetchall()
rows = cur.fetchall()
for row in rows:
row["tres_req"] = self._tres_to_str(row.get("tres_req"))
row["tres_alloc"] = self._tres_to_str(row.get("tres_alloc"))
row["state"] = self._state_to_str(row.get("state"))
return rows

def aggregate_usage(self, start_time, end_time):
"""Aggregate usage metrics by account and time period."""
Expand Down
Loading