diff --git a/promo_code/business/managers.py b/promo_code/business/managers.py index e8851b2..0ddc547 100644 --- a/promo_code/business/managers.py +++ b/promo_code/business/managers.py @@ -22,6 +22,33 @@ def create_company(self, email, name, password=None, **extra_fields): class PromoManager(django.db.models.Manager): + def get_queryset(self): + return super().get_queryset() + + def with_related(self): + return ( + self.select_related('company') + .prefetch_related('unique_codes') + .only( + 'id', + 'company', + 'description', + 'image_url', + 'target', + 'max_count', + 'active_from', + 'active_until', + 'mode', + 'promo_common', + 'created_at', + 'company__id', + 'company__name', + ) + ) + + def for_company(self, user): + return self.with_related().filter(company=user) + @django.db.transaction.atomic def create_promo( self, diff --git a/promo_code/business/pagination.py b/promo_code/business/pagination.py index 58604fb..68eb65a 100644 --- a/promo_code/business/pagination.py +++ b/promo_code/business/pagination.py @@ -1,4 +1,3 @@ -import rest_framework.exceptions import rest_framework.pagination import rest_framework.response diff --git a/promo_code/business/permissions.py b/promo_code/business/permissions.py index 99da2d3..ba4d8f4 100644 --- a/promo_code/business/permissions.py +++ b/promo_code/business/permissions.py @@ -5,4 +5,14 @@ class IsCompanyUser(rest_framework.permissions.BasePermission): def has_permission(self, request, view): + if not request.user or not request.user.is_authenticated: + return False + return isinstance(request.user, business.models.Company) + + +class IsPromoOwner(rest_framework.permissions.BasePermission): + message = 'The promo code does not belong to this company.' + + def has_object_permission(self, request, view, obj): + return getattr(obj, 'company_id', None) == request.user.id diff --git a/promo_code/business/views.py b/promo_code/business/views.py index 8c592df..e4cedd7 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -1,7 +1,6 @@ import re import django.db.models -import django.shortcuts import pycountry import rest_framework.exceptions import rest_framework.generics @@ -150,29 +149,7 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView): pagination_class = business.pagination.CustomLimitOffsetPagination def get_queryset(self): - queryset = ( - business.models.Promo.objects.filter( - company=self.request.user, - ) - .select_related('company') - .prefetch_related('unique_codes') - .only( - 'id', - 'company', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'created_at', - 'company__id', - 'company__name', - ) - ) - + queryset = business.models.Promo.objects.for_company(self.request.user) countries = [ country.strip() for group in self.request.query_params.getlist('country', []) @@ -298,113 +275,24 @@ def _validate_limit(self): ) -class CompanyPromoDetailView(rest_framework.views.APIView): +class CompanyPromoDetailView(rest_framework.generics.RetrieveUpdateAPIView): + """ + Retrieve (GET) and partially update (PATCH) detailed information + about a company’s promo. + """ + + http_method_names = ['get', 'patch', 'options', 'head'] + + serializer_class = business.serializers.PromoDetailSerializer + permission_classes = [ rest_framework.permissions.IsAuthenticated, business.permissions.IsCompanyUser, + business.permissions.IsPromoOwner, ] lookup_field = 'id' - def get_queryset(self): - user = self.request.user - return ( - business.models.Promo.objects.filter(company=user) - .select_related('company') - .prefetch_related('unique_codes') - .select_related('company') - .prefetch_related('unique_codes') - .only( - 'id', - 'company', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'created_at', - 'company__id', - 'company__name', - ) - ) - - def get(self, request, id): - try: - promo = business.models.Promo.objects.get( - id=id, - ) - except business.models.Promo.DoesNotExist: - raise rest_framework.exceptions.NotFound( - 'Promo not found,', - ) - - if promo.company != request.user: - return rest_framework.response.Response( - { - 'status': 'error', - 'message': ( - 'The promo code does not belong to this company.' - ), - }, - status=rest_framework.status.HTTP_403_FORBIDDEN, - ) - - serializer = business.serializers.PromoDetailSerializer( - promo, - ) - - return rest_framework.response.Response( - serializer.data, - status=rest_framework.status.HTTP_200_OK, - ) - - def patch(self, request, id, *args, **kwargs): - try: - promo = business.models.Promo.objects.get( - id=id, - ) - except business.models.Promo.DoesNotExist: - return rest_framework.response.Response( - { - 'status': 'error', - 'message': 'Promo code not found.', - }, - status=rest_framework.status.HTTP_404_NOT_FOUND, - ) - - if promo.company != request.user: - return rest_framework.response.Response( - { - 'status': 'error', - 'message': ('Promo code does not belong to this company.'), - }, - status=rest_framework.status.HTTP_403_FORBIDDEN, - ) - - serializer = business.serializers.PromoDetailSerializer( - promo, - data=request.data, - partial=True, - context={ - 'request': request, - }, - ) - - if not serializer.is_valid(): - return rest_framework.response.Response( - { - 'status': 'error', - 'message': 'Request data error.', - }, - status=rest_framework.status.HTTP_400_BAD_REQUEST, - ) - - serializer.save() - - return rest_framework.response.Response( - serializer.data, - status=rest_framework.status.HTTP_200_OK, - ) + # Use an enriched base queryset without pre-filtering by company, + # so that ownership mismatches raise 403 Forbidden (not 404 Not Found). + queryset = business.models.Promo.objects.with_related()