|
1 | 1 | from datetime import timedelta |
| 2 | +from importlib import reload |
| 3 | +from freezegun import freeze_time |
2 | 4 |
|
3 | 5 | from django.contrib.auth import get_user_model |
4 | 6 | from django.urls import reverse |
5 | 7 | from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED |
6 | 8 |
|
7 | 9 | from rest_framework_simplejwt.settings import api_settings |
8 | | -from rest_framework_simplejwt.tokens import AccessToken |
| 10 | +from rest_framework_simplejwt.tokens import AccessToken, RefreshToken |
| 11 | +from rest_framework_simplejwt.utils import aware_utcnow |
9 | 12 |
|
10 | 13 | from .utils import APIViewTestCase, override_api_settings |
11 | 14 |
|
@@ -127,3 +130,124 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self): |
127 | 130 |
|
128 | 131 | self.assertEqual(res.status_code, HTTP_200_OK) |
129 | 132 | self.assertEqual(res.data["foo"], "bar") |
| 133 | + |
| 134 | + @override_api_settings( |
| 135 | + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), |
| 136 | + TOKEN_FAMILY_CHECK_ON_ACCESS=True, |
| 137 | + ) |
| 138 | + def test_access_token_performs_family_blacklist_check_when_enabled(self): |
| 139 | + res = self.client.post( |
| 140 | + reverse("token_obtain_pair"), |
| 141 | + data={ |
| 142 | + User.USERNAME_FIELD: self.username, |
| 143 | + "password": self.password, |
| 144 | + }, |
| 145 | + ) |
| 146 | + |
| 147 | + access = res.data["access"] |
| 148 | + refresh = res.data["refresh"] |
| 149 | + |
| 150 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 151 | + |
| 152 | + res = self.view_get() |
| 153 | + |
| 154 | + self.assertEqual(res.status_code, HTTP_200_OK) |
| 155 | + self.assertEqual(res.data["foo"], "bar") |
| 156 | + |
| 157 | + RefreshToken(str(refresh)).blacklist_family() |
| 158 | + |
| 159 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 160 | + |
| 161 | + res = self.view_get() |
| 162 | + |
| 163 | + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) |
| 164 | + self.assertEqual("token_not_valid", res.data["code"]) |
| 165 | + |
| 166 | + error_msg = res.data.get("messages")[0].get("message") |
| 167 | + self.assertIn("family", error_msg) |
| 168 | + self.assertIn("blacklisted", error_msg) |
| 169 | + |
| 170 | + @override_api_settings( |
| 171 | + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), |
| 172 | + TOKEN_FAMILY_CHECK_ON_ACCESS=True, |
| 173 | + # We use a smaller family lifetime than the default access token lifetime |
| 174 | + # so we dont have to reload the tokens and serializers modules |
| 175 | + TOKEN_FAMILY_LIFETIME=timedelta(minutes=2), |
| 176 | + ) |
| 177 | + def test_access_token_performs_family_expiration_check_when_enabled(self): |
| 178 | + |
| 179 | + with freeze_time(aware_utcnow() - timedelta(minutes=2)): |
| 180 | + res = self.client.post( |
| 181 | + reverse("token_obtain_pair"), |
| 182 | + data={ |
| 183 | + User.USERNAME_FIELD: self.username, |
| 184 | + "password": self.password, |
| 185 | + }, |
| 186 | + ) |
| 187 | + |
| 188 | + access = res.data["access"] |
| 189 | + |
| 190 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 191 | + |
| 192 | + res = self.view_get() |
| 193 | + |
| 194 | + self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) |
| 195 | + self.assertEqual("token_not_valid", res.data["code"]) |
| 196 | + |
| 197 | + error_msg = res.data.get("messages")[0].get("message") |
| 198 | + self.assertIn("family", error_msg) |
| 199 | + self.assertIn("expired", error_msg) |
| 200 | + |
| 201 | + @override_api_settings( |
| 202 | + AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",), |
| 203 | + TOKEN_FAMILY_CHECK_ON_ACCESS=False, |
| 204 | + # We use a smaller family lifetime than the default access token lifetime |
| 205 | + # so we dont have to reload the tokens and serializers modules |
| 206 | + TOKEN_FAMILY_LIFETIME=timedelta(minutes=2), |
| 207 | + ) |
| 208 | + def test_access_token_does_not_performs_family_checks_when_disabled(self): |
| 209 | + res = self.client.post( |
| 210 | + reverse("token_obtain_pair"), |
| 211 | + data={ |
| 212 | + User.USERNAME_FIELD: self.username, |
| 213 | + "password": self.password, |
| 214 | + }, |
| 215 | + ) |
| 216 | + |
| 217 | + access = res.data["access"] |
| 218 | + refresh = res.data["refresh"] |
| 219 | + |
| 220 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 221 | + res = self.view_get() |
| 222 | + |
| 223 | + self.assertEqual(res.status_code, HTTP_200_OK) |
| 224 | + self.assertEqual(res.data["foo"], "bar") |
| 225 | + |
| 226 | + # blacklisting the token family |
| 227 | + RefreshToken(str(refresh)).blacklist_family() |
| 228 | + |
| 229 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 230 | + res = self.view_get() |
| 231 | + |
| 232 | + # response must be 200_OK since the family check for the access token is disabled |
| 233 | + self.assertEqual(res.status_code, HTTP_200_OK) |
| 234 | + self.assertEqual(res.data["foo"], "bar") |
| 235 | + |
| 236 | + # testing for family expiration now |
| 237 | + with freeze_time(aware_utcnow() - timedelta(minutes=2)): |
| 238 | + res = self.client.post( |
| 239 | + reverse("token_obtain_pair"), |
| 240 | + data={ |
| 241 | + User.USERNAME_FIELD: self.username, |
| 242 | + "password": self.password, |
| 243 | + }, |
| 244 | + ) |
| 245 | + |
| 246 | + access = res.data["access"] |
| 247 | + |
| 248 | + self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access) |
| 249 | + res = self.view_get() |
| 250 | + |
| 251 | + # response must be 200_OK since the family check for the access token is disabled |
| 252 | + self.assertEqual(res.status_code, HTTP_200_OK) |
| 253 | + self.assertEqual(res.data["foo"], "bar") |
0 commit comments