|
1 | 1 | import unittest |
2 | 2 | import json |
3 | 3 | from slurmdb import SlurmDB |
4 | | -from slurm_schema import extract_schema |
| 4 | +from slurm_schema import extract_schema, extract_schema_from_dump |
5 | 5 |
|
6 | 6 | class SlurmDBValidationTests(unittest.TestCase): |
7 | 7 | def test_invalid_cluster_rejected(self): |
@@ -181,7 +181,14 @@ def fake_connect(): |
181 | 181 | def test_fetch_usage_records_uses_cpus_req_if_alloc_missing(self): |
182 | 182 | with open('test/example_slurm_schema_for_testing.json') as fh: |
183 | 183 | schema = json.load(fh) |
| 184 | + schema_sql = extract_schema_from_dump('test/example_slurmdb_for_testing.sql') |
184 | 185 | job_cols = schema.get('localcluster_job_table', []) |
| 186 | + job_cols_sql = schema_sql.get('localcluster_job_table', []) |
| 187 | + # ensure schema JSON and SQL dump agree on CPU columns |
| 188 | + self.assertIn('cpus_req', job_cols) |
| 189 | + self.assertIn('cpus_req', job_cols_sql) |
| 190 | + self.assertNotIn('cpus_alloc', job_cols) |
| 191 | + self.assertNotIn('cpus_alloc', job_cols_sql) |
185 | 192 |
|
186 | 193 | class FakeCursor: |
187 | 194 | def __init__(self): |
@@ -222,7 +229,60 @@ def cursor(self): |
222 | 229 | db.connect = lambda: None |
223 | 230 | db.fetch_usage_records(0, 0) |
224 | 231 | queries = db._conn.cursor_obj.queries |
225 | | - self.assertIn("j.cpus_req AS cpus_alloc", queries[1]) |
| 232 | + self.assertIn("j.cpus_req AS cpus_alloc", queries[-1]) |
| 233 | + |
| 234 | + def test_fetch_usage_records_uses_job_name_column(self): |
| 235 | + with open('test/example_slurm_schema_for_testing.json') as fh: |
| 236 | + schema = json.load(fh) |
| 237 | + schema_sql = extract_schema_from_dump('test/example_slurmdb_for_testing.sql') |
| 238 | + job_cols = schema.get('localcluster_job_table', []) |
| 239 | + job_cols_sql = schema_sql.get('localcluster_job_table', []) |
| 240 | + # ensure schema JSON and SQL dump agree on job name column |
| 241 | + self.assertIn('job_name', job_cols) |
| 242 | + self.assertIn('job_name', job_cols_sql) |
| 243 | + self.assertNotIn('name', job_cols) |
| 244 | + self.assertNotIn('name', job_cols_sql) |
| 245 | + |
| 246 | + class FakeCursor: |
| 247 | + def __init__(self): |
| 248 | + self.queries = [] |
| 249 | + |
| 250 | + def execute(self, query, params=None): |
| 251 | + self.queries.append(query) |
| 252 | + if query.lower().startswith("show columns"): |
| 253 | + column = params[0] if params else None |
| 254 | + if column in job_cols: |
| 255 | + self._fetchone = {'Field': column} |
| 256 | + else: |
| 257 | + self._fetchone = None |
| 258 | + else: |
| 259 | + self._fetchall = [] |
| 260 | + |
| 261 | + def fetchone(self): |
| 262 | + return getattr(self, "_fetchone", None) |
| 263 | + |
| 264 | + def fetchall(self): |
| 265 | + return getattr(self, "_fetchall", []) |
| 266 | + |
| 267 | + def __enter__(self): |
| 268 | + return self |
| 269 | + |
| 270 | + def __exit__(self, exc_type, exc, tb): |
| 271 | + pass |
| 272 | + |
| 273 | + class FakeConn: |
| 274 | + def __init__(self): |
| 275 | + self.cursor_obj = FakeCursor() |
| 276 | + |
| 277 | + def cursor(self): |
| 278 | + return self.cursor_obj |
| 279 | + |
| 280 | + db = SlurmDB(cluster="localcluster") |
| 281 | + db._conn = FakeConn() |
| 282 | + db.connect = lambda: None |
| 283 | + db.fetch_usage_records(0, 0) |
| 284 | + queries = db._conn.cursor_obj.queries |
| 285 | + self.assertIn("j.job_name AS job_name", queries[-1]) |
226 | 286 |
|
227 | 287 | if __name__ == '__main__': |
228 | 288 | unittest.main() |
0 commit comments