diff --git a/promo_code/business/models.py b/promo_code/business/models.py index 9a86fe3..e64858c 100644 --- a/promo_code/business/models.py +++ b/promo_code/business/models.py @@ -119,8 +119,7 @@ def get_comment_count(self) -> int: def get_used_codes_count(self) -> int: if self.mode == business.constants.PROMO_MODE_UNIQUE: return self.unique_codes.filter(is_used=True).count() - # TODO: COMMON Promo - return 0 + return self.used_count @property def get_available_unique_codes(self) -> list[str] | None: diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 3c9b282..ffebcb6 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -1,7 +1,6 @@ import uuid import django.contrib.auth.password_validation -import django.core.validators import django.db.transaction import pycountry import rest_framework.exceptions @@ -13,7 +12,7 @@ import business.constants import business.models import business.utils.tokens -import business.validators +import core.serializers import core.utils.auth @@ -104,7 +103,7 @@ def validate(self, attrs): ) company = self.get_active_company_from_token(refresh) - company = business.utils.auth.bump_company_token_version(company) + company = core.utils.auth.bump_token_version(company) return business.utils.tokens.generate_company_tokens(company) @@ -141,6 +140,62 @@ def get_active_company_from_token(self, token): return company +class CountryField(rest_framework.serializers.CharField): + """ + Custom field for validating country codes according to ISO 3166-1 alpha-2. + """ + + def __init__(self, **kwargs): + kwargs['allow_blank'] = False + kwargs['min_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH + kwargs['max_length'] = business.constants.TARGET_COUNTRY_CODE_LENGTH + super().__init__(**kwargs) + + def to_internal_value(self, data): + code = super().to_internal_value(data) + try: + pycountry.countries.lookup(code.upper()) + except LookupError: + raise rest_framework.serializers.ValidationError( + 'Invalid ISO 3166-1 alpha-2 country code.', + ) + return code + + +class MultiCountryField(rest_framework.serializers.ListField): + """ + Custom field for handling multiple country codes, + passed either as a comma-separated list or as multiple parameters. + """ + + def __init__(self, **kwargs): + kwargs['child'] = CountryField() + kwargs['allow_empty'] = False + super().__init__(**kwargs) + + def to_internal_value(self, data): + if not data or not isinstance(data, list): + raise rest_framework.serializers.ValidationError( + 'At least one country must be specified.', + ) + + # (&country=us,fr) + if len(data) == 1 and ',' in data[0]: + countries_str = data[0] + if '' in [s.strip() for s in countries_str.split(',')]: + raise rest_framework.serializers.ValidationError( + 'Invalid country format.', + ) + data = [country.strip() for country in countries_str.split(',')] + + if any(not item for item in data): + raise rest_framework.serializers.ValidationError( + 'Empty value for country is not allowed.', + ) + + return super().to_internal_value(data) + + class TargetSerializer(rest_framework.serializers.Serializer): age_from = rest_framework.serializers.IntegerField( min_value=business.constants.TARGET_AGE_MIN, @@ -152,15 +207,13 @@ class TargetSerializer(rest_framework.serializers.Serializer): max_value=business.constants.TARGET_AGE_MAX, required=False, ) - country = rest_framework.serializers.CharField( - max_length=business.constants.TARGET_COUNTRY_CODE_LENGTH, - min_length=business.constants.TARGET_COUNTRY_CODE_LENGTH, - required=False, - ) + country = CountryField(required=False) + categories = rest_framework.serializers.ListField( child=rest_framework.serializers.CharField( min_length=business.constants.TARGET_CATEGORY_MIN_LENGTH, max_length=business.constants.TARGET_CATEGORY_MAX_LENGTH, + allow_blank=False, ), max_length=business.constants.TARGET_CATEGORY_MAX_ITEMS, required=False, @@ -170,6 +223,7 @@ class TargetSerializer(rest_framework.serializers.Serializer): def validate(self, data): age_from = data.get('age_from') age_until = data.get('age_until') + if ( age_from is not None and age_until is not None @@ -178,60 +232,47 @@ def validate(self, data): raise rest_framework.serializers.ValidationError( {'age_until': 'Must be greater than or equal to age_from.'}, ) - - country = data.get('country') - if country: - try: - pycountry.countries.lookup(country.strip().upper()) - data['country'] = country - except LookupError: - raise rest_framework.serializers.ValidationError( - {'country': 'Invalid ISO 3166-1 alpha-2 country code.'}, - ) - return data -class PromoCreateSerializer(rest_framework.serializers.ModelSerializer): +class BasePromoSerializer(rest_framework.serializers.ModelSerializer): + """ + Base serializer for promo, containing validation and representation logic. + """ + + image_url = rest_framework.serializers.URLField( + required=False, + allow_blank=False, + max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, + ) description = rest_framework.serializers.CharField( min_length=business.constants.PROMO_DESC_MIN_LENGTH, max_length=business.constants.PROMO_DESC_MAX_LENGTH, required=True, ) - image_url = rest_framework.serializers.CharField( - required=False, - max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], - ) target = TargetSerializer(required=True, allow_null=True) promo_common = rest_framework.serializers.CharField( min_length=business.constants.PROMO_COMMON_CODE_MIN_LENGTH, max_length=business.constants.PROMO_COMMON_CODE_MAX_LENGTH, required=False, allow_null=True, + allow_blank=False, ) promo_unique = rest_framework.serializers.ListField( child=rest_framework.serializers.CharField( min_length=business.constants.PROMO_UNIQUE_CODE_MIN_LENGTH, max_length=business.constants.PROMO_UNIQUE_CODE_MAX_LENGTH, + allow_blank=False, ), min_length=business.constants.PROMO_UNIQUE_LIST_MIN_ITEMS, max_length=business.constants.PROMO_UNIQUE_LIST_MAX_ITEMS, required=False, allow_null=True, ) - # headers - url = rest_framework.serializers.HyperlinkedIdentityField( - view_name='api-business:promo-detail', - lookup_field='id', - ) class Meta: model = business.models.Promo fields = ( - 'url', 'description', 'image_url', 'target', @@ -244,241 +285,194 @@ class Meta: ) def validate(self, data): - data = super().validate(data) - validator = business.validators.PromoValidator(data=data) - return validator.validate() - - def create(self, validated_data): - target_data = validated_data.pop('target') - promo_common = validated_data.pop('promo_common', None) - promo_unique = validated_data.pop('promo_unique', None) - - return business.models.Promo.objects.create_promo( - user=self.context['request'].user, - target_data=target_data, - promo_common=promo_common, - promo_unique=promo_unique, - **validated_data, - ) - - def to_representation(self, instance): - data = super().to_representation(instance) - data['target'] = instance.target - - if instance.mode == business.constants.PROMO_MODE_UNIQUE: - data['promo_unique'] = [ - code.code for code in instance.unique_codes.all() - ] - data.pop('promo_common', None) + """ + Main validation method. + Determines the mode and calls the corresponding validation method. + """ + + mode = data.get('mode', getattr(self.instance, 'mode', None)) + + if mode == business.constants.PROMO_MODE_COMMON: + self._validate_common(data) + elif mode == business.constants.PROMO_MODE_UNIQUE: + self._validate_unique(data) + elif mode is None: + raise rest_framework.serializers.ValidationError( + {'mode': 'This field is required.'}, + ) else: - data.pop('promo_unique', None) + raise rest_framework.serializers.ValidationError( + {'mode': 'Invalid mode.'}, + ) return data + def _validate_common(self, data): + """ + Validations for COMMON promo mode. + """ -class PromoListQuerySerializer(rest_framework.serializers.Serializer): - """ - Serializer for validating query parameters of promo list requests. - """ - - limit = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - offset = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - sort_by = rest_framework.serializers.ChoiceField( - choices=['active_from', 'active_until'], - required=False, - ) - country = rest_framework.serializers.CharField( - required=False, - allow_blank=True, - ) - - _allowed_params = None - - def get_allowed_params(self): - if self._allowed_params is None: - self._allowed_params = set(self.fields.keys()) - return self._allowed_params - - def validate(self, attrs): - query_params = self.initial_data - allowed_params = self.get_allowed_params() - - unexpected_params = set(query_params.keys()) - allowed_params - if unexpected_params: - raise rest_framework.exceptions.ValidationError('Invalid params.') - - field_errors = {} + if 'promo_unique' in data and data['promo_unique'] is not None: + raise rest_framework.serializers.ValidationError( + {'promo_unique': 'This field is not allowed for COMMON mode.'}, + ) - attrs = self._validate_int_field('limit', attrs, field_errors) - attrs = self._validate_int_field('offset', attrs, field_errors) + if self.instance is None and not data.get('promo_common'): + raise rest_framework.serializers.ValidationError( + {'promo_common': 'This field is required for COMMON mode.'}, + ) - self._validate_country(query_params, attrs, field_errors) + new_max_count = data.get('max_count') + if self.instance and new_max_count is not None: + used_count = self.instance.get_used_codes_count + if used_count > new_max_count: + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + f'max_count ({new_max_count}) cannot be less than ' + f'used_count ({used_count}).' + ), + }, + ) - if field_errors: - raise rest_framework.exceptions.ValidationError(field_errors) + effective_max_count = ( + new_max_count + if new_max_count is not None + else getattr(self.instance, 'max_count', None) + ) - return attrs + min_c = business.constants.PROMO_COMMON_MIN_COUNT + max_c = business.constants.PROMO_COMMON_MAX_COUNT + if effective_max_count is not None and not ( + min_c <= effective_max_count <= max_c + ): + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + f'Must be between {min_c} and {max_c} for COMMON mode.' + ), + }, + ) - def _validate_int_field(self, field_name, attrs, field_errors): - value_str = self.initial_data.get(field_name) - if value_str is None: - return attrs + def _validate_unique(self, data): + """ + Validations for UNIQUE promo mode. + """ - if value_str == '': - raise rest_framework.exceptions.ValidationError( - f'Invalid {field_name} format.', + if 'promo_common' in data and data['promo_common'] is not None: + raise rest_framework.serializers.ValidationError( + {'promo_common': 'This field is not allowed for UNIQUE mode.'}, ) - try: - value_int = int(value_str) - if value_int < 0: - raise rest_framework.exceptions.ValidationError( - f'{field_name.capitalize()} cannot be negative.', - ) - attrs[field_name] = value_int - except (ValueError, TypeError): - raise rest_framework.exceptions.ValidationError( - f'Invalid {field_name} format.', + if self.instance is None and not data.get('promo_unique'): + raise rest_framework.serializers.ValidationError( + {'promo_unique': 'This field is required for UNIQUE mode.'}, ) - return attrs - - def _validate_country(self, query_params, attrs, field_errors): - countries_raw = query_params.getlist('country', []) + effective_max_count = data.get( + 'max_count', + getattr(self.instance, 'max_count', None), + ) - if '' in countries_raw: - raise rest_framework.exceptions.ValidationError( - 'Invalid country format.', + if ( + effective_max_count is not None + and effective_max_count + != business.constants.PROMO_UNIQUE_MAX_COUNT + ): + raise rest_framework.serializers.ValidationError( + { + 'max_count': ( + 'Must be equal to ' + f'{business.constants.PROMO_UNIQUE_MAX_COUNT} ' + 'for UNIQUE mode.' + ), + }, ) - country_codes = [] - invalid_codes = [] + def to_representation(self, instance): + """ + Controls the display of fields in the response. + """ - for country_group in countries_raw: - if not country_group.strip(): - continue + data = super().to_representation(instance) - parts = [part.strip() for part in country_group.split(',')] + if not instance.image_url: + data.pop('image_url', None) - if '' in parts: - raise rest_framework.exceptions.ValidationError( - 'Invalid country format.', - ) + if instance.mode == business.constants.PROMO_MODE_UNIQUE: + data.pop('promo_common', None) + if 'promo_unique' in self.fields and isinstance( + self.fields['promo_unique'], + rest_framework.serializers.SerializerMethodField, + ): + data['promo_unique'] = self.get_promo_unique(instance) + else: + data['promo_unique'] = [ + code.code for code in instance.unique_codes.all() + ] + else: + data.pop('promo_unique', None) - country_codes.extend(parts) + return data - country_codes_upper = [c.upper() for c in country_codes] - for code in country_codes_upper: - if len(code) != 2: - invalid_codes.append(code) - continue - try: - pycountry.countries.lookup(code) - except LookupError: - invalid_codes.append(code) +class PromoCreateSerializer(BasePromoSerializer): + url = rest_framework.serializers.HyperlinkedIdentityField( + view_name='api-business:promo-detail', + lookup_field='id', + ) - if invalid_codes: - field_errors['country'] = ( - f'Invalid country codes: {", ".join(invalid_codes)}' - ) + class Meta(BasePromoSerializer.Meta): + fields = ('url',) + BasePromoSerializer.Meta.fields - attrs['countries'] = country_codes - attrs.pop('country', None) + def create(self, validated_data): + target_data = validated_data.pop('target') + promo_common = validated_data.pop('promo_common', None) + promo_unique = validated_data.pop('promo_unique', None) + return business.models.Promo.objects.create_promo( + user=self.context['request'].user, + target_data=target_data, + promo_common=promo_common, + promo_unique=promo_unique, + **validated_data, + ) -class PromoReadOnlySerializer(rest_framework.serializers.ModelSerializer): - promo_id = rest_framework.serializers.UUIDField( - source='id', - read_only=True, - ) - company_id = rest_framework.serializers.UUIDField( - source='company.id', - read_only=True, - ) - company_name = rest_framework.serializers.CharField( - source='company.name', - read_only=True, - ) - target = TargetSerializer() - promo_unique = rest_framework.serializers.SerializerMethodField() - like_count = rest_framework.serializers.IntegerField( - source='get_like_count', - read_only=True, - ) - used_count = rest_framework.serializers.IntegerField( - source='get_used_codes_count', - read_only=True, - ) - comment_count = rest_framework.serializers.IntegerField( - source='get_comment_count', - read_only=True, - ) - active = rest_framework.serializers.BooleanField( - source='is_active', - read_only=True, +class PromoListQuerySerializer( + core.serializers.BaseLimitOffsetPaginationSerializer, +): + """ + Validates query parameters for the list of promotions. + """ + + sort_by = rest_framework.serializers.ChoiceField( + choices=['active_from', 'active_until'], + required=False, ) + country = MultiCountryField(required=False) - class Meta: - model = business.models.Promo - fields = ( - 'promo_id', - 'company_id', - 'company_name', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'promo_unique', - 'like_count', - 'comment_count', - 'used_count', - 'active', - ) + def validate(self, attrs): + query_params = self.initial_data.keys() + allowed_params = self.fields.keys() + unexpected_params = set(query_params) - set(allowed_params) - def get_promo_unique(self, obj): - return obj.get_available_unique_codes + if unexpected_params: + raise rest_framework.exceptions.ValidationError( + f'Invalid parameters: {", ".join(unexpected_params)}', + ) - def to_representation(self, instance): - data = super().to_representation(instance) - if instance.mode == business.constants.PROMO_MODE_COMMON: - data.pop('promo_unique', None) - else: - data.pop('promo_common', None) + if 'country' in attrs: + attrs['countries'] = attrs.pop('country') - return data + return attrs -class PromoDetailSerializer(rest_framework.serializers.ModelSerializer): +class PromoDetailSerializer(BasePromoSerializer): promo_id = rest_framework.serializers.UUIDField( source='id', read_only=True, ) - description = rest_framework.serializers.CharField( - min_length=business.constants.PROMO_DESC_MIN_LENGTH, - max_length=business.constants.PROMO_DESC_MAX_LENGTH, - required=True, - ) - image_url = rest_framework.serializers.CharField( - required=False, - max_length=business.constants.PROMO_IMAGE_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], - ) - target = TargetSerializer(allow_null=True, required=False) - promo_unique = rest_framework.serializers.SerializerMethodField() company_name = rest_framework.serializers.CharField( source='company.name', read_only=True, @@ -500,31 +494,26 @@ class PromoDetailSerializer(rest_framework.serializers.ModelSerializer): read_only=True, ) - class Meta: - model = business.models.Promo - fields = ( + promo_unique = rest_framework.serializers.SerializerMethodField() + + class Meta(BasePromoSerializer.Meta): + fields = BasePromoSerializer.Meta.fields + ( 'promo_id', - 'description', - 'image_url', - 'target', - 'max_count', - 'active_from', - 'active_until', - 'mode', - 'promo_common', - 'promo_unique', 'company_name', - 'active', 'like_count', 'comment_count', 'used_count', + 'active', ) def get_promo_unique(self, obj): - return obj.get_available_unique_codes + if obj.mode == business.constants.PROMO_MODE_UNIQUE: + return obj.get_available_unique_codes + return None def update(self, instance, validated_data): target_data = validated_data.pop('target', None) + for attr, value in validated_data.items(): setattr(instance, attr, value) @@ -534,13 +523,18 @@ def update(self, instance, validated_data): instance.save() return instance - def validate(self, data): - data = super().validate(data) - validator = business.validators.PromoValidator( - data=data, - instance=self.instance, - ) - return validator.validate() + +class PromoReadOnlySerializer(PromoDetailSerializer): + """Read-only serializer for promo.""" + + company_id = rest_framework.serializers.UUIDField( + source='company.id', + read_only=True, + ) + + class Meta(PromoDetailSerializer.Meta): + fields = PromoDetailSerializer.Meta.fields + ('company_id',) + read_only_fields = fields class CountryStatSerializer(rest_framework.serializers.Serializer): diff --git a/promo_code/business/validators.py b/promo_code/business/validators.py deleted file mode 100644 index b66ce07..0000000 --- a/promo_code/business/validators.py +++ /dev/null @@ -1,98 +0,0 @@ -import rest_framework.exceptions - -import business.constants - - -class PromoValidator: - def __init__(self, data, instance=None): - self.data = data - self.instance = instance - self.full_data = self._get_full_data() - - def _get_full_data(self): - full_data = {} - if self.instance is not None: - full_data.update( - { - 'mode': self.instance.mode, - 'promo_common': self.instance.promo_common, - 'promo_unique': None, - 'max_count': self.instance.max_count, - 'active_from': self.instance.active_from, - 'active_until': self.instance.active_until, - 'used_count': self.instance.used_count, - 'target': self.instance.target - if self.instance.target - else {}, - }, - ) - - full_data.update(self.data) - return full_data - - def validate(self): - mode = self.full_data.get('mode') - promo_common = self.full_data.get('promo_common') - promo_unique = self.full_data.get('promo_unique') - max_count = self.full_data.get('max_count') - used_count = self.full_data.get('used_count') - - if mode not in [ - business.constants.PROMO_MODE_COMMON, - business.constants.PROMO_MODE_UNIQUE, - ]: - raise rest_framework.exceptions.ValidationError( - {'mode': 'Invalid mode.'}, - ) - - if used_count and used_count > max_count: - raise rest_framework.exceptions.ValidationError( - {'mode': 'Invalid max_count.'}, - ) - - if mode == business.constants.PROMO_MODE_COMMON: - if not promo_common: - raise rest_framework.exceptions.ValidationError( - { - 'promo_common': ( - 'This field is required for COMMON mode.' - ), - }, - ) - if promo_unique is not None: - raise rest_framework.exceptions.ValidationError( - { - 'promo_unique': ( - 'This field is not allowed for COMMON mode.' - ), - }, - ) - if max_count is None or not ( - business.constants.PROMO_COMMON_MIN_COUNT - <= max_count - <= business.constants.PROMO_COMMON_MAX_COUNT - ): - raise rest_framework.exceptions.ValidationError( - { - 'max_count': ( - 'Must be between 0 and 100,000,000 ' - 'for COMMON mode.' - ), - }, - ) - - elif mode == business.constants.PROMO_MODE_UNIQUE: - if promo_common is not None: - raise rest_framework.exceptions.ValidationError( - { - 'promo_common': ( - 'This field is not allowed for UNIQUE mode.' - ), - }, - ) - if max_count != business.constants.PROMO_UNIQUE_MAX_COUNT: - raise rest_framework.exceptions.ValidationError( - {'max_count': 'Must be 1 for UNIQUE mode.'}, - ) - - return self.full_data diff --git a/promo_code/core/pagination.py b/promo_code/core/pagination.py index 120f148..31a129d 100644 --- a/promo_code/core/pagination.py +++ b/promo_code/core/pagination.py @@ -1,6 +1,9 @@ +import rest_framework.exceptions import rest_framework.pagination import rest_framework.response +import core.serializers + class CustomLimitOffsetPagination( rest_framework.pagination.LimitOffsetPagination, @@ -9,12 +12,13 @@ class CustomLimitOffsetPagination( max_limit = 100 def get_limit(self, request): - raw_limit = request.query_params.get(self.limit_query_param) - - if raw_limit is None: - return self.default_limit + serializer = core.serializers.BaseLimitOffsetPaginationSerializer( + data=request.query_params, + ) + serializer.is_valid(raise_exception=True) - limit = int(raw_limit) + validated_data = serializer.validated_data + limit = validated_data.get('limit', self.default_limit) # Allow 0, otherwise cut by max_limit return 0 if limit == 0 else min(limit, self.max_limit) diff --git a/promo_code/core/serializers.py b/promo_code/core/serializers.py new file mode 100644 index 0000000..60cd85b --- /dev/null +++ b/promo_code/core/serializers.py @@ -0,0 +1,31 @@ +import rest_framework.exceptions +import rest_framework.serializers + + +class BaseLimitOffsetPaginationSerializer( + rest_framework.serializers.Serializer, +): + """ + Base serializer for common filtering and sorting parameters. + Pagination parameters (limit, offset) are handled by the pagination class. + """ + + limit = rest_framework.serializers.IntegerField( + min_value=0, + required=False, + ) + offset = rest_framework.serializers.IntegerField( + min_value=0, + required=False, + ) + + def validate(self, attrs): + errors = {} + for field in ('limit', 'offset'): + raw = self.initial_data.get(field, None) + if raw == '': + errors[field] = ['This field cannot be an empty string.'] + if errors: + raise rest_framework.exceptions.ValidationError(errors) + + return super().validate(attrs)