diff --git a/promo_code/business/managers.py b/promo_code/business/managers.py new file mode 100644 index 0000000..a21b86b --- /dev/null +++ b/promo_code/business/managers.py @@ -0,0 +1,53 @@ +import django.contrib.auth.models +import django.db.models + +import business.models + + +class CompanyManager(django.contrib.auth.models.BaseUserManager): + def create_company(self, email, name, password=None, **extra_fields): + if not email: + raise ValueError('The Email must be set') + + email = self.normalize_email(email) + company = self.model( + email=email, + name=name, + **extra_fields, + ) + company.set_password(password) + company.save(using=self._db) + return company + + +class PromoManager(django.db.models.Manager): + @django.db.transaction.atomic + def create_promo( + self, + user, + target_data, + promo_common, + promo_unique, + **kwargs, + ): + promo = self.create( + company=user, + target=target_data, + **kwargs, + ) + + if promo.mode == business.models.Promo.MODE_COMMON: + promo.promo_common = promo_common + promo.save(update_fields=['promo_common']) + elif promo.mode == business.models.Promo.MODE_UNIQUE and promo_unique: + self._create_unique_codes(promo, promo_unique) + + return promo + + def _create_unique_codes(self, promo, codes): + business.models.PromoCode.objects.bulk_create( + [ + business.models.PromoCode(promo=promo, code=code) + for code in codes + ], + ) diff --git a/promo_code/business/models.py b/promo_code/business/models.py index 26750d2..31181b5 100644 --- a/promo_code/business/models.py +++ b/promo_code/business/models.py @@ -3,21 +3,7 @@ import django.contrib.auth.models import django.db.models - -class CompanyManager(django.contrib.auth.models.BaseUserManager): - def create_company(self, email, name, password=None, **extra_fields): - if not email: - raise ValueError('The Email must be set') - - email = self.normalize_email(email) - company = self.model( - email=email, - name=name, - **extra_fields, - ) - company.set_password(password) - company.save(using=self._db) - return company +import business.managers class Company(django.contrib.auth.models.AbstractBaseUser): @@ -37,7 +23,7 @@ class Company(django.contrib.auth.models.AbstractBaseUser): created_at = django.db.models.DateTimeField(auto_now_add=True) is_active = django.db.models.BooleanField(default=True) - objects = CompanyManager() + objects = business.managers.CompanyManager() USERNAME_FIELD = 'email' REQUIRED_FIELDS = ['name'] @@ -87,6 +73,8 @@ class Promo(django.db.models.Model): created_at = django.db.models.DateTimeField(auto_now_add=True) + objects = business.managers.PromoManager() + def __str__(self): return f'Promo {self.id} ({self.mode})' diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 89e0818..454f240 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -246,104 +246,23 @@ class Meta: ) def validate(self, data): - mode = data.get('mode') - promo_common = data.get('promo_common') - promo_unique = data.get('promo_unique') - max_count = data.get('max_count') - - if mode == business_models.Promo.MODE_COMMON: - if not promo_common: - raise rest_framework.serializers.ValidationError( - { - 'promo_common': ( - 'This field is required for COMMON mode.' - ), - }, - ) - - if promo_unique is not None: - raise rest_framework.serializers.ValidationError( - { - 'promo_unique': ( - 'This field is not allowed for COMMON mode.' - ), - }, - ) - - if max_count < 0 or max_count > 100000000: - raise rest_framework.serializers.ValidationError( - { - 'max_count': ( - 'Must be between 0 and 100,000,000 ' - 'for COMMON mode.' - ), - }, - ) - - elif mode == business_models.Promo.MODE_UNIQUE: - if not promo_unique: - raise rest_framework.serializers.ValidationError( - { - 'promo_unique': ( - 'This field is required for UNIQUE mode.' - ), - }, - ) - - if promo_common is not None: - raise rest_framework.serializers.ValidationError( - { - 'promo_common': ( - 'This field is not allowed for UNIQUE mode.' - ), - }, - ) - - if max_count != 1: - raise rest_framework.serializers.ValidationError( - {'max_count': 'Must be 1 for UNIQUE mode.'}, - ) - - else: - raise rest_framework.serializers.ValidationError( - {'mode': 'Invalid mode.'}, - ) - - active_from = data.get('active_from') - active_until = data.get('active_until') - if active_from and active_until and active_from > active_until: - raise rest_framework.serializers.ValidationError( - {'active_until': 'Must be after or equal to active_from.'}, - ) - - return data + data = super().validate(data) + validator = business.validators.PromoValidator(data=data) + return validator.validate() def create(self, validated_data): target_data = validated_data.pop('target') promo_common = validated_data.pop('promo_common', None) promo_unique = validated_data.pop('promo_unique', None) - mode = validated_data['mode'] - - user = self.context['request'].user - validated_data['company'] = user - promo = business_models.Promo.objects.create( + return business_models.Promo.objects.create_promo( + user=self.context['request'].user, + target_data=target_data, + promo_common=promo_common, + promo_unique=promo_unique, **validated_data, - target=target_data, ) - if mode == business_models.Promo.MODE_COMMON: - promo.promo_common = promo_common - promo.save() - elif mode == business_models.Promo.MODE_UNIQUE and promo_unique: - promo_codes = [ - business_models.PromoCode(promo=promo, code=code) - for code in promo_unique - ] - business_models.PromoCode.objects.bulk_create(promo_codes) - - return promo - def to_representation(self, instance): data = super().to_representation(instance) data['target'] = instance.target @@ -505,74 +424,12 @@ def update(self, instance, validated_data): return instance def validate(self, data): - instance = self.instance - full_data = { - 'mode': instance.mode, - 'promo_common': instance.promo_common, - 'promo_unique': None, - 'max_count': instance.max_count, - 'active_from': instance.active_from, - 'active_until': instance.active_until, - 'target': instance.target if instance.target is not None else {}, - } - full_data.update(data) - mode = full_data.get('mode') - promo_common = full_data.get('promo_common') - promo_unique = full_data.get('promo_unique') - max_count = full_data.get('max_count') - - if mode == business_models.Promo.MODE_COMMON: - if not promo_common: - raise rest_framework.serializers.ValidationError( - { - 'promo_common': ( - 'This field is required for COMMON mode.' - ), - }, - ) - - if promo_unique is not None: - raise rest_framework.serializers.ValidationError( - { - 'promo_unique': ( - 'This field is not allowed for COMMON mode.' - ), - }, - ) - - if max_count < 0 or max_count > 100000000: - raise rest_framework.serializers.ValidationError( - {'max_count': 'Must be between 0 and 100,000,000.'}, - ) - - elif mode == business_models.Promo.MODE_UNIQUE: - if promo_common is not None: - raise rest_framework.serializers.ValidationError( - { - 'promo_common': ( - 'This field is not allowed for UNIQUE mode.' - ), - }, - ) - - if max_count != 1: - raise rest_framework.serializers.ValidationError( - {'max_count': 'Must be 1 for UNIQUE mode.'}, - ) - else: - raise rest_framework.serializers.ValidationError( - {'mode': 'Invalid mode.'}, - ) - - active_from = full_data.get('active_from') - active_until = full_data.get('active_until') - - if active_from and active_until and active_from > active_until: - raise rest_framework.serializers.ValidationError( - {'active_until': 'Must be after or equal to active_from.'}, - ) - - return data + data = super().validate(data) + validator = business.validators.PromoValidator( + data=data, + instance=self.instance, + ) + return validator.validate() def get_like_count(self, obj): return 0 diff --git a/promo_code/business/tests/promocodes/operations/test_detail.py b/promo_code/business/tests/promocodes/operations/test_detail.py index 615aebd..7ebd652 100644 --- a/promo_code/business/tests/promocodes/operations/test_detail.py +++ b/promo_code/business/tests/promocodes/operations/test_detail.py @@ -45,7 +45,6 @@ def setUpTestData(cls): cls.promo2_id = response2.data['id'] def test_get_promo_company1(self): - promo_detail_url = django.urls.reverse( 'api-business:promo-detail', kwargs={'id': self.__class__.promo1_id}, diff --git a/promo_code/business/tests/promocodes/validations/test_create_validation.py b/promo_code/business/tests/promocodes/validations/test_create_validation.py index 4e47934..7c498f5 100644 --- a/promo_code/business/tests/promocodes/validations/test_create_validation.py +++ b/promo_code/business/tests/promocodes/validations/test_create_validation.py @@ -7,7 +7,6 @@ class TestPromoCreate( business.tests.promocodes.base.BasePromoTestCase, ): - def setUp(self): super().setUp() self.client.credentials( diff --git a/promo_code/business/tests/promocodes/validations/test_detail_validation.py b/promo_code/business/tests/promocodes/validations/test_detail_validation.py index ba3bfba..aabafce 100644 --- a/promo_code/business/tests/promocodes/validations/test_detail_validation.py +++ b/promo_code/business/tests/promocodes/validations/test_detail_validation.py @@ -5,7 +5,6 @@ class TestPromoDetail(business.tests.promocodes.base.BasePromoTestCase): - @classmethod def setUpClass(cls): super().setUpClass() diff --git a/promo_code/business/validators.py b/promo_code/business/validators.py index 755ac01..c00cb4d 100644 --- a/promo_code/business/validators.py +++ b/promo_code/business/validators.py @@ -1,4 +1,5 @@ import rest_framework.exceptions +import rest_framework.permissions import business.models @@ -22,3 +23,94 @@ def __call__(self, value): ) exc.status_code = self.status_code raise exc + + +class PromoValidator: + def __init__(self, data, instance=None): + self.data = data + self.instance = instance + self.full_data = self._get_full_data() + + def _get_full_data(self): + full_data = {} + if self.instance is not None: + full_data.update( + { + 'mode': self.instance.mode, + 'promo_common': self.instance.promo_common, + 'promo_unique': None, + 'max_count': self.instance.max_count, + 'active_from': self.instance.active_from, + 'active_until': self.instance.active_until, + 'target': self.instance.target + if self.instance.target + else {}, + }, + ) + + full_data.update(self.data) + return full_data + + def validate(self): + mode = self.full_data.get('mode') + promo_common = self.full_data.get('promo_common') + promo_unique = self.full_data.get('promo_unique') + max_count = self.full_data.get('max_count') + active_from = self.full_data.get('active_from') + active_until = self.full_data.get('active_until') + + if mode not in [ + business.models.Promo.MODE_COMMON, + business.models.Promo.MODE_UNIQUE, + ]: + raise rest_framework.exceptions.ValidationError( + {'mode': 'Invalid mode.'}, + ) + + if mode == business.models.Promo.MODE_COMMON: + if not promo_common: + raise rest_framework.exceptions.ValidationError( + { + 'promo_common': ( + 'This field is required for COMMON mode.' + ), + }, + ) + if promo_unique is not None: + raise rest_framework.exceptions.ValidationError( + { + 'promo_unique': ( + 'This field is not allowed for COMMON mode.' + ), + }, + ) + if max_count is None or not (0 <= max_count <= 100_000_000): + raise rest_framework.exceptions.ValidationError( + { + 'max_count': ( + 'Must be between 0 and 100,000,000 ' + 'for COMMON mode.' + ), + }, + ) + + elif mode == business.models.Promo.MODE_UNIQUE: + if promo_common is not None: + raise rest_framework.exceptions.ValidationError( + { + 'promo_common': ( + 'This field is not allowed for UNIQUE mode.' + ), + }, + ) + if max_count != 1: + raise rest_framework.exceptions.ValidationError( + {'max_count': 'Must be 1 for UNIQUE mode.'}, + ) + + if active_from and active_until and active_from > active_until: + raise rest_framework.exceptions.ValidationError( + {'active_until': 'Must be after or equal to active_from.'}, + ) + + return self.full_data