|
1 | 1 | import datetime |
2 | 2 | import logging |
3 | | -from typing import Any |
| 3 | +from typing import Any, Final |
4 | 4 |
|
5 | 5 | import arrow |
6 | 6 | import sqlalchemy as sa |
|
20 | 20 | ClusterNotFoundError, |
21 | 21 | ComputationalRunNotFoundError, |
22 | 22 | DirectorError, |
23 | | - ProjectNotFoundError, |
24 | 23 | UserNotFoundError, |
25 | 24 | ) |
26 | 25 | from ....models.comp_runs import CompRunsAtDB, RunMetadataDict |
|
30 | 29 |
|
31 | 30 | logger = logging.getLogger(__name__) |
32 | 31 |
|
| 32 | +_POSTGRES_ERROR_TO_ERROR_MAP: Final[ |
| 33 | + dict[tuple[str, ...], tuple[type[DirectorError], tuple[str, ...]]] |
| 34 | +] = { |
| 35 | + ("users", "user_id"): (UserNotFoundError, ("users", "user_id")), |
| 36 | + ("projects", "project_uuid"): ( |
| 37 | + UserNotFoundError, |
| 38 | + ("projects", "project_id"), |
| 39 | + ), |
| 40 | + ("clusters", "cluster_id"): ( |
| 41 | + ClusterNotFoundError, |
| 42 | + ("clusters", "cluster_id"), |
| 43 | + ), |
| 44 | +} |
| 45 | + |
33 | 46 |
|
34 | 47 | class CompRunsRepository(BaseRepository): |
35 | 48 | async def get( |
@@ -173,15 +186,13 @@ async def create( |
173 | 186 | return CompRunsAtDB.model_validate(row) |
174 | 187 | except ForeignKeyViolation as exc: |
175 | 188 | message = exc.args[0] |
176 | | - match message: |
177 | | - case s if "users" in s and "user_id" in s: |
178 | | - raise UserNotFoundError(user_id=user_id) from exc |
179 | | - case s if "projects" in s and "project_uuid" in s: |
180 | | - raise ProjectNotFoundError(project_id=project_id) from exc |
181 | | - case s if "clusters" in s and "cluster_id" in s: |
182 | | - raise ClusterNotFoundError(cluster_id=cluster_id) from exc |
183 | | - case _: |
184 | | - raise DirectorError from exc |
| 189 | + |
| 190 | + for pg_keys, (exc_type, exc_keys) in _POSTGRES_ERROR_TO_ERROR_MAP.items(): |
| 191 | + if all(k in message for k in pg_keys): |
| 192 | + raise exc_type( |
| 193 | + **{f"{k}": locals().get(k) for k in exc_keys} |
| 194 | + ) from exc |
| 195 | + raise DirectorError from exc |
185 | 196 |
|
186 | 197 | async def update( |
187 | 198 | self, user_id: UserID, project_id: ProjectID, iteration: PositiveInt, **values |
|
0 commit comments