|
1 | 1 | from flask.views import MethodView |
2 | 2 | from flask_jwt_extended import jwt_required |
3 | 3 | from flask_smorest import Blueprint, abort |
4 | | -from sqlalchemy import exists |
| 4 | +from psycopg2.errors import UniqueViolation |
| 5 | +from sqlalchemy import UniqueConstraint, exists |
| 6 | +from sqlalchemy.exc import IntegrityError |
5 | 7 |
|
6 | 8 | from app import db |
7 | 9 | from app.models import ( |
|
27 | 29 | class CategoryCollection(MethodView): |
28 | 30 | init_every_request = False |
29 | 31 |
|
| 32 | + @staticmethod |
| 33 | + def _get_name_unique_constraint(): |
| 34 | + name_col = Category.__table__.c.name |
| 35 | + return next( |
| 36 | + con |
| 37 | + for con in Category.__table__.constraints |
| 38 | + if isinstance(con, UniqueConstraint) |
| 39 | + and len(con.columns) == 1 |
| 40 | + and con.columns.contains_column(name_col) |
| 41 | + ) |
| 42 | + |
| 43 | + _NAME_UNIQUE_CONSTRAINT = _get_name_unique_constraint() |
| 44 | + |
30 | 45 | @bp.response(200, CategoriesOut) |
31 | 46 | def get(self): |
32 | 47 | """ |
@@ -88,8 +103,18 @@ def post(self, data): |
88 | 103 |
|
89 | 104 | category.subcategories = subcategories |
90 | 105 |
|
91 | | - db.session.add(category) |
92 | | - db.session.commit() |
| 106 | + try: |
| 107 | + db.session.add(category) |
| 108 | + db.session.commit() |
| 109 | + except IntegrityError as ie: |
| 110 | + db.session.rollback() |
| 111 | + if ( |
| 112 | + isinstance(ie.orig, UniqueViolation) |
| 113 | + and ie.orig.diag.constraint_name |
| 114 | + == CategoryCollection._NAME_UNIQUE_CONSTRAINT.name |
| 115 | + ): |
| 116 | + abort(409, message="Category with this name already exists") |
| 117 | + raise ie |
93 | 118 |
|
94 | 119 | return category |
95 | 120 |
|
@@ -176,7 +201,18 @@ def put(self, data, id): |
176 | 201 |
|
177 | 202 | category.subcategories.extend(subcategories) |
178 | 203 |
|
179 | | - db.session.commit() |
| 204 | + try: |
| 205 | + db.session.commit() |
| 206 | + except IntegrityError as ie: |
| 207 | + db.session.rollback() |
| 208 | + if ( |
| 209 | + isinstance(ie.orig, UniqueViolation) |
| 210 | + and ie.orig.diag.constraint_name |
| 211 | + == category_subcategory.primary_key.name |
| 212 | + ): |
| 213 | + abort(409, message="Category and subcategory already linked") |
| 214 | + raise ie |
| 215 | + |
180 | 216 | return category |
181 | 217 |
|
182 | 218 | @jwt_required() |
|
0 commit comments