diff --git a/Makefile b/Makefile index e391efab7..6864d368d 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ CURRENT_SIGN_SETTING := $(shell git config commit.gpgSign) help: @echo "clean-build - remove build artifacts" @echo "clean-pyc - remove Python file artifacts" + @echo "isortfix - fixes the imports order" @echo "lint - check style with flake8" @echo "test - run tests quickly with the default Python" @echo "testall - run tests on every Python version with tox" @@ -23,6 +24,9 @@ clean-pyc: find . -name '*.pyo' -exec rm -f {} + find . -name '*~' -exec rm -f {} + +isortfix: + isort --recursive --skip migrations docs + lint: tox -e lint diff --git a/docs/conf.py b/docs/conf.py index b3d97f220..9e611fdd0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,6 +17,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. #sys.path.insert(0, os.path.abspath('.')) +import doctest import os DIR = os.path.dirname('__file__') @@ -306,7 +307,6 @@ def django_configure(): # -- Doctest configuration ---------------------------------------- -import doctest doctest_default_flags = (0 | doctest.DONT_ACCEPT_TRUE_FOR_1 diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index e0fa2f02e..14dedc2f4 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -1,9 +1,10 @@ +from django.contrib.auth import get_user_model from django.utils.translation import gettext_lazy as _ -from rest_framework import HTTP_HEADER_ENCODING, authentication +from rest_framework import HTTP_HEADER_ENCODING, authentication, exceptions +from rest_framework.authentication import CSRFCheck from .exceptions import AuthenticationFailed, InvalidToken, TokenError from .settings import api_settings -from .state import User AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES @@ -16,6 +17,19 @@ ) +def enforce_csrf(request): + """ + Enforce CSRF validation. + """ + check = CSRFCheck() + # populates request.META['CSRF_COOKIE'], which is used in process_view() + check.process_request(request) + reason = check.process_view(request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + + class JWTAuthentication(authentication.BaseAuthentication): """ An authentication plugin that authenticates requests through a JSON web @@ -23,18 +37,31 @@ class JWTAuthentication(authentication.BaseAuthentication): """ www_authenticate_realm = 'api' + def __init__(self): + self.user_model = get_user_model() + def authenticate(self, request): header = self.get_header(request) if header is None: - return None + if not api_settings.AUTH_COOKIE: + return None + raw_token = request.COOKIES.get(api_settings.AUTH_COOKIE) or None + else: + raw_token = self.get_raw_token(header) - raw_token = self.get_raw_token(header) if raw_token is None: return None validated_token = self.get_validated_token(raw_token) - return self.get_user(validated_token), validated_token + user = self.get_user(validated_token) + if not user or not user.is_active: + return None + + if api_settings.AUTH_COOKIE: + enforce_csrf(request) + + return user, validated_token def authenticate_header(self, request): return '{0} realm="{1}"'.format( @@ -107,8 +134,8 @@ def get_user(self, validated_token): raise InvalidToken(_('Token contained no recognizable user identification')) try: - user = User.objects.get(**{api_settings.USER_ID_FIELD: user_id}) - except User.DoesNotExist: + user = self.user_model.objects.get(**{api_settings.USER_ID_FIELD: user_id}) + except self.user_model.DoesNotExist: raise AuthenticationFailed(_('User not found'), code='user_not_found') if not user.is_active: diff --git a/rest_framework_simplejwt/models.py b/rest_framework_simplejwt/models.py index d317788d5..c7b1c1859 100644 --- a/rest_framework_simplejwt/models.py +++ b/rest_framework_simplejwt/models.py @@ -19,11 +19,10 @@ class instead of a `User` model instance. Instances of this class act as # inactive user is_active = True - _groups = EmptyManager(auth_models.Group) - _user_permissions = EmptyManager(auth_models.Permission) - def __init__(self, token): self.token = token + self._groups = EmptyManager(auth_models.Group) + self._user_permissions = EmptyManager(auth_models.Permission) def __str__(self): return 'TokenUser {}'.format(self.id) diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index af8f4fc83..29d71dc9b 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -1,12 +1,11 @@ import importlib -from django.contrib.auth import authenticate +from django.contrib.auth import authenticate, get_user_model from django.contrib.auth.models import update_last_login from django.utils.translation import gettext_lazy as _ from rest_framework import exceptions, serializers from .settings import api_settings -from .state import User from .tokens import RefreshToken, SlidingToken, UntypedToken rule_package, user_eligible_for_login = api_settings.USER_AUTHENTICATION_RULE.rsplit('.', 1) @@ -24,7 +23,7 @@ def __init__(self, *args, **kwargs): class TokenObtainSerializer(serializers.Serializer): - username_field = User.USERNAME_FIELD + username_field = get_user_model().USERNAME_FIELD default_error_messages = { 'no_active_account': _('No active account found with the given credentials') @@ -70,9 +69,10 @@ def validate(self, attrs): data = super().validate(attrs) refresh = self.get_token(self.user) + access = refresh.access_token + data['access'] = str(access) data['refresh'] = str(refresh) - data['access'] = str(refresh.access_token) if api_settings.UPDATE_LAST_LOGIN: update_last_login(None, self.user) @@ -104,7 +104,8 @@ class TokenRefreshSerializer(serializers.Serializer): def validate(self, attrs): refresh = RefreshToken(attrs['refresh']) - data = {'access': str(refresh.access_token)} + access = refresh.access_token + data = {'access': str(access)} if api_settings.ROTATE_REFRESH_TOKENS: if api_settings.BLACKLIST_AFTER_ROTATION: diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index 7d91f7d4e..6b951d947 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -37,8 +37,23 @@ 'SLIDING_TOKEN_REFRESH_EXP_CLAIM': 'refresh_exp', 'SLIDING_TOKEN_LIFETIME': timedelta(minutes=5), 'SLIDING_TOKEN_REFRESH_LIFETIME': timedelta(days=1), + + # Cookie name. Enables cookies if value is set. + 'AUTH_COOKIE': None, + # A string like "example.com", or None for standard domain cookie. + 'AUTH_COOKIE_DOMAIN': settings.CSRF_COOKIE_DOMAIN, + # Whether the auth cookies should be secure (https:// only). + 'AUTH_COOKIE_SECURE': settings.CSRF_COOKIE_SECURE, + # The path of the auth cookie. + 'AUTH_COOKIE_PATH': settings.CSRF_COOKIE_PATH, } +# Whether to set the flag restricting cookie leaks on cross-site requests. +# This can be 'Lax', 'Strict', or None to disable the flag. 'None' is supported in version 3.1 only +# CSRF_COOKIE_SAMESITE was introduced in django 2.1 https://docs.djangoproject.com/en/3.1/releases/2.1/#csrf +if hasattr(settings, 'CSRF_COOKIE_SAMESITE'): + DEFAULTS['AUTH_COOKIE_SAMESITE'] = settings.CSRF_COOKIE_SAMESITE + IMPORT_STRINGS = ( 'AUTH_TOKEN_CLASSES', 'TOKEN_USER_CLASS', diff --git a/rest_framework_simplejwt/state.py b/rest_framework_simplejwt/state.py index 9a13e393a..697a79e3d 100644 --- a/rest_framework_simplejwt/state.py +++ b/rest_framework_simplejwt/state.py @@ -1,8 +1,5 @@ -from django.contrib.auth import get_user_model - from .backends import TokenBackend from .settings import api_settings -User = get_user_model() token_backend = TokenBackend(api_settings.ALGORITHM, api_settings.SIGNING_KEY, api_settings.VERIFYING_KEY, api_settings.AUDIENCE, api_settings.ISSUER) diff --git a/rest_framework_simplejwt/views.py b/rest_framework_simplejwt/views.py index fec1edcac..a87bfdafe 100644 --- a/rest_framework_simplejwt/views.py +++ b/rest_framework_simplejwt/views.py @@ -1,9 +1,19 @@ +from django.conf import settings +from django.middleware import csrf +from django.utils.translation import gettext_lazy as _ from rest_framework import generics, status +from rest_framework.exceptions import NotAuthenticated from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.views import APIView + +from rest_framework_simplejwt.settings import api_settings +from rest_framework_simplejwt.tokens import RefreshToken from . import serializers from .authentication import AUTH_HEADER_TYPES from .exceptions import InvalidToken, TokenError +from .utils import aware_utcnow, datetime_from_epoch class TokenViewBase(generics.GenericAPIView): @@ -28,10 +38,93 @@ def post(self, request, *args, **kwargs): except TokenError as e: raise InvalidToken(e.args[0]) - return Response(serializer.validated_data, status=status.HTTP_200_OK) + data = serializer.validated_data + + # Don't return the token in the response body if the auth tokens are in a httpOnly cookie + # Only return the CSRF token + if api_settings.AUTH_COOKIE: + csrf_token = csrf.get_token(self.request) + cookie_data = self.get_cookie_data() + response = Response({'csrf_token': csrf_token}, status=status.HTTP_200_OK) + return self.set_auth_cookies(response, data, cookie_data) + + return Response(data, status=status.HTTP_200_OK) + + def get_cookie_data(self): + cookie_data = { + 'expires': self.get_refresh_token_expiration(), + 'domain': api_settings.AUTH_COOKIE_DOMAIN, + 'path': api_settings.AUTH_COOKIE_PATH, + 'secure': api_settings.AUTH_COOKIE_SECURE or None, + 'httponly': True + } + # prior to django 2.1 samesite was not supported + if hasattr(api_settings, 'AUTH_COOKIE_SAMESITE'): + cookie_data['samesite'] = api_settings.AUTH_COOKIE_SAMESITE + return cookie_data + + def set_auth_cookies(self, response, data, cookie_data): + return response + + def get_refresh_token_expiration(self): + return aware_utcnow() + api_settings.REFRESH_TOKEN_LIFETIME + + +class TokenRefreshViewBase(TokenViewBase): + def extract_token_from_cookie(self, request): + return request + + def post(self, request, *args, **kwargs): + if api_settings.AUTH_COOKIE: + request = self.extract_token_from_cookie(request) + return super().post(request, *args, **kwargs) + + +class BaseTokenCookieViewMixin: + + def extract_token_from_cookie(self, request): + """Extracts token from cookie and sets it in request.data as it would be sent by the user""" + if not request.data: + token = request.COOKIES.get(self.token_refresh_cookie_name) + if not token: + raise NotAuthenticated(detail=_('Refresh cookie not set. Try to authenticate first.')) + request.data[self.token_refresh_data_key] = token + return request + + def get_refresh_token_expiration(self): + return aware_utcnow() + api_settings.REFRESH_TOKEN_LIFETIME + +class TokenCookieViewMixin(BaseTokenCookieViewMixin): + token_refresh_view_name = 'token_refresh' + token_refresh_data_key = 'refresh' -class TokenObtainPairView(TokenViewBase): + @property + def token_refresh_cookie_name(self): + return '{}_refresh'.format(api_settings.AUTH_COOKIE) + + def set_auth_cookies(self, response, data, cookie_data): + response.set_cookie( + api_settings.AUTH_COOKIE, + data['access'], + **cookie_data + ) + if 'refresh' in data: + response.set_cookie( + '{}_refresh'.format(api_settings.AUTH_COOKIE), + data['refresh'], + **{ + **cookie_data, + **{ + 'domain': api_settings.AUTH_COOKIE_DOMAIN, + 'path': reverse(self.token_refresh_view_name) + } + } + ) + return response + + +class TokenObtainPairView(TokenCookieViewMixin, TokenViewBase): """ Takes a set of user credentials and returns an access and refresh JSON web token pair to prove the authentication of those credentials. @@ -42,18 +135,40 @@ class TokenObtainPairView(TokenViewBase): token_obtain_pair = TokenObtainPairView.as_view() -class TokenRefreshView(TokenViewBase): +class TokenRefreshView(TokenCookieViewMixin, TokenRefreshViewBase): """ Takes a refresh type JSON web token and returns an access type JSON web token if the refresh token is valid. """ serializer_class = serializers.TokenRefreshSerializer + def get_refresh_token_expiration(self): + if api_settings.ROTATE_REFRESH_TOKENS: + return super().get_refresh_token_expiration() + token = RefreshToken(self.request.data['refresh']) + return datetime_from_epoch(token.payload['exp']) + token_refresh = TokenRefreshView.as_view() -class TokenObtainSlidingView(TokenViewBase): +class SlidingTokenCookieViewMixin(BaseTokenCookieViewMixin): + token_refresh_data_key = 'token' + + @property + def token_refresh_cookie_name(self): + return api_settings.AUTH_COOKIE + + def set_auth_cookies(self, response, data, cookie_data): + response.set_cookie( + api_settings.AUTH_COOKIE, + data['token'], + **cookie_data + ) + return response + + +class TokenObtainSlidingView(SlidingTokenCookieViewMixin, TokenViewBase): """ Takes a set of user credentials and returns a sliding JSON web token to prove the authentication of those credentials. @@ -64,7 +179,7 @@ class TokenObtainSlidingView(TokenViewBase): token_obtain_sliding = TokenObtainSlidingView.as_view() -class TokenRefreshSlidingView(TokenViewBase): +class TokenRefreshSlidingView(SlidingTokenCookieViewMixin, TokenRefreshViewBase): """ Takes a sliding JSON web token and returns a new, refreshed version if the token's refresh period has not expired. @@ -84,3 +199,44 @@ class TokenVerifyView(TokenViewBase): token_verify = TokenVerifyView.as_view() + + +class TokenCookieDeleteView(APIView): + """ + Deletes httpOnly auth cookies. + Used as logout view while using AUTH_COOKIE + """ + token_refresh_view_name = 'token_refresh' + authentication_classes = () + permission_classes = () + + def post(self, request): + response = Response() + + if api_settings.AUTH_COOKIE: + self.delete_auth_cookies(response) + self.delete_csrf_cookie(response) + + return response + + def delete_auth_cookies(self, response): + response.delete_cookie( + api_settings.AUTH_COOKIE, + domain=api_settings.AUTH_COOKIE_DOMAIN, + path=api_settings.AUTH_COOKIE_PATH + ) + response.delete_cookie( + '{}_refresh'.format(api_settings.AUTH_COOKIE), + domain=api_settings.AUTH_COOKIE_DOMAIN, + path=reverse(self.token_refresh_view_name), + ) + + def delete_csrf_cookie(self, response): + response.delete_cookie( + settings.CSRF_COOKIE_NAME, + domain=api_settings.AUTH_COOKIE_DOMAIN, + path=api_settings.AUTH_COOKIE_PATH + ) + + +token_delete = TokenCookieDeleteView.as_view() diff --git a/setup.py b/setup.py index f5f8b5edd..a18e67199 100755 --- a/setup.py +++ b/setup.py @@ -1,8 +1,5 @@ #!/usr/bin/env python -from setuptools import ( - setup, - find_packages, -) +from setuptools import find_packages, setup extras_require = { 'test': [ diff --git a/tests/test_integration.py b/tests/test_integration.py index 7d2db2edc..7f00c425d 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,8 +1,10 @@ from datetime import timedelta +from django.conf import settings +from django.contrib.auth import get_user_model + from rest_framework_simplejwt.compat import reverse from rest_framework_simplejwt.settings import api_settings -from rest_framework_simplejwt.state import User from rest_framework_simplejwt.tokens import AccessToken from .utils import APIViewTestCase, override_api_settings @@ -14,8 +16,8 @@ class TestTestView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -30,7 +32,7 @@ def test_wrong_auth_type(self): res = self.client.post( reverse('token_obtain_sliding'), data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }, ) @@ -50,7 +52,7 @@ def test_expired_token(self): res = self.client.post( reverse('token_obtain_pair'), data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }, ) @@ -70,7 +72,7 @@ def test_user_can_get_sliding_token_and_use_it(self): res = self.client.post( reverse('token_obtain_sliding'), data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }, ) @@ -88,7 +90,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): res = self.client.post( reverse('token_obtain_pair'), data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }, ) @@ -118,3 +120,164 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): self.assertEqual(res.status_code, 200) self.assertEqual(res.data['foo'], 'bar') + + +class TestTestViewWithCookie(APIViewTestCase): + + view_name = 'test_view' + + def setUp(self): + self.username = 'test_user' + self.password = 'test_password' + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( + username=self.username, + password=self.password, + ) + self.client = self.client_class(enforce_csrf_checks=True) + + def test_no_authorization_with_auth_cookie(self): + auth_cookie_name = 'authorization' + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.AccessToken',)): + res = self.view_get() + + self.assertEqual(res.status_code, 401) + self.assertIn('credentials were not provided', res.data['detail']) + + def test_user_can_get_access_refresh_and_delete_sliding_token_and_use_them_with_auth_cookie(self): + auth_cookie_name = 'authorization' + auth_refresh_cookie_name = '%s_refresh' % auth_cookie_name + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + res = self.client.post( + reverse('token_obtain_sliding'), + data={ + self.user_model.USERNAME_FIELD: self.username, + 'password': self.password, + }, + ) + self.assertNotIn('access', res.data) + self.assertGreater(len(res.cookies.get(auth_cookie_name).value), 0) + # Sliding tokens don't have a refresh token, it's a splippery slope if you ask me + self.assertIsNone(res.cookies.get(auth_refresh_cookie_name)) + self.assertEqual(res.status_code, 200) + csrf_token = res.data['csrf_token'] + + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + # Get on test view, this should work + res = self.view_get() + self.assertEqual(res.status_code, 200) + self.assertEqual(res.data['foo'], 'bar') + + # Refresh the token + with override_api_settings(AUTH_COOKIE=auth_cookie_name): + res = self.client.post(reverse('token_refresh_sliding')) + self.assertEqual(res.status_code, 200) + self.assertGreater(len(res.cookies.get(auth_cookie_name).value), 0) + self.assertIsNone(res.cookies.get(auth_refresh_cookie_name)) + + # Get again + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + res = self.view_get() + self.assertEqual(res.status_code, 200) + self.assertEqual(res.data['foo'], 'bar') + + # Try a post without CSRF + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + res = self.view_post(data={}) + self.assertEqual(res.status_code, 403) + + # Add CSRF + self.client.credentials(HTTP_X_CSRFTOKEN=csrf_token) + self.client.cookies[settings.CSRF_COOKIE_NAME] = csrf_token + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + res = self.view_post(data={}) + + self.assertEqual(res.status_code, 200) + + # Delete cookies + with override_api_settings(AUTH_COOKIE=auth_cookie_name): + res = self.client.post(reverse('token_delete')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.cookies.get(auth_cookie_name).value, '') + self.assertEqual(res.cookies.get(auth_refresh_cookie_name).value, '') + self.assertEqual(res.cookies.get(settings.CSRF_COOKIE_NAME).value, '') + + def test_user_can_get_access_refresh_and_delete_tokens_and_use_them_with_auth_cookie(self): + auth_cookie_name = 'authorization' + auth_refresh_cookie_name = '%s_refresh' % auth_cookie_name + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.SlidingToken',)): + res = self.client.post( + reverse('token_obtain_pair'), + data={ + self.user_model.USERNAME_FIELD: self.username, + 'password': self.password, + }, + ) + + # There is no reason to have the tokens in the response body we set them in the cookie + self.assertNotIn('access', res.data) + self.assertNotIn('refresh', res.data) + self.assertIsNotNone(res.data['csrf_token']) + # Make sure set cookie is called + self.assertGreater(len(res.cookies.get(auth_cookie_name).value), 0) + self.assertGreater(len(res.cookies.get(auth_refresh_cookie_name).value), 0) + # Get the csrf token + csrf_token = res.data['csrf_token'] + + # Get on test view + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.AccessToken',)): + res = self.view_get() + + self.assertEqual(res.status_code, 200) + self.assertEqual(res.data['foo'], 'bar') + + # Refresh the token + with override_api_settings(AUTH_COOKIE=auth_cookie_name, ROTATE_REFRESH_TOKENS=False): + res = self.client.post(reverse('token_refresh')) + self.assertEqual(res.status_code, 200) + # Make sure we only update the access token + self.assertGreater(len(res.cookies.get(auth_cookie_name).value), 0) + self.assertIsNone(res.cookies.get(auth_refresh_cookie_name)) + self.assertNotIn('access', res.data) + self.assertNotIn('refresh', res.data) + self.assertIn('csrf_token', res.data) + + # Now refresh token with rotation enabled + with override_api_settings(AUTH_COOKIE=auth_cookie_name, ROTATE_REFRESH_TOKENS=True): + res = self.client.post(reverse('token_refresh')) + + self.assertEqual(res.status_code, 200) + # Make sure both tokens are updated + self.assertGreater(len(res.cookies.get(auth_cookie_name).value), 0) + self.assertGreater(len(res.cookies.get(auth_refresh_cookie_name).value), 0) + self.assertNotIn('access', res.data) + self.assertNotIn('refresh', res.data) + self.assertIn('csrf_token', res.data) + + # Get on test view again and test that it stills work after a refresh + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.AccessToken',)): + res = self.view_get() + + self.assertEqual(res.status_code, 200) + self.assertEqual(res.data['foo'], 'bar') + + # Try to post, it should fail because CSRF token is not in the header + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.AccessToken',)): + res = self.view_post(data={}) + self.assertEqual(res.status_code, 403) + + # Add CSRF + self.client.credentials(HTTP_X_CSRFTOKEN=csrf_token) + self.client.cookies[settings.CSRF_COOKIE_NAME] = csrf_token + with override_api_settings(AUTH_COOKIE=auth_cookie_name, AUTH_TOKEN_CLASSES=('rest_framework_simplejwt.tokens.AccessToken',)): + res = self.view_post(data={}) + + self.assertEqual(res.status_code, 200) + + # Delete cookies + with override_api_settings(AUTH_COOKIE=auth_cookie_name): + res = self.client.post(reverse('token_delete')) + self.assertEqual(res.status_code, 200) + self.assertEqual(res.cookies.get(auth_cookie_name).value, '') + self.assertEqual(res.cookies.get(auth_refresh_cookie_name).value, '') + self.assertEqual(res.cookies.get(settings.CSRF_COOKIE_NAME).value, '') diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 7cc2326d3..5769bf8bb 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -1,6 +1,7 @@ from datetime import timedelta from unittest.mock import MagicMock, patch +from django.contrib.auth import get_user_model from django.test import TestCase from rest_framework import exceptions as drf_exceptions @@ -11,7 +12,6 @@ TokenRefreshSlidingSerializer, TokenVerifySerializer, ) from rest_framework_simplejwt.settings import api_settings -from rest_framework_simplejwt.state import User from rest_framework_simplejwt.token_blacklist.models import ( BlacklistedToken, OutstandingToken, ) @@ -30,7 +30,7 @@ def setUp(self): self.username = 'test_user' self.password = 'test_password' - self.user = User.objects.create_user( + self.user = get_user_model().objects.create_user( username=self.username, password=self.password, ) @@ -80,7 +80,7 @@ def setUp(self): self.username = 'test_user' self.password = 'test_password' - self.user = User.objects.create_user( + self.user = get_user_model().objects.create_user( username=self.username, password=self.password, ) @@ -105,7 +105,7 @@ def setUp(self): self.username = 'test_user' self.password = 'test_password' - self.user = User.objects.create_user( + self.user = get_user_model().objects.create_user( username=self.username, password=self.password, ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 8d587224d..38a28f3d9 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,12 +1,12 @@ from datetime import datetime, timedelta from unittest.mock import patch +from django.contrib.auth import get_user_model from django.test import TestCase from jose import jwt from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.settings import api_settings -from rest_framework_simplejwt.state import User from rest_framework_simplejwt.tokens import ( AccessToken, RefreshToken, SlidingToken, Token, UntypedToken, ) @@ -280,7 +280,7 @@ def test_check_exp(self): def test_for_user(self): username = 'test_user' - user = User.objects.create_user( + user = get_user_model().objects.create_user( username=username, password='test_password', ) diff --git a/tests/test_views.py b/tests/test_views.py index 8b2a182a6..e7b03ef39 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -2,11 +2,11 @@ from importlib import reload from unittest.mock import patch +from django.contrib.auth import get_user_model from django.utils import timezone from rest_framework_simplejwt import serializers from rest_framework_simplejwt.settings import api_settings -from rest_framework_simplejwt.state import User from rest_framework_simplejwt.tokens import ( AccessToken, RefreshToken, SlidingToken, ) @@ -23,8 +23,8 @@ class TestTokenObtainPairView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -32,20 +32,20 @@ def setUp(self): def test_fields_missing(self): res = self.view_post(data={}) self.assertEqual(res.status_code, 400) - self.assertIn(User.USERNAME_FIELD, res.data) + self.assertIn(self.user_model.USERNAME_FIELD, res.data) self.assertIn('password', res.data) - res = self.view_post(data={User.USERNAME_FIELD: self.username}) + res = self.view_post(data={self.user_model.USERNAME_FIELD: self.username}) self.assertEqual(res.status_code, 400) self.assertIn('password', res.data) res = self.view_post(data={'password': self.password}) self.assertEqual(res.status_code, 400) - self.assertIn(User.USERNAME_FIELD, res.data) + self.assertIn(self.user_model.USERNAME_FIELD, res.data) def test_credentials_wrong(self): res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': 'test_user', }) self.assertEqual(res.status_code, 401) @@ -56,7 +56,7 @@ def test_user_inactive(self): self.user.save() res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) self.assertEqual(res.status_code, 401) @@ -64,7 +64,7 @@ def test_user_inactive(self): def test_success(self): res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) self.assertEqual(res.status_code, 200) @@ -73,22 +73,22 @@ def test_success(self): def test_update_last_login(self): self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) # verify last_login is not updated - user = User.objects.get(username=self.username) + user = self.user_model.objects.get(username=self.username) self.assertEqual(user.last_login, None) # verify last_login is updated with override_api_settings(UPDATE_LAST_LOGIN=True): reload(serializers) self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) - user = User.objects.get(username=self.username) + user = self.user_model.objects.get(username=self.username) self.assertIsNotNone(user.last_login) self.assertGreaterEqual(timezone.now(), user.last_login) @@ -101,8 +101,8 @@ class TestTokenRefreshView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -152,8 +152,8 @@ class TestTokenObtainSlidingView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -161,20 +161,20 @@ def setUp(self): def test_fields_missing(self): res = self.view_post(data={}) self.assertEqual(res.status_code, 400) - self.assertIn(User.USERNAME_FIELD, res.data) + self.assertIn(self.user_model.USERNAME_FIELD, res.data) self.assertIn('password', res.data) - res = self.view_post(data={User.USERNAME_FIELD: self.username}) + res = self.view_post(data={self.user_model.USERNAME_FIELD: self.username}) self.assertEqual(res.status_code, 400) self.assertIn('password', res.data) res = self.view_post(data={'password': self.password}) self.assertEqual(res.status_code, 400) - self.assertIn(User.USERNAME_FIELD, res.data) + self.assertIn(self.user_model.USERNAME_FIELD, res.data) def test_credentials_wrong(self): res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': 'test_user', }) self.assertEqual(res.status_code, 401) @@ -185,7 +185,7 @@ def test_user_inactive(self): self.user.save() res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) self.assertEqual(res.status_code, 401) @@ -193,7 +193,7 @@ def test_user_inactive(self): def test_success(self): res = self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) self.assertEqual(res.status_code, 200) @@ -201,22 +201,22 @@ def test_success(self): def test_update_last_login(self): self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) # verify last_login is not updated - user = User.objects.get(username=self.username) + user = self.user_model.objects.get(username=self.username) self.assertEqual(user.last_login, None) # verify last_login is updated with override_api_settings(UPDATE_LAST_LOGIN=True): reload(serializers) self.view_post(data={ - User.USERNAME_FIELD: self.username, + self.user_model.USERNAME_FIELD: self.username, 'password': self.password, }) - user = User.objects.get(username=self.username) + user = self.user_model.objects.get(username=self.username) self.assertIsNotNone(user.last_login) self.assertGreaterEqual(timezone.now(), user.last_login) @@ -229,8 +229,8 @@ class TestTokenRefreshSlidingView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -294,8 +294,8 @@ class TestTokenVerifyView(APIViewTestCase): def setUp(self): self.username = 'test_user' self.password = 'test_password' - - self.user = User.objects.create_user( + self.user_model = get_user_model() + self.user = self.user_model.objects.create_user( username=self.username, password=self.password, ) @@ -321,7 +321,6 @@ def test_it_should_return_401_if_token_invalid(self): def test_it_should_return_200_if_everything_okay(self): token = RefreshToken() - res = self.view_post(data={'token': str(token)}) self.assertEqual(res.status_code, 200) self.assertEqual(len(res.data), 0) diff --git a/tests/urls.py b/tests/urls.py index 14e8a0b99..9633a600a 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -7,6 +7,7 @@ urlpatterns = [ re_path(r'^token/pair/$', jwt_views.token_obtain_pair, name='token_obtain_pair'), re_path(r'^token/refresh/$', jwt_views.token_refresh, name='token_refresh'), + re_path(r'^token/delete/$', jwt_views.token_delete, name='token_delete'), re_path(r'^token/sliding/$', jwt_views.token_obtain_sliding, name='token_obtain_sliding'), re_path(r'^token/sliding/refresh/$', jwt_views.token_refresh_sliding, name='token_refresh_sliding'), diff --git a/tests/utils.py b/tests/utils.py index ec6c7ea9f..768f53d8b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,7 @@ def authenticate_with_token(self, type, token): view_get = client_action_wrapper('get') +# Don't nest contexts it won't work!!! @contextlib.contextmanager def override_api_settings(**settings): old_settings = {} diff --git a/tests/views.py b/tests/views.py index c8a85ced6..54f951657 100644 --- a/tests/views.py +++ b/tests/views.py @@ -12,5 +12,8 @@ class TestView(APIView): def get(self, request): return Response({'foo': 'bar'}) + def post(self, request): + return Response({'foo': 'bar'}) + test_view = TestView.as_view()