diff --git a/app/migrated_routes/category.py b/app/migrated_routes/category.py index 0ee3024..b001f9d 100644 --- a/app/migrated_routes/category.py +++ b/app/migrated_routes/category.py @@ -1,7 +1,9 @@ from flask.views import MethodView from flask_jwt_extended import jwt_required from flask_smorest import Blueprint, abort -from sqlalchemy import exists +from psycopg2.errors import UniqueViolation +from sqlalchemy import UniqueConstraint, exists +from sqlalchemy.exc import IntegrityError from app import db from app.models import ( @@ -9,7 +11,6 @@ Product, Subcategory, category_subcategory, - subcategory_product, ) from app.schemas import ( CategoriesOut, @@ -27,6 +28,19 @@ class CategoryCollection(MethodView): init_every_request = False + @staticmethod + def _get_name_unique_constraint(): + name_col = Category.__table__.c.name + return next( + con + for con in Category.__table__.constraints + if isinstance(con, UniqueConstraint) + and len(con.columns) == 1 + and con.columns.contains_column(name_col) + ) + + _NAME_UNIQUE_CONSTRAINT = _get_name_unique_constraint() + @bp.response(200, CategoriesOut) def get(self): """ @@ -88,8 +102,18 @@ def post(self, data): category.subcategories = subcategories - db.session.add(category) - db.session.commit() + try: + db.session.add(category) + db.session.commit() + except IntegrityError as ie: + db.session.rollback() + if ( + isinstance(ie.orig, UniqueViolation) + and ie.orig.diag.constraint_name + == CategoryCollection._NAME_UNIQUE_CONSTRAINT.name + ): + abort(409, message="Category with this name already exists") + raise return category @@ -176,7 +200,18 @@ def put(self, data, id): category.subcategories.extend(subcategories) - db.session.commit() + try: + db.session.commit() + except IntegrityError as ie: + db.session.rollback() + if ( + isinstance(ie.orig, UniqueViolation) + and ie.orig.diag.constraint_name + == category_subcategory.primary_key.name + ): + abort(409, message="Category and subcategory already linked") + raise + return category @jwt_required() @@ -277,14 +312,9 @@ def get(self, id, page): abort(404) products = ( - Product.query.join(subcategory_product) - .join( - category_subcategory, - onclause=subcategory_product.c.subcategory_id - == category_subcategory.c.subcategory_id, + Product.query.filter( + Product.subcategories.any(Subcategory.categories.any(id=id)) ) - .filter(category_subcategory.c.category_id == id) - .distinct() .order_by(Product.id.asc()) .paginate(page=page, per_page=CategoryProducts._PER_PAGE, error_out=False) ) diff --git a/tests/test_relationships.py b/tests/test_relationships.py index 6b9bc02..610b3e4 100644 --- a/tests/test_relationships.py +++ b/tests/test_relationships.py @@ -1,4 +1,7 @@ +import sqlite3 + import pytest +from sqlalchemy.exc import IntegrityError from app.models import Category, Product, Subcategory @@ -96,6 +99,18 @@ def test_update_category_adds_subcategories(self, create_authenticated_headers, assert self._category_subcategory_ids(category["id"]) == sorted([subcategory1["id"], subcategory2["id"]]) + def test_update_category_adds_linked_subcategories(self, create_authenticated_headers, create_category, create_subcategory): + headers = create_authenticated_headers() + subcategory = create_subcategory("U_SC1", headers=headers).get_json() + category = create_category("U_Cat", subcategories=[subcategory["id"]], headers=headers).get_json() + + with pytest.raises(IntegrityError) as ie: + self.client.put(f"/categories/{category['id']}", json={"subcategories": [subcategory["id"]]}, headers=headers) + + assert isinstance(ie.value.orig, sqlite3.IntegrityError) + assert "UNIQUE constraint failed" in str(ie.value.orig) + assert self._category_subcategory_ids(category["id"]) == [subcategory["id"]] + def test_update_subcategory_adds_categories_and_products(self, create_authenticated_headers, create_category, create_product, create_subcategory): category1 = create_category("UC1").get_json() category2 = create_category("UC2").get_json()