diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 1d9d960..6b2ae0f 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -6,6 +6,10 @@ import rest_framework.exceptions import rest_framework.serializers import rest_framework.status +import rest_framework_simplejwt.exceptions +import rest_framework_simplejwt.serializers +import rest_framework_simplejwt.tokens +import rest_framework_simplejwt.views class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer): @@ -90,3 +94,39 @@ def validate(self, attrs): ) return attrs + + +class CompanyTokenRefreshSerializer( + rest_framework_simplejwt.serializers.TokenRefreshSerializer, +): + def validate(self, attrs): + refresh = rest_framework_simplejwt.tokens.RefreshToken( + attrs['refresh'], + ) + user_type = refresh.payload.get('user_type', 'user') + + if user_type != 'company': + raise rest_framework_simplejwt.exceptions.InvalidToken( + 'This refresh endpoint is for company tokens only', + ) + + company_id = refresh.payload.get('company_id') + if not company_id: + raise rest_framework_simplejwt.exceptions.InvalidToken( + 'Company ID missing in token', + ) + + try: + company = business_models.Company.objects.get(id=company_id) + except business_models.Company.DoesNotExist: + raise rest_framework_simplejwt.exceptions.InvalidToken( + 'Company not found', + ) + + token_version = refresh.payload.get('token_version', 0) + if company.token_version != token_version: + raise rest_framework_simplejwt.exceptions.InvalidToken( + 'Token is blacklisted', + ) + + return super().validate(attrs) diff --git a/promo_code/business/tests/auth/base.py b/promo_code/business/tests/auth/base.py index 3f32203..bd1ced8 100644 --- a/promo_code/business/tests/auth/base.py +++ b/promo_code/business/tests/auth/base.py @@ -10,9 +10,12 @@ class BaseBusinessAuthTestCase(rest_framework.test.APITestCase): def setUpTestData(cls): super().setUpTestData() cls.client = rest_framework.test.APIClient() + cls.company_refresh_url = django.urls.reverse( + 'api-business:company-token-refresh', + ) + cls.protected_url = django.urls.reverse('api-core:protected') cls.signup_url = django.urls.reverse('api-business:company-sign-up') cls.signin_url = django.urls.reverse('api-business:company-sign-in') - cls.protected_url = django.urls.reverse('api-core:protected') cls.valid_data = { 'name': 'Digital Marketing Solutions Inc.', 'email': 'testcompany@example.com', diff --git a/promo_code/business/tests/auth/test_tokens.py b/promo_code/business/tests/auth/test_tokens.py index b2ab66a..c500c25 100644 --- a/promo_code/business/tests/auth/test_tokens.py +++ b/promo_code/business/tests/auth/test_tokens.py @@ -2,6 +2,9 @@ import business.tests.auth.base import rest_framework.status import rest_framework.test +import rest_framework_simplejwt.tokens + +import user.models class JWTTests(business.tests.auth.base.BaseBusinessAuthTestCase): @@ -82,3 +85,193 @@ def test_registration_token_invalid_after_login(self): response.status_code, rest_framework.status.HTTP_200_OK, ) + + +class TestCompanyTokenRefresh( + business.tests.auth.base.BaseBusinessAuthTestCase, +): + def setUp(self): + super().setUp() + + self.company = business.models.Company.objects.create_company( + name='Digital Marketing Solutions Inc.', + email='testcompany@example.com', + password='SuperStrongPassword2000!', + token_version=1, + ) + + self.company_data = { + 'email': 'testcompany@example.com', + 'password': 'SuperStrongPassword2000!', + } + + self.company_refresh = rest_framework_simplejwt.tokens.RefreshToken() + self.company_refresh.payload.update( + { + 'user_type': 'company', + 'company_id': self.company.id, + 'token_version': self.company.token_version, + }, + ) + + self.user = user.models.User.objects.create_user( + email='minecraft.digger@gmail.com', + name='Steve', + surname='Jobs', + password='SuperStrongPassword2000!', + other={'age': 23, 'country': 'gb'}, + ) + self.user_refresh = ( + rest_framework_simplejwt.tokens.RefreshToken.for_user(self.user) + ) + self.user_refresh.payload['user_type'] = 'user' + + def test_successful_company_token_refresh(self): + response = self.client.post( + self.company_refresh_url, + {'refresh': str(self.company_refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_200_OK, + ) + self.assertIn('access', response.data) + self.assertIn('refresh', response.data) + + self.assertNotEqual(self.company_refresh, response.data['refresh']) + + def test_reject_user_tokens(self): + response = self.client.post( + self.company_refresh_url, + {'refresh': str(self.user_refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertIn( + 'This refresh endpoint is for company tokens only', + str(response.content), + ) + + def test_token_version_mismatch(self): + self.company.token_version = 2 + self.company.save() + + response = self.client.post( + self.company_refresh_url, + {'refresh': str(self.company_refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertIn('Token is blacklisted', str(response.content)) + + def test_missing_company_id(self): + invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken() + invalid_refresh.payload.update( + {'user_type': 'company', 'token_version': 1}, + ) + + response = self.client.post( + self.company_refresh_url, + {'refresh': str(invalid_refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertIn( + 'Company ID missing in token', + str(response.content.decode()), + ) + + def test_company_not_found(self): + invalid_refresh = rest_framework_simplejwt.tokens.RefreshToken() + invalid_refresh.payload.update( + {'user_type': 'company', 'company_id': 999, 'token_version': 1}, + ) + + response = self.client.post( + self.company_refresh_url, + {'refresh': str(invalid_refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertIn('Company not found', str(response.content)) + + def test_refresh_token_invalidation_after_new_login(self): + first_login_response = self.client.post( + self.signin_url, + self.company_data, + format='json', + ) + refresh_token_v1 = first_login_response.data['refresh'] + + second_login_response = self.client.post( + self.signin_url, + self.company_data, + format='json', + ) + refresh_token_v2 = second_login_response.data['refresh'] + + refresh_response_v1 = self.client.post( + self.company_refresh_url, + {'refresh': refresh_token_v1}, + format='json', + ) + self.assertEqual( + refresh_response_v1.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertEqual(refresh_response_v1.data['code'], 'token_not_valid') + self.assertEqual( + str(refresh_response_v1.data['detail']), + 'Token is blacklisted', + ) + + refresh_response_v2 = self.client.post( + self.company_refresh_url, + {'refresh': refresh_token_v2}, + format='json', + ) + self.assertEqual( + refresh_response_v2.status_code, + rest_framework.status.HTTP_200_OK, + ) + self.assertIn('access', refresh_response_v2.data) + + self.client.credentials( + HTTP_AUTHORIZATION='Bearer ' + first_login_response.data['access'], + ) + protected_response = self.client.get(self.protected_url) + self.assertEqual( + protected_response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + + def test_default_user_type_handling(self): + refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user( + self.user, + ) + response = self.client.post( + self.company_refresh_url, + {'refresh': str(refresh)}, + ) + + self.assertEqual( + response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertIn( + 'This refresh endpoint is for company tokens only', + str(response.content), + ) diff --git a/promo_code/business/urls.py b/promo_code/business/urls.py index 4d40321..1f7f693 100644 --- a/promo_code/business/urls.py +++ b/promo_code/business/urls.py @@ -15,4 +15,9 @@ business.views.CompanySignInView.as_view(), name='company-sign-in', ), + django.urls.path( + 'token/refresh', + business.views.CompanyTokenRefreshView.as_view(), + name='company-token-refresh', + ), ] diff --git a/promo_code/business/views.py b/promo_code/business/views.py index 0c9856b..0b6bbb1 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -97,3 +97,7 @@ def post(self, request): response_data, status=rest_framework.status.HTTP_200_OK, ) + + +class CompanyTokenRefreshView(rest_framework_simplejwt.views.TokenRefreshView): + serializer_class = business.serializers.CompanyTokenRefreshSerializer diff --git a/promo_code/user/tests/auth/base.py b/promo_code/user/tests/auth/base.py index fc13f54..66c49ed 100644 --- a/promo_code/user/tests/auth/base.py +++ b/promo_code/user/tests/auth/base.py @@ -11,7 +11,7 @@ def setUpTestData(cls): super().setUpTestData() cls.client = rest_framework.test.APIClient() cls.protected_url = django.urls.reverse('api-core:protected') - cls.refresh_url = django.urls.reverse('api-user:token_refresh') + cls.refresh_url = django.urls.reverse('api-user:user-token-refresh') cls.signup_url = django.urls.reverse('api-user:sign-up') cls.signin_url = django.urls.reverse('api-user:sign-in') diff --git a/promo_code/user/urls.py b/promo_code/user/urls.py index 824788d..c2bc733 100644 --- a/promo_code/user/urls.py +++ b/promo_code/user/urls.py @@ -20,6 +20,6 @@ django.urls.path( 'token/refresh/', rest_framework_simplejwt.views.TokenRefreshView.as_view(), - name='token_refresh', + name='user-token-refresh', ), ]