|
14 | 14 |
|
15 | 15 | STATE_FILE = os.path.expanduser("last_run.json") |
16 | 16 |
|
| 17 | +JOB_STATE_MAP = { |
| 18 | + 0: "PENDING", |
| 19 | + 1: "RUNNING", |
| 20 | + 2: "SUSPENDED", |
| 21 | + 3: "COMPLETED", |
| 22 | + 4: "CANCELLED", |
| 23 | + 5: "FAILED", |
| 24 | + 6: "TIMEOUT", |
| 25 | + 7: "NODE_FAIL", |
| 26 | + 8: "PREEMPTED", |
| 27 | + 9: "BOOT_FAIL", |
| 28 | + 10: "DEADLINE", |
| 29 | + 11: "OUT_OF_MEMORY", |
| 30 | +} |
| 31 | + |
17 | 32 |
|
18 | 33 | def _read_last_run(): |
19 | 34 | """Return the last processed end date from the state file.""" |
@@ -58,6 +73,7 @@ def __init__( |
58 | 73 | self.password = password or os.environ.get("SLURMDB_PASS") or cfg.get("password", "") |
59 | 74 | self.database = database or os.environ.get("SLURMDB_DB") or cfg.get("db", "slurm_acct_db") |
60 | 75 | self._conn = None |
| 76 | + self._tres_map = None |
61 | 77 | self._config_file = conf_path |
62 | 78 | self._slurm_conf = slurm_conf or os.environ.get("SLURM_CONF", "/etc/slurm/slurm.conf") |
63 | 79 | self.cluster = ( |
@@ -223,6 +239,46 @@ def _parse_tres(self, tres_str, key): |
223 | 239 | return 0 |
224 | 240 | return 0 |
225 | 241 |
|
| 242 | + def _get_tres_map(self): |
| 243 | + if self._tres_map is None: |
| 244 | + self.connect() |
| 245 | + self._tres_map = {} |
| 246 | + with self._conn.cursor() as cur: |
| 247 | + cur.execute("SELECT id, type, name FROM tres") |
| 248 | + for row in cur.fetchall(): |
| 249 | + t_type = row.get("type") |
| 250 | + t_name = row.get("name") |
| 251 | + if t_type == "gres": |
| 252 | + name = f"{t_type}/{t_name}" if t_name else t_type |
| 253 | + elif t_name: |
| 254 | + name = f"{t_type}/{t_name}" |
| 255 | + else: |
| 256 | + name = t_type |
| 257 | + self._tres_map[row["id"]] = name |
| 258 | + return self._tres_map |
| 259 | + |
| 260 | + def _tres_to_str(self, tres_str): |
| 261 | + if not tres_str: |
| 262 | + return "" |
| 263 | + tmap = self._get_tres_map() |
| 264 | + parts = [] |
| 265 | + for part in str(tres_str).split(','): |
| 266 | + if '=' not in part: |
| 267 | + continue |
| 268 | + key, val = part.split('=', 1) |
| 269 | + try: |
| 270 | + name = tmap.get(int(key), key) |
| 271 | + except ValueError: |
| 272 | + name = key |
| 273 | + parts.append(f"{name}={val}") |
| 274 | + return ','.join(parts) |
| 275 | + |
| 276 | + def _state_to_str(self, state): |
| 277 | + try: |
| 278 | + return JOB_STATE_MAP[int(state)] |
| 279 | + except (TypeError, ValueError, KeyError): |
| 280 | + return state |
| 281 | + |
226 | 282 | def fetch_usage_records(self, start_time, end_time): |
227 | 283 | """Fetch raw job records from SlurmDBD.""" |
228 | 284 | start_time = self._validate_time(start_time, "start_time") |
@@ -254,7 +310,12 @@ def fetch_usage_records(self, start_time, end_time): |
254 | 310 | f"WHERE j.time_start >= %s AND j.time_end <= %s" |
255 | 311 | ) |
256 | 312 | cur.execute(query, (start_time, end_time)) |
257 | | - return cur.fetchall() |
| 313 | + rows = cur.fetchall() |
| 314 | + for row in rows: |
| 315 | + row["tres_req"] = self._tres_to_str(row.get("tres_req")) |
| 316 | + row["tres_alloc"] = self._tres_to_str(row.get("tres_alloc")) |
| 317 | + row["state"] = self._state_to_str(row.get("state")) |
| 318 | + return rows |
258 | 319 |
|
259 | 320 | def aggregate_usage(self, start_time, end_time): |
260 | 321 | """Aggregate usage metrics by account and time period.""" |
|
0 commit comments