diff --git a/src/slurmdb.py b/src/slurmdb.py index a7dfa9f..cec95d7 100644 --- a/src/slurmdb.py +++ b/src/slurmdb.py @@ -243,18 +243,33 @@ def _get_tres_map(self): if self._tres_map is None: self.connect() self._tres_map = {} - with self._conn.cursor() as cur: - cur.execute("SELECT id, type, name FROM tres") - for row in cur.fetchall(): - t_type = row.get("type") - t_name = row.get("name") - if t_type == "gres": - name = f"{t_type}/{t_name}" if t_name else t_type - elif t_name: - name = f"{t_type}/{t_name}" - else: - name = t_type - self._tres_map[row["id"]] = name + for table in ("tres", "tres_table"): + try: + with self._conn.cursor() as cur: + cur.execute(f"SELECT id, type, name FROM {table}") + for row in cur.fetchall(): + t_type = row.get("type") + t_name = row.get("name") + if t_type == "gres": + name = f"{t_type}/{t_name}" if t_name else t_type + elif t_name: + name = f"{t_type}/{t_name}" + else: + name = t_type + self._tres_map[row["id"]] = name + break + except Exception as e: + # Older Slurm databases might use a different table name + # or not support TRES at all. If a table is missing, + # fall back to the next option or return an empty map. + if not ( + pymysql + and isinstance(e, pymysql.err.ProgrammingError) + and e.args + and e.args[0] == 1146 + ): + raise + # If both queries failed we leave the map empty. return self._tres_map def _tres_to_str(self, tres_str): diff --git a/test/unit/slurmdb_validation.test.py b/test/unit/slurmdb_validation.test.py index 36192d4..351d58e 100644 --- a/test/unit/slurmdb_validation.test.py +++ b/test/unit/slurmdb_validation.test.py @@ -1,5 +1,6 @@ import unittest import json +import pymysql from slurmdb import SlurmDB from slurm_schema import extract_schema, extract_schema_from_dump @@ -132,6 +133,60 @@ def fake_connect(): self.assertTrue(conn.closed) self.assertIsNone(db._conn) + def test_get_tres_map_handles_missing_table(self): + db = SlurmDB() + db.connect = lambda: None + + class FakeCursor: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + def execute(self, query): + raise pymysql.err.ProgrammingError(1146, "Table 'slurm_acct_db.tres' doesn't exist") + + def fetchall(self): + return [] + + class FakeConn: + def cursor(self): + return FakeCursor() + + db._conn = FakeConn() + tmap = db._get_tres_map() + self.assertEqual(tmap, {}) + + def test_get_tres_map_uses_tres_table_when_missing_tres(self): + db = SlurmDB() + db.connect = lambda: None + + class FakeCursor: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + def execute(self, query): + self.query = query + if "tres_table" not in query: + raise pymysql.err.ProgrammingError(1146, "Table 'slurm_acct_db.tres' doesn't exist") + + def fetchall(self): + if "tres_table" in self.query: + return [{"id": 1, "type": "cpu", "name": ""}] + return [] + + class FakeConn: + def cursor(self): + return FakeCursor() + + db._conn = FakeConn() + tmap = db._get_tres_map() + self.assertEqual(tmap, {1: "cpu"}) + def test_extract_schema_with_context_manager_closes_connection(self): class FakeCursor: def __init__(self, dbname):