Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/slurmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,33 @@ def _get_tres_map(self):
if self._tres_map is None:
self.connect()
self._tres_map = {}
with self._conn.cursor() as cur:
cur.execute("SELECT id, type, name FROM tres")
for row in cur.fetchall():
t_type = row.get("type")
t_name = row.get("name")
if t_type == "gres":
name = f"{t_type}/{t_name}" if t_name else t_type
elif t_name:
name = f"{t_type}/{t_name}"
else:
name = t_type
self._tres_map[row["id"]] = name
for table in ("tres", "tres_table"):
try:
with self._conn.cursor() as cur:
cur.execute(f"SELECT id, type, name FROM {table}")
for row in cur.fetchall():
t_type = row.get("type")
t_name = row.get("name")
if t_type == "gres":
name = f"{t_type}/{t_name}" if t_name else t_type
elif t_name:
name = f"{t_type}/{t_name}"
else:
name = t_type
self._tres_map[row["id"]] = name
break
except Exception as e:
# Older Slurm databases might use a different table name
# or not support TRES at all. If a table is missing,
# fall back to the next option or return an empty map.
if not (
pymysql
and isinstance(e, pymysql.err.ProgrammingError)
and e.args
and e.args[0] == 1146
):
raise
# If both queries failed we leave the map empty.
return self._tres_map

def _tres_to_str(self, tres_str):
Expand Down
55 changes: 55 additions & 0 deletions test/unit/slurmdb_validation.test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import json
import pymysql
from slurmdb import SlurmDB
from slurm_schema import extract_schema, extract_schema_from_dump

Expand Down Expand Up @@ -132,6 +133,60 @@ def fake_connect():
self.assertTrue(conn.closed)
self.assertIsNone(db._conn)

def test_get_tres_map_handles_missing_table(self):
db = SlurmDB()
db.connect = lambda: None

class FakeCursor:
def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
pass

def execute(self, query):
raise pymysql.err.ProgrammingError(1146, "Table 'slurm_acct_db.tres' doesn't exist")

def fetchall(self):
return []

class FakeConn:
def cursor(self):
return FakeCursor()

db._conn = FakeConn()
tmap = db._get_tres_map()
self.assertEqual(tmap, {})

def test_get_tres_map_uses_tres_table_when_missing_tres(self):
db = SlurmDB()
db.connect = lambda: None

class FakeCursor:
def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
pass

def execute(self, query):
self.query = query
if "tres_table" not in query:
raise pymysql.err.ProgrammingError(1146, "Table 'slurm_acct_db.tres' doesn't exist")

def fetchall(self):
if "tres_table" in self.query:
return [{"id": 1, "type": "cpu", "name": ""}]
return []

class FakeConn:
def cursor(self):
return FakeCursor()

db._conn = FakeConn()
tmap = db._get_tres_map()
self.assertEqual(tmap, {1: "cpu"})

def test_extract_schema_with_context_manager_closes_connection(self):
class FakeCursor:
def __init__(self, dbname):
Expand Down
Loading