diff --git a/promo_code/business/migrations/0001_initial.py b/promo_code/business/migrations/0001_initial.py new file mode 100644 index 0000000..b2657d3 --- /dev/null +++ b/promo_code/business/migrations/0001_initial.py @@ -0,0 +1,45 @@ +# Generated by Django 5.2b1 on 2025-03-25 14:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name='Company', + fields=[ + ( + 'id', + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name='ID', + ), + ), + ( + 'password', + models.CharField(max_length=128, verbose_name='password'), + ), + ( + 'last_login', + models.DateTimeField( + blank=True, null=True, verbose_name='last login' + ), + ), + ('email', models.EmailField(max_length=120, unique=True)), + ('name', models.CharField(max_length=50)), + ('token_version', models.IntegerField(default=0)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('is_active', models.BooleanField(default=True)), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/promo_code/business/models.py b/promo_code/business/models.py new file mode 100644 index 0000000..8e2219b --- /dev/null +++ b/promo_code/business/models.py @@ -0,0 +1,38 @@ +import django.contrib.auth.models +import django.db.models + + +class CompanyManager(django.contrib.auth.models.BaseUserManager): + def create_company(self, email, name, password=None, **extra_fields): + if not email: + raise ValueError('The Email must be set') + + email = self.normalize_email(email) + company = self.model( + email=email, + name=name, + **extra_fields, + ) + company.set_password(password) + company.save(using=self._db) + return company + + +class Company(django.contrib.auth.models.AbstractBaseUser): + email = django.db.models.EmailField( + unique=True, + max_length=120, + ) + name = django.db.models.CharField(max_length=50) + + token_version = django.db.models.IntegerField(default=0) + created_at = django.db.models.DateTimeField(auto_now_add=True) + is_active = django.db.models.BooleanField(default=True) + + objects = CompanyManager() + + USERNAME_FIELD = 'email' + REQUIRED_FIELDS = ['name'] + + def __str__(self): + return self.name diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py new file mode 100644 index 0000000..1d9d960 --- /dev/null +++ b/promo_code/business/serializers.py @@ -0,0 +1,92 @@ +import business.models as business_models +import business.validators +import django.contrib.auth.password_validation +import django.core.exceptions +import django.core.validators +import rest_framework.exceptions +import rest_framework.serializers +import rest_framework.status + + +class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer): + password = rest_framework.serializers.CharField( + write_only=True, + required=True, + validators=[django.contrib.auth.password_validation.validate_password], + min_length=8, + max_length=60, + style={'input_type': 'password'}, + ) + name = rest_framework.serializers.CharField( + required=True, + min_length=5, + max_length=50, + ) + email = rest_framework.serializers.EmailField( + required=True, + min_length=8, + max_length=120, + validators=[ + business.validators.UniqueEmailValidator( + 'This email address is already registered.', + 'email_conflict', + ), + ], + ) + + class Meta: + model = business_models.Company + fields = ( + 'name', + 'email', + 'password', + ) + + def create(self, validated_data): + try: + company = business_models.Company.objects.create_company( + email=validated_data['email'], + name=validated_data['name'], + password=validated_data['password'], + ) + company.token_version += 1 + company.save() + return company + except django.core.exceptions.ValidationError as e: + raise rest_framework.serializers.ValidationError(e.messages) + + +class CompanySignInSerializer( + rest_framework.serializers.Serializer, +): + email = rest_framework.serializers.EmailField(required=True) + password = rest_framework.serializers.CharField( + required=True, + write_only=True, + style={'input_type': 'password'}, + ) + + def validate(self, attrs): + email = attrs.get('email') + password = attrs.get('password') + + if not email or not password: + raise rest_framework.exceptions.ValidationError( + {'detail': 'Both email and password are required'}, + code='required', + ) + + try: + company = business_models.Company.objects.get(email=email) + except business_models.Company.DoesNotExist: + raise rest_framework.serializers.ValidationError( + 'Invalid credentials', + ) + + if not company.is_active or not company.check_password(password): + raise rest_framework.exceptions.AuthenticationFailed( + {'detail': 'Invalid credentials or inactive account'}, + code='authentication_failed', + ) + + return attrs diff --git a/promo_code/business/urls.py b/promo_code/business/urls.py new file mode 100644 index 0000000..4d40321 --- /dev/null +++ b/promo_code/business/urls.py @@ -0,0 +1,18 @@ +import business.views +import django.urls + +app_name = 'api-business' + + +urlpatterns = [ + django.urls.path( + 'auth/sign-up', + business.views.CompanySignUpView.as_view(), + name='company-sign-up', + ), + django.urls.path( + 'auth/sign-in', + business.views.CompanySignInView.as_view(), + name='company-sign-in', + ), +] diff --git a/promo_code/business/validators.py b/promo_code/business/validators.py new file mode 100644 index 0000000..819c64c --- /dev/null +++ b/promo_code/business/validators.py @@ -0,0 +1,23 @@ +import business.models +import rest_framework.exceptions + + +class UniqueEmailValidator: + def __init__(self, default_detail=None, default_code=None): + self.status_code = 409 + self.default_detail = ( + default_detail or 'This email address is already registered.' + ) + self.default_code = default_code or 'email_conflict' + + def __call__(self, value): + if business.models.Company.objects.filter(email=value).exists(): + exc = rest_framework.exceptions.APIException( + detail={ + 'status': 'error', + 'message': self.default_detail, + 'code': self.default_code, + }, + ) + exc.status_code = self.status_code + raise exc diff --git a/promo_code/business/views.py b/promo_code/business/views.py new file mode 100644 index 0000000..0c9856b --- /dev/null +++ b/promo_code/business/views.py @@ -0,0 +1,99 @@ +import business.models +import business.serializers +import rest_framework.exceptions +import rest_framework.generics +import rest_framework.response +import rest_framework.serializers +import rest_framework.status +import rest_framework_simplejwt.exceptions +import rest_framework_simplejwt.tokens +import rest_framework_simplejwt.views + +import core.views + + +class CompanySignUpView( + core.views.BaseCustomResponseMixin, + rest_framework.generics.CreateAPIView, +): + def post(self, request): + try: + serializer = business.serializers.CompanySignUpSerializer( + data=request.data, + ) + serializer.is_valid(raise_exception=True) + except ( + rest_framework.serializers.ValidationError, + rest_framework_simplejwt.exceptions.TokenError, + ) as e: + if isinstance(e, rest_framework.serializers.ValidationError): + return self.handle_validation_error() + + raise rest_framework_simplejwt.exceptions.InvalidToken(str(e)) + + company = serializer.save() + + refresh = rest_framework_simplejwt.tokens.RefreshToken() + refresh['user_type'] = 'company' + refresh['company_id'] = company.id + refresh['token_version'] = company.token_version + + access_token = refresh.access_token + access_token['user_type'] = 'company' + access_token['company_id'] = company.id + refresh['token_version'] = company.token_version + + response_data = { + 'access': str(access_token), + 'refresh': str(refresh), + } + + return rest_framework.response.Response( + response_data, + status=rest_framework.status.HTTP_200_OK, + ) + + +class CompanySignInView( + core.views.BaseCustomResponseMixin, + rest_framework_simplejwt.views.TokenObtainPairView, +): + def post(self, request): + try: + serializer = business.serializers.CompanySignInSerializer( + data=request.data, + ) + serializer.is_valid(raise_exception=True) + except ( + rest_framework.serializers.ValidationError, + rest_framework_simplejwt.exceptions.TokenError, + ) as e: + if isinstance(e, rest_framework.serializers.ValidationError): + return self.handle_validation_error() + + raise rest_framework_simplejwt.exceptions.InvalidToken(str(e)) + + company = business.models.Company.objects.get( + email=serializer.validated_data['email'], + ) + company.token_version += 1 + company.save() + + refresh = rest_framework_simplejwt.tokens.RefreshToken() + refresh['user_type'] = 'company' + refresh['company_id'] = company.id + refresh['token_version'] = company.token_version + + access_token = refresh.access_token + access_token['user_type'] = 'company' + access_token['company_id'] = company.id + + response_data = { + 'access': str(access_token), + 'refresh': str(refresh), + } + + return rest_framework.response.Response( + response_data, + status=rest_framework.status.HTTP_200_OK, + ) diff --git a/promo_code/core/views.py b/promo_code/core/views.py index e6f149b..50fad55 100644 --- a/promo_code/core/views.py +++ b/promo_code/core/views.py @@ -2,9 +2,20 @@ import django.views import rest_framework.permissions import rest_framework.response +import rest_framework.status import rest_framework.views +class BaseCustomResponseMixin: + error_response = {'status': 'error', 'message': 'Error in request data.'} + + def handle_validation_error(self): + return rest_framework.response.Response( + self.error_response, + status=rest_framework.status.HTTP_400_BAD_REQUEST, + ) + + class PingView(django.views.View): def get(self, request, *args, **kwargs): return django.http.HttpResponse('PROOOOOOOOOOOOOOOOOD', status=200) diff --git a/promo_code/promo_code/urls.py b/promo_code/promo_code/urls.py index 8628dbf..76eb0fd 100644 --- a/promo_code/promo_code/urls.py +++ b/promo_code/promo_code/urls.py @@ -2,7 +2,8 @@ import django.urls urlpatterns = [ - django.urls.path('api/ping/', django.urls.include('core.urls')), + django.urls.path('api/business/', django.urls.include('business.urls')), django.urls.path('api/user/', django.urls.include('user.urls')), + django.urls.path('api/ping/', django.urls.include('core.urls')), django.urls.path('admin/', django.contrib.admin.site.urls), ] diff --git a/promo_code/user/authentication.py b/promo_code/user/authentication.py index 0b6d515..7140a3e 100644 --- a/promo_code/user/authentication.py +++ b/promo_code/user/authentication.py @@ -1,23 +1,58 @@ +import business.models import rest_framework_simplejwt.authentication import rest_framework_simplejwt.exceptions +import user.models as user_models + class CustomJWTAuthentication( rest_framework_simplejwt.authentication.JWTAuthentication, ): def authenticate(self, request): try: - user_token = super().authenticate(request) - except rest_framework_simplejwt.exceptions.InvalidToken: - raise rest_framework_simplejwt.exceptions.AuthenticationFailed( - 'Token is invalid or expired', - ) + header = self.get_header(request) + if header is None: + return None - if user_token: - user, token = user_token - if token.payload.get('token_version') != user.token_version: + raw_token = self.get_raw_token(header) + if raw_token is None: + return None + + validated_token = self.get_validated_token(raw_token) + user_type = validated_token.get('user_type', 'user') + + model_mapping = { + 'user': (user_models.User, 'user_id'), + 'company': (business.models.Company, 'company_id'), + } + + if user_type not in model_mapping: + raise rest_framework_simplejwt.exceptions.AuthenticationFailed( + 'Invalid user type', + ) + + model_class, id_field = model_mapping[user_type] + instance = model_class.objects.get( + id=validated_token.get(id_field), + ) + if instance.token_version != validated_token.get( + 'token_version', + 0, + ): raise rest_framework_simplejwt.exceptions.AuthenticationFailed( 'Token invalid', ) - return user_token + return (instance, validated_token) + + except ( + user_models.User.DoesNotExist, + business.models.Company.DoesNotExist, + ): + raise rest_framework_simplejwt.exceptions.AuthenticationFailed( + 'User or Company not found', + ) + except rest_framework_simplejwt.exceptions.InvalidToken: + raise rest_framework_simplejwt.exceptions.AuthenticationFailed( + 'Token is invalid or expired', + ) diff --git a/promo_code/user/serializers.py b/promo_code/user/serializers.py index 7de89d7..ee91085 100644 --- a/promo_code/user/serializers.py +++ b/promo_code/user/serializers.py @@ -21,11 +21,20 @@ class SignUpSerializer(rest_framework.serializers.ModelSerializer): min_length=8, style={'input_type': 'password'}, ) - name = rest_framework.serializers.CharField(required=True, min_length=1) - surname = rest_framework.serializers.CharField(required=True, min_length=1) + name = rest_framework.serializers.CharField( + required=True, + min_length=1, + max_length=100, + ) + surname = rest_framework.serializers.CharField( + required=True, + min_length=1, + max_length=120, + ) email = rest_framework.serializers.EmailField( required=True, min_length=8, + max_length=120, validators=[ user.validators.UniqueEmailValidator( 'This email address is already registered.', diff --git a/promo_code/user/tests/auth/test_authentication.py b/promo_code/user/tests/auth/test_authentication.py index adbce19..5e71bca 100644 --- a/promo_code/user/tests/auth/test_authentication.py +++ b/promo_code/user/tests/auth/test_authentication.py @@ -1,5 +1,3 @@ -import django.test -import django.urls import rest_framework.status import rest_framework.test @@ -22,7 +20,7 @@ def test_signin_success(self): 'password': 'SuperStrongPassword2000!', } response = self.client.post( - django.urls.reverse('api-user:sign-in'), + self.signin_url, data, format='json', ) diff --git a/promo_code/user/views.py b/promo_code/user/views.py index cad415a..a782599 100644 --- a/promo_code/user/views.py +++ b/promo_code/user/views.py @@ -7,21 +7,12 @@ import rest_framework_simplejwt.tokens import rest_framework_simplejwt.views +import core.views import user.serializers -class BaseCustomResponseMixin: - error_response = {'status': 'error', 'message': 'Error in request data.'} - - def handle_validation_error(self): - return rest_framework.response.Response( - self.error_response, - status=rest_framework.status.HTTP_400_BAD_REQUEST, - ) - - class SignUpView( - BaseCustomResponseMixin, + core.views.BaseCustomResponseMixin, rest_framework.generics.CreateAPIView, ): serializer_class = user.serializers.SignUpSerializer @@ -47,7 +38,7 @@ def create(self, request, *args, **kwargs): class SignInView( - BaseCustomResponseMixin, + core.views.BaseCustomResponseMixin, rest_framework_simplejwt.views.TokenObtainPairView, ): serializer_class = user.serializers.SignInSerializer