Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions promo_code/business/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@ def create_company(self, email, name, password=None, **extra_fields):


class PromoManager(django.db.models.Manager):
def get_queryset(self):
return super().get_queryset()

def with_related(self):
return (
self.select_related('company')
.prefetch_related('unique_codes')
.only(
'id',
'company',
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'created_at',
'company__id',
'company__name',
)
)

def for_company(self, user):
return self.with_related().filter(company=user)

@django.db.transaction.atomic
def create_promo(
self,
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/pagination.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import rest_framework.exceptions
import rest_framework.pagination
import rest_framework.response

Expand Down
10 changes: 10 additions & 0 deletions promo_code/business/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,14 @@

class IsCompanyUser(rest_framework.permissions.BasePermission):
def has_permission(self, request, view):
if not request.user or not request.user.is_authenticated:
return False

return isinstance(request.user, business.models.Company)


class IsPromoOwner(rest_framework.permissions.BasePermission):
message = 'The promo code does not belong to this company.'

def has_object_permission(self, request, view, obj):
return getattr(obj, 'company_id', None) == request.user.id
142 changes: 15 additions & 127 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re

import django.db.models
import django.shortcuts
import pycountry
import rest_framework.exceptions
import rest_framework.generics
Expand Down Expand Up @@ -150,29 +149,7 @@ class CompanyPromoListView(rest_framework.generics.ListAPIView):
pagination_class = business.pagination.CustomLimitOffsetPagination

def get_queryset(self):
queryset = (
business.models.Promo.objects.filter(
company=self.request.user,
)
.select_related('company')
.prefetch_related('unique_codes')
.only(
'id',
'company',
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'created_at',
'company__id',
'company__name',
)
)

queryset = business.models.Promo.objects.for_company(self.request.user)
countries = [
country.strip()
for group in self.request.query_params.getlist('country', [])
Expand Down Expand Up @@ -298,113 +275,24 @@ def _validate_limit(self):
)


class CompanyPromoDetailView(rest_framework.views.APIView):
class CompanyPromoDetailView(rest_framework.generics.RetrieveUpdateAPIView):
"""
Retrieve (GET) and partially update (PATCH) detailed information
about a company’s promo.
"""

http_method_names = ['get', 'patch', 'options', 'head']

serializer_class = business.serializers.PromoDetailSerializer

permission_classes = [
rest_framework.permissions.IsAuthenticated,
business.permissions.IsCompanyUser,
business.permissions.IsPromoOwner,
]

lookup_field = 'id'

def get_queryset(self):
user = self.request.user
return (
business.models.Promo.objects.filter(company=user)
.select_related('company')
.prefetch_related('unique_codes')
.select_related('company')
.prefetch_related('unique_codes')
.only(
'id',
'company',
'description',
'image_url',
'target',
'max_count',
'active_from',
'active_until',
'mode',
'promo_common',
'created_at',
'company__id',
'company__name',
)
)

def get(self, request, id):
try:
promo = business.models.Promo.objects.get(
id=id,
)
except business.models.Promo.DoesNotExist:
raise rest_framework.exceptions.NotFound(
'Promo not found,',
)

if promo.company != request.user:
return rest_framework.response.Response(
{
'status': 'error',
'message': (
'The promo code does not belong to this company.'
),
},
status=rest_framework.status.HTTP_403_FORBIDDEN,
)

serializer = business.serializers.PromoDetailSerializer(
promo,
)

return rest_framework.response.Response(
serializer.data,
status=rest_framework.status.HTTP_200_OK,
)

def patch(self, request, id, *args, **kwargs):
try:
promo = business.models.Promo.objects.get(
id=id,
)
except business.models.Promo.DoesNotExist:
return rest_framework.response.Response(
{
'status': 'error',
'message': 'Promo code not found.',
},
status=rest_framework.status.HTTP_404_NOT_FOUND,
)

if promo.company != request.user:
return rest_framework.response.Response(
{
'status': 'error',
'message': ('Promo code does not belong to this company.'),
},
status=rest_framework.status.HTTP_403_FORBIDDEN,
)

serializer = business.serializers.PromoDetailSerializer(
promo,
data=request.data,
partial=True,
context={
'request': request,
},
)

if not serializer.is_valid():
return rest_framework.response.Response(
{
'status': 'error',
'message': 'Request data error.',
},
status=rest_framework.status.HTTP_400_BAD_REQUEST,
)

serializer.save()

return rest_framework.response.Response(
serializer.data,
status=rest_framework.status.HTTP_200_OK,
)
# Use an enriched base queryset without pre-filtering by company,
# so that ownership mismatches raise 403 Forbidden (not 404 Not Found).
queryset = business.models.Promo.objects.with_related()