diff --git a/promo_code/user/serializers.py b/promo_code/user/serializers.py index d4eba84..7de89d7 100644 --- a/promo_code/user/serializers.py +++ b/promo_code/user/serializers.py @@ -66,6 +66,8 @@ def create(self, validated_data): other=validated_data['other'], password=validated_data['password'], ) + user.token_version += 1 + user.save() return user except django.core.exceptions.ValidationError as e: raise rest_framework.serializers.ValidationError(e.messages) @@ -80,13 +82,26 @@ class SignInSerializer( write_only=True, ) - def validate(self, data): - email = data.get('email') - password = data.get('password') + def validate(self, attrs): + user = self.authenticate_user(attrs) + + self.update_token_version(user) + + data = super().validate(attrs) + + refresh = rest_framework_simplejwt.tokens.RefreshToken(data['refresh']) + + self.invalidate_previous_tokens(user, refresh['jti']) + + return data + + def authenticate_user(self, attrs): + email = attrs.get('email') + password = attrs.get('password') if not email or not password: - raise rest_framework.serializers.ValidationError( - {'status': 'error', 'message': 'Both fields are required.'}, + raise rest_framework.exceptions.ValidationError( + {'detail': 'Both email and password are required'}, code='required', ) @@ -95,55 +110,26 @@ def validate(self, data): email=email, password=password, ) - if not user: - raise rest_framework.exceptions.AuthenticationFailed( - {'status': 'error', 'message': 'Invalid email or password.'}, - code='authorization', - ) - authenticate_kwargs = { - self.username_field: data[self.username_field], - 'password': data['password'], - } - try: - authenticate_kwargs['request'] = self.context['request'] - except KeyError: - pass - - self.user = django.contrib.auth.authenticate(**authenticate_kwargs) - - if not getattr(self.user, 'is_active', None): + if not user or not user.is_active: raise rest_framework.exceptions.AuthenticationFailed( - self.error_messages['no_active_account'], - 'no_active_account', + {'detail': 'Invalid credentials or inactive account'}, + code='authentication_failed', ) - self.user.token_version += 1 - self.user.save() + return user - refresh = self.get_token(self.user) - data = { - 'refresh': str(refresh), - 'access': str(refresh.access_token), - } - - current_jti = refresh['jti'] - - tokens_qs = tb_models.OutstandingToken.objects.filter( - user=self.user, - ) - - outstanding_tokens = tokens_qs.exclude(jti=current_jti) + def invalidate_previous_tokens(self, user, current_jti): + outstanding_tokens = tb_models.OutstandingToken.objects.filter( + user=user, + ).exclude(jti=current_jti) for token in outstanding_tokens: - ( - tb_models.BlacklistedToken.objects.get_or_create( - token=token, - ) - ) + tb_models.BlacklistedToken.objects.get_or_create(token=token) - data['token_version'] = self.user.token_version - return data + def update_token_version(self, user): + user.token_version += 1 + user.save() def get_token(self, user): token = super().get_token(user) diff --git a/promo_code/user/tests.py b/promo_code/user/tests.py index 60778e9..fc39bda 100644 --- a/promo_code/user/tests.py +++ b/promo_code/user/tests.py @@ -324,7 +324,7 @@ def test_valid_registration(self): response.status_code, rest_framework.status.HTTP_200_OK, ) - self.assertIn('token', response.data) + self.assertIn('access', response.data) self.assertTrue( user.models.User.objects.filter( email='minecraft.digger@gmail.com', @@ -391,7 +391,7 @@ def test_signin_success(self): class JWTTests(rest_framework.test.APITestCase): def setUp(self): - + self.signup_url = django.urls.reverse('api-user:sign-up') self.signin_url = django.urls.reverse('api-user:sign-in') self.protected_url = django.urls.reverse('api-core:protected') self.refresh_url = django.urls.reverse('api-user:token_refresh') @@ -428,13 +428,54 @@ def test_access_protected_view_with_valid_token(self): self.assertEqual(response.status_code, 200) self.assertEqual(response.data['status'], 'request was permitted') - def test_refresh_token_invalidation_after_new_login(self): + def test_registration_token_invalid_after_login(self): + data = { + 'email': 'test@example.com', + 'password': 'StrongPass123!cd', + 'name': 'John', + 'surname': 'Doe', + 'other': {'age': 22, 'country': 'us'}, + } + response = self.client.post( + self.signup_url, + data, + format='json', + ) + reg_access_token = response.data['access'] + + self.client.credentials( + HTTP_AUTHORIZATION=f'Bearer {reg_access_token}', + ) + response = self.client.get(self.protected_url) + self.assertEqual(response.status_code, 200) + + login_data = {'email': data['email'], 'password': data['password']} + response = self.client.post( + self.signin_url, + login_data, + format='json', + ) + login_access_token = response.data['access'] + + self.client.credentials( + HTTP_AUTHORIZATION=f'Bearer {reg_access_token}', + ) + response = self.client.get(self.protected_url) + self.assertEqual(response.status_code, 401) + + self.client.credentials( + HTTP_AUTHORIZATION=f'Bearer {login_access_token}', + ) + response = self.client.get(self.protected_url) + self.assertEqual(response.status_code, 200) + def test_refresh_token_invalidation_after_new_login(self): first_login_response = self.client.post( self.signin_url, self.user_data, format='json', ) + refresh_token_v1 = first_login_response.data['refresh'] second_login_response = self.client.post( @@ -493,21 +534,3 @@ def test_blacklist_storage(self): (tb_models.OutstandingToken.objects.count()), 2, ) - - def test_token_version_increment(self): - response1 = self.client.post( - self.signin_url, - self.user_data, - format='json', - ) - self.assertEqual(response1.data['token_version'], 1) - - response2 = self.client.post( - self.signin_url, - self.user_data, - format='json', - ) - self.assertEqual(response2.data['token_version'], 2) - - user_ = user.models.User.objects.get(email=self.user_data['email']) - self.assertEqual(user_.token_version, 2) diff --git a/promo_code/user/views.py b/promo_code/user/views.py index 24d661a..cad415a 100644 --- a/promo_code/user/views.py +++ b/promo_code/user/views.py @@ -35,9 +35,13 @@ def create(self, request, *args, **kwargs): return self.handle_validation_error() user = serializer.save() + refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user) + refresh['token_version'] = user.token_version + access_token = refresh.access_token + return rest_framework.response.Response( - {'token': str(refresh.access_token)}, + {'access': str(access_token), 'refresh': str(refresh)}, status=rest_framework.status.HTTP_200_OK, ) @@ -49,11 +53,10 @@ class SignInView( serializer_class = user.serializers.SignInSerializer def post(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) try: + serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - response = super().post(request, *args, **kwargs) except ( rest_framework.serializers.ValidationError, rest_framework_simplejwt.exceptions.TokenError, @@ -63,7 +66,12 @@ def post(self, request, *args, **kwargs): raise rest_framework_simplejwt.exceptions.InvalidToken(str(e)) + response_data = { + 'access': serializer.validated_data['access'], + 'refresh': serializer.validated_data['refresh'], + } + return rest_framework.response.Response( - response, + response_data, status=rest_framework.status.HTTP_200_OK, )