diff --git a/promo_code/core/tests.py b/promo_code/core/tests.py index 3bdd0b7..bb9bd0b 100644 --- a/promo_code/core/tests.py +++ b/promo_code/core/tests.py @@ -6,5 +6,5 @@ class StaticURLTests(django.test.TestCase): def test_ping_endpoint(self): - response = self.client.get(django.urls.reverse('core:ping')) + response = self.client.get(django.urls.reverse('api-core:ping')) self.assertEqual(response.status_code, http.HTTPStatus.OK) diff --git a/promo_code/core/urls.py b/promo_code/core/urls.py index 864359e..fae4a95 100644 --- a/promo_code/core/urls.py +++ b/promo_code/core/urls.py @@ -1,7 +1,7 @@ import core.views import django.urls -app_name = 'core' +app_name = 'api-core' urlpatterns = [ @@ -10,4 +10,9 @@ core.views.PingView.as_view(), name='ping', ), + django.urls.path( + 'protected-endpoint/', + core.views.MyProtectedView.as_view(), + name='protected', + ), ] diff --git a/promo_code/core/views.py b/promo_code/core/views.py index ceaf1dc..e6f149b 100644 --- a/promo_code/core/views.py +++ b/promo_code/core/views.py @@ -1,7 +1,18 @@ import django.http import django.views +import rest_framework.permissions +import rest_framework.response +import rest_framework.views class PingView(django.views.View): def get(self, request, *args, **kwargs): return django.http.HttpResponse('PROOOOOOOOOOOOOOOOOD', status=200) + + +class MyProtectedView(rest_framework.views.APIView): + permission_classes = [rest_framework.permissions.IsAuthenticated] + + def get(self, request, format=None): + content = {'status': 'request was permitted'} + return rest_framework.response.Response(content) diff --git a/promo_code/promo_code/settings.py b/promo_code/promo_code/settings.py index aaaad76..40b4c34 100644 --- a/promo_code/promo_code/settings.py +++ b/promo_code/promo_code/settings.py @@ -48,47 +48,55 @@ def load_bool(name, default): AUTH_USER_MODEL = 'user.User' REST_FRAMEWORK = { - 'DEFAULT_RENDERER_CLASSES': ('rest_framework.renderers.JSONRenderer',), 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'user.authentication.CustomJWTAuthentication', + 'rest_framework_simplejwt.authentication.JWTAuthentication', ], } SIMPLE_JWT = { - 'ACCESS_TOKEN_LIFETIME': datetime.timedelta(hours=1), + 'ACCESS_TOKEN_LIFETIME': datetime.timedelta(minutes=60), 'REFRESH_TOKEN_LIFETIME': datetime.timedelta(days=1), 'ROTATE_REFRESH_TOKENS': True, 'BLACKLIST_AFTER_ROTATION': True, - 'UPDATE_LAST_LOGIN': False, # ! - # + 'UPDATE_LAST_LOGIN': False, 'ALGORITHM': 'HS256', - 'SIGNING_KEY': SECRET_KEY, - 'VERIFYING_KEY': None, + 'VERIFYING_KEY': '', 'AUDIENCE': None, 'ISSUER': None, 'JSON_ENCODER': None, 'JWK_URL': None, 'LEEWAY': 0, - # 'AUTH_HEADER_TYPES': ('Bearer',), 'AUTH_HEADER_NAME': 'HTTP_AUTHORIZATION', 'USER_ID_FIELD': 'id', 'USER_ID_CLAIM': 'user_id', 'USER_AUTHENTICATION_RULE': ( 'rest_framework_simplejwt.authentication' - '.default_user_authentication_rule', + '.default_user_authentication_rule' ), - # + 'AUTH_TOKEN_CLASSES': ('rest_framework_simplejwt.tokens.AccessToken',), 'TOKEN_TYPE_CLAIM': 'token_type', 'TOKEN_USER_CLASS': 'rest_framework_simplejwt.models.TokenUser', - # 'JTI_CLAIM': 'jti', - # 'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp', 'SLIDING_TOKEN_LIFETIME': datetime.timedelta(minutes=5), 'SLIDING_TOKEN_REFRESH_LIFETIME': datetime.timedelta(days=1), - # - 'ACCESS_TOKEN_CLASS': 'user.tokens.CustomAccessToken', + 'TOKEN_OBTAIN_SERIALIZER': 'user.serializers.SignInSerializer', + 'TOKEN_REFRESH_SERIALIZER': ( + 'rest_framework_simplejwt.serializers.TokenRefreshSerializer' + ), + 'TOKEN_VERIFY_SERIALIZER': ( + 'rest_framework_simplejwt.serializers.TokenVerifySerializer' + ), + 'TOKEN_BLACKLIST_SERIALIZER': ( + 'rest_framework_simplejwt.serializers.TokenBlacklistSerializer' + ), + 'SLIDING_TOKEN_OBTAIN_SERIALIZER': ( + 'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer' + ), + 'SLIDING_TOKEN_REFRESH_SERIALIZER': ( + 'rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer' + ), } MIDDLEWARE = [ @@ -99,6 +107,7 @@ def load_bool(name, default): 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'user.middleware.TokenVersionMiddleware', ] ROOT_URLCONF = 'promo_code.urls' diff --git a/promo_code/user/authentication.py b/promo_code/user/authentication.py deleted file mode 100644 index 5294ea6..0000000 --- a/promo_code/user/authentication.py +++ /dev/null @@ -1,20 +0,0 @@ -import datetime - -import rest_framework.exceptions -import rest_framework_simplejwt.authentication - - -class CustomJWTAuthentication( - rest_framework_simplejwt.authentication.JWTAuthentication, -): - def get_user(self, validated_token): - user = super().get_user(validated_token) - last_login_str = validated_token.get('last_login') - if last_login_str: - last_login = datetime.datetime.fromisoformat(last_login_str) - if user.last_login and user.last_login > last_login: - raise rest_framework.exceptions.AuthenticationFailed( - 'Token has been invalidated', - ) - - return user diff --git a/promo_code/user/middleware.py b/promo_code/user/middleware.py new file mode 100644 index 0000000..d4b1ba3 --- /dev/null +++ b/promo_code/user/middleware.py @@ -0,0 +1,25 @@ +import django.http +import rest_framework_simplejwt.authentication + + +class TokenVersionMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + auth = rest_framework_simplejwt.authentication.JWTAuthentication() + auth_result = auth.authenticate(request) + + if auth_result is None: + return self.get_response(request) + + user, token = auth_result + if user: + token_version = token.payload.get('token_version', 0) + if token_version != user.token_version: + return django.http.JsonResponse( + {'error': 'Token invalid'}, + status=401, + ) + + return self.get_response(request) diff --git a/promo_code/user/migrations/0002_user_token_version.py b/promo_code/user/migrations/0002_user_token_version.py new file mode 100644 index 0000000..b58577f --- /dev/null +++ b/promo_code/user/migrations/0002_user_token_version.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2b1 on 2025-03-14 19:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('user', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='token_version', + field=models.IntegerField(default=0), + ), + ] diff --git a/promo_code/user/models.py b/promo_code/user/models.py index 225fb3b..dacda78 100644 --- a/promo_code/user/models.py +++ b/promo_code/user/models.py @@ -46,6 +46,8 @@ class User( ) other = django.db.models.JSONField(default=dict) + token_version = django.db.models.IntegerField(default=0) + is_active = django.db.models.BooleanField(default=True) is_staff = django.db.models.BooleanField(default=False) last_login = django.db.models.DateTimeField(null=True, blank=True) diff --git a/promo_code/user/serializers.py b/promo_code/user/serializers.py index 6debfff..d4eba84 100644 --- a/promo_code/user/serializers.py +++ b/promo_code/user/serializers.py @@ -4,6 +4,8 @@ import rest_framework.exceptions import rest_framework.serializers import rest_framework.status +import rest_framework_simplejwt.serializers +import rest_framework_simplejwt.token_blacklist.models as tb_models import rest_framework_simplejwt.tokens import user.models as user_models @@ -69,7 +71,9 @@ def create(self, validated_data): raise rest_framework.serializers.ValidationError(e.messages) -class SignInSerializer(rest_framework.serializers.Serializer): +class SignInSerializer( + rest_framework_simplejwt.serializers.TokenObtainPairSerializer, +): email = rest_framework.serializers.EmailField(required=True) password = rest_framework.serializers.CharField( required=True, @@ -97,10 +101,51 @@ def validate(self, data): code='authorization', ) - data['user'] = user + authenticate_kwargs = { + self.username_field: data[self.username_field], + 'password': data['password'], + } + try: + authenticate_kwargs['request'] = self.context['request'] + except KeyError: + pass + + self.user = django.contrib.auth.authenticate(**authenticate_kwargs) + + if not getattr(self.user, 'is_active', None): + raise rest_framework.exceptions.AuthenticationFailed( + self.error_messages['no_active_account'], + 'no_active_account', + ) + + self.user.token_version += 1 + self.user.save() + + refresh = self.get_token(self.user) + data = { + 'refresh': str(refresh), + 'access': str(refresh.access_token), + } + + current_jti = refresh['jti'] + + tokens_qs = tb_models.OutstandingToken.objects.filter( + user=self.user, + ) + + outstanding_tokens = tokens_qs.exclude(jti=current_jti) + + for token in outstanding_tokens: + ( + tb_models.BlacklistedToken.objects.get_or_create( + token=token, + ) + ) + + data['token_version'] = self.user.token_version return data - def get_token(self): - user = self.validated_data['user'] - refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user) - return {'token': str(refresh.access_token)} + def get_token(self, user): + token = super().get_token(user) + token['token_version'] = user.token_version + return token diff --git a/promo_code/user/tests.py b/promo_code/user/tests.py index 93800b9..60778e9 100644 --- a/promo_code/user/tests.py +++ b/promo_code/user/tests.py @@ -3,9 +3,9 @@ import parameterized import rest_framework.status import rest_framework.test +import rest_framework_simplejwt.token_blacklist.models as tb_models import user.models -import user.tokens class AuthTestCase(rest_framework.test.APITestCase): @@ -18,7 +18,6 @@ def tearDown(self): super(AuthTestCase, self).tearDown() def test_valid_registration_and_email_duplication(self): - # Successful registration valid_data = { 'name': 'Emma', 'surname': 'Thompson', @@ -36,7 +35,6 @@ def test_valid_registration_and_email_duplication(self): rest_framework.status.HTTP_200_OK, ) - # Duplicate email registration attempt duplicate_data = { 'name': 'Lui', 'surname': 'Jomalone', @@ -389,3 +387,127 @@ def test_signin_success(self): response.status_code, rest_framework.status.HTTP_200_OK, ) + + +class JWTTests(rest_framework.test.APITestCase): + def setUp(self): + + self.signin_url = django.urls.reverse('api-user:sign-in') + self.protected_url = django.urls.reverse('api-core:protected') + self.refresh_url = django.urls.reverse('api-user:token_refresh') + user.models.User.objects.create_user( + name='John', + surname='Doe', + email='example@example.com', + password='SuperStrongPassword2000!', + other={'age': 25, 'country': 'us'}, + ) + self.user_data = { + 'email': 'example@example.com', + 'password': 'SuperStrongPassword2000!', + } + + super(JWTTests, self).setUp() + + def tearDown(self): + user.models.User.objects.all().delete() + + super(JWTTests, self).tearDown() + + def test_access_protected_view_with_valid_token(self): + response = self.client.post( + self.signin_url, + self.user_data, + format='json', + ) + + token = response.data['access'] + + self.client.credentials(HTTP_AUTHORIZATION='Bearer ' + token) + response = self.client.get(self.protected_url) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['status'], 'request was permitted') + + def test_refresh_token_invalidation_after_new_login(self): + + first_login_response = self.client.post( + self.signin_url, + self.user_data, + format='json', + ) + refresh_token_v1 = first_login_response.data['refresh'] + + second_login_response = self.client.post( + self.signin_url, + self.user_data, + format='json', + ) + refresh_token_v2 = second_login_response.data['refresh'] + + refresh_response_v1 = self.client.post( + self.refresh_url, + {'refresh': refresh_token_v1}, + format='json', + ) + self.assertEqual( + refresh_response_v1.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + self.assertEqual(refresh_response_v1.data['code'], 'token_not_valid') + self.assertEqual( + str(refresh_response_v1.data['detail']), + 'Token is blacklisted', + ) + + refresh_response_v2 = self.client.post( + self.refresh_url, + {'refresh': refresh_token_v2}, + format='json', + ) + self.assertEqual( + refresh_response_v2.status_code, + rest_framework.status.HTTP_200_OK, + ) + self.assertIn('access', refresh_response_v2.data) + + self.client.credentials( + HTTP_AUTHORIZATION='Bearer ' + first_login_response.data['access'], + ) + protected_response = self.client.get(self.protected_url) + self.assertEqual( + protected_response.status_code, + rest_framework.status.HTTP_401_UNAUTHORIZED, + ) + + def test_blacklist_storage(self): + + self.client.post(self.signin_url, self.user_data, format='json') + + self.client.post(self.signin_url, self.user_data, format='json') + + self.assertEqual( + (tb_models.BlacklistedToken.objects.count()), + 1, + ) + self.assertEqual( + (tb_models.OutstandingToken.objects.count()), + 2, + ) + + def test_token_version_increment(self): + response1 = self.client.post( + self.signin_url, + self.user_data, + format='json', + ) + self.assertEqual(response1.data['token_version'], 1) + + response2 = self.client.post( + self.signin_url, + self.user_data, + format='json', + ) + self.assertEqual(response2.data['token_version'], 2) + + user_ = user.models.User.objects.get(email=self.user_data['email']) + self.assertEqual(user_.token_version, 2) diff --git a/promo_code/user/tokens.py b/promo_code/user/tokens.py deleted file mode 100644 index fda698e..0000000 --- a/promo_code/user/tokens.py +++ /dev/null @@ -1,9 +0,0 @@ -import rest_framework_simplejwt.tokens - - -class CustomAccessToken(rest_framework_simplejwt.tokens.AccessToken): - @classmethod - def for_user(cls, user): - token = super().for_user(user) - token['last_login'] = user.last_login.isoformat() - return token diff --git a/promo_code/user/urls.py b/promo_code/user/urls.py index 7511bda..824788d 100644 --- a/promo_code/user/urls.py +++ b/promo_code/user/urls.py @@ -1,4 +1,5 @@ import django.urls +import rest_framework_simplejwt.views import user.views @@ -13,7 +14,12 @@ ), django.urls.path( 'auth/sign-in', - user.views.SignInView.as_view(), + rest_framework_simplejwt.views.TokenObtainPairView.as_view(), name='sign-in', ), + django.urls.path( + 'token/refresh/', + rest_framework_simplejwt.views.TokenRefreshView.as_view(), + name='token_refresh', + ), ] diff --git a/promo_code/user/views.py b/promo_code/user/views.py index b8d65ec..24d661a 100644 --- a/promo_code/user/views.py +++ b/promo_code/user/views.py @@ -3,6 +3,7 @@ import rest_framework.response import rest_framework.serializers import rest_framework.status +import rest_framework_simplejwt.exceptions import rest_framework_simplejwt.tokens import rest_framework_simplejwt.views @@ -43,7 +44,7 @@ def create(self, request, *args, **kwargs): class SignInView( BaseCustomResponseMixin, - rest_framework_simplejwt.views.TokenViewBase, + rest_framework_simplejwt.views.TokenObtainPairView, ): serializer_class = user.serializers.SignInSerializer @@ -52,10 +53,17 @@ def post(self, request, *args, **kwargs): try: serializer.is_valid(raise_exception=True) - except rest_framework.serializers.ValidationError: - return self.handle_validation_error() + response = super().post(request, *args, **kwargs) + except ( + rest_framework.serializers.ValidationError, + rest_framework_simplejwt.exceptions.TokenError, + ) as e: + if isinstance(e, rest_framework.serializers.ValidationError): + return self.handle_validation_error() + + raise rest_framework_simplejwt.exceptions.InvalidToken(str(e)) return rest_framework.response.Response( - serializer.get_token(), + response, status=rest_framework.status.HTTP_200_OK, )