Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions app/migrated_routes/category.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from flask.views import MethodView
from flask_jwt_extended import jwt_required
from flask_smorest import Blueprint, abort
from sqlalchemy import exists
from psycopg2.errors import UniqueViolation
from sqlalchemy import UniqueConstraint, exists
from sqlalchemy.exc import IntegrityError

from app import db
from app.models import (
Category,
Product,
Subcategory,
category_subcategory,
subcategory_product,
)
from app.schemas import (
CategoriesOut,
Expand All @@ -27,6 +28,19 @@
class CategoryCollection(MethodView):
init_every_request = False

@staticmethod
def _get_name_unique_constraint():
name_col = Category.__table__.c.name
return next(
con
for con in Category.__table__.constraints
if isinstance(con, UniqueConstraint)
and len(con.columns) == 1
and con.columns.contains_column(name_col)
)

_NAME_UNIQUE_CONSTRAINT = _get_name_unique_constraint()

@bp.response(200, CategoriesOut)
def get(self):
"""
Expand Down Expand Up @@ -88,8 +102,18 @@ def post(self, data):

category.subcategories = subcategories

db.session.add(category)
db.session.commit()
try:
db.session.add(category)
db.session.commit()
except IntegrityError as ie:
db.session.rollback()
if (
isinstance(ie.orig, UniqueViolation)
and ie.orig.diag.constraint_name
== CategoryCollection._NAME_UNIQUE_CONSTRAINT.name
):
abort(409, message="Category with this name already exists")
raise

return category

Expand Down Expand Up @@ -176,7 +200,18 @@ def put(self, data, id):

category.subcategories.extend(subcategories)

db.session.commit()
try:
db.session.commit()
except IntegrityError as ie:
db.session.rollback()
if (
isinstance(ie.orig, UniqueViolation)
and ie.orig.diag.constraint_name
== category_subcategory.primary_key.name
):
abort(409, message="Category and subcategory already linked")
raise

return category

@jwt_required()
Expand Down Expand Up @@ -277,14 +312,9 @@ def get(self, id, page):
abort(404)

products = (
Product.query.join(subcategory_product)
.join(
category_subcategory,
onclause=subcategory_product.c.subcategory_id
== category_subcategory.c.subcategory_id,
Product.query.filter(
Product.subcategories.any(Subcategory.categories.any(id=id))
)
.filter(category_subcategory.c.category_id == id)
.distinct()
.order_by(Product.id.asc())
.paginate(page=page, per_page=CategoryProducts._PER_PAGE, error_out=False)
)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_relationships.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import sqlite3

import pytest
from sqlalchemy.exc import IntegrityError

from app.models import Category, Product, Subcategory

Expand Down Expand Up @@ -96,6 +99,18 @@ def test_update_category_adds_subcategories(self, create_authenticated_headers,

assert self._category_subcategory_ids(category["id"]) == sorted([subcategory1["id"], subcategory2["id"]])

def test_update_category_adds_linked_subcategories(self, create_authenticated_headers, create_category, create_subcategory):
headers = create_authenticated_headers()
subcategory = create_subcategory("U_SC1", headers=headers).get_json()
category = create_category("U_Cat", subcategories=[subcategory["id"]], headers=headers).get_json()

with pytest.raises(IntegrityError) as ie:
self.client.put(f"/categories/{category['id']}", json={"subcategories": [subcategory["id"]]}, headers=headers)

assert isinstance(ie.value.orig, sqlite3.IntegrityError)
assert "UNIQUE constraint failed" in str(ie.value.orig)
assert self._category_subcategory_ids(category["id"]) == [subcategory["id"]]

def test_update_subcategory_adds_categories_and_products(self, create_authenticated_headers, create_category, create_product, create_subcategory):
category1 = create_category("UC1").get_json()
category2 = create_category("UC2").get_json()
Expand Down