|
30 | 30 |
|
31 | 31 | logger = logging.getLogger(__name__) |
32 | 32 |
|
33 | | -_POSTGRES_ERROR_TO_ERROR_MAP: Final[ |
34 | | - dict[tuple[str, ...], tuple[type[DirectorError], tuple[str, ...]]] |
| 33 | +_POSTGRES_FK_COLUMN_TO_ERROR_MAP: Final[ |
| 34 | + dict[sa.Column, tuple[type[DirectorError], tuple[str, ...]]] |
35 | 35 | ] = { |
36 | | - ("users", "user_id"): (UserNotFoundError, ("users", "user_id")), |
37 | | - ("projects", "project_uuid"): ( |
| 36 | + comp_runs.c.user_id: (UserNotFoundError, ("users", "user_id")), |
| 37 | + comp_runs.c.project_uuid: ( |
38 | 38 | ProjectNotFoundError, |
39 | 39 | ("projects", "project_id"), |
40 | 40 | ), |
41 | | - ("clusters", "cluster_id"): ( |
| 41 | + comp_runs.c.cluster_id: ( |
42 | 42 | ClusterNotFoundError, |
43 | 43 | ("clusters", "cluster_id"), |
44 | 44 | ), |
45 | 45 | } |
| 46 | +_DEFAULT_FK_CONSTRAINT_TO_ERROR: Final[tuple[type[DirectorError], tuple]] = ( |
| 47 | + DirectorError, |
| 48 | + (), |
| 49 | +) |
46 | 50 |
|
47 | 51 |
|
48 | 52 | class CompRunsRepository(BaseRepository): |
@@ -186,10 +190,13 @@ async def create( |
186 | 190 | row = await result.first() |
187 | 191 | return CompRunsAtDB.model_validate(row) |
188 | 192 | except ForeignKeyViolation as exc: |
189 | | - message = exc.args[0] |
190 | | - |
191 | | - for pg_keys, (exc_type, exc_keys) in _POSTGRES_ERROR_TO_ERROR_MAP.items(): |
192 | | - if all(k in message for k in pg_keys): |
| 193 | + assert exc.diag.constraint_name # nosec # noqa: PT017 |
| 194 | + for foreign_key in comp_runs.foreign_keys: |
| 195 | + if exc.diag.constraint_name == foreign_key.name: |
| 196 | + assert foreign_key.parent is not None # nosec |
| 197 | + exc_type, exc_keys = _POSTGRES_FK_COLUMN_TO_ERROR_MAP[ |
| 198 | + foreign_key.parent |
| 199 | + ] |
193 | 200 | raise exc_type( |
194 | 201 | **{f"{k}": locals().get(k) for k in exc_keys} |
195 | 202 | ) from exc |
|
0 commit comments