Skip to content

Commit d2cd59d

Browse files
Improve testing (#688)
* Support `override_api_settings` as decorator * Update test_authentication * black formatting test_authentication * Use drf status instead of literal status * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_integration * Update test_serializers * Update test_integration * Update test_token_blacklist * Update test_tokens * Update test_views * add `setUpTestData` to `TestToken` * fix typo `self` should be `cls` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c65036c commit d2cd59d

File tree

7 files changed

+154
-155
lines changed

7 files changed

+154
-155
lines changed

tests/test_authentication.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,19 @@ def test_get_header(self):
4040
)
4141
self.assertEqual(self.backend.get_header(request), self.fake_header)
4242

43-
# Should work with the x_access_token
44-
with override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN"):
45-
# Should pull correct header off request when using X_ACCESS_TOKEN
46-
request = self.factory.get(
47-
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header
48-
)
49-
self.assertEqual(self.backend.get_header(request), self.fake_header)
50-
51-
# Should work for unicode headers when using
52-
request = self.factory.get(
53-
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
54-
)
55-
self.assertEqual(self.backend.get_header(request), self.fake_header)
43+
@override_api_settings(AUTH_HEADER_NAME="HTTP_X_ACCESS_TOKEN")
44+
def test_get_header_x_access_token(self):
45+
# Should pull correct header off request when using X_ACCESS_TOKEN
46+
request = self.factory.get("/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header)
47+
self.assertEqual(self.backend.get_header(request), self.fake_header)
48+
49+
# Should work for unicode headers when using
50+
request = self.factory.get(
51+
"/test-url/", HTTP_X_ACCESS_TOKEN=self.fake_header.decode("utf-8")
52+
)
53+
self.assertEqual(self.backend.get_header(request), self.fake_header)
5654

5755
def test_get_raw_token(self):
58-
# Should return None if header lacks correct type keyword
59-
with override_api_settings(AUTH_HEADER_TYPES="JWT"):
60-
reload(authentication)
61-
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
6256
reload(authentication)
6357

6458
# Should return None if an empty AUTHORIZATION header is sent
@@ -74,14 +68,21 @@ def test_get_raw_token(self):
7468
# Otherwise, should return unvalidated token in header
7569
self.assertEqual(self.backend.get_raw_token(self.fake_header), self.fake_token)
7670

71+
@override_api_settings(AUTH_HEADER_TYPES="JWT")
72+
def test_get_raw_token_incorrect_header_keyword(self):
73+
# Should return None if header lacks correct type keyword
74+
# AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
75+
reload(authentication)
76+
self.assertIsNone(self.backend.get_raw_token(self.fake_header))
77+
78+
@override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer"))
79+
def test_get_raw_token_multi_header_keyword(self):
7780
# Should return token if header has one of many valid token types
78-
with override_api_settings(AUTH_HEADER_TYPES=("JWT", "Bearer")):
79-
reload(authentication)
80-
self.assertEqual(
81-
self.backend.get_raw_token(self.fake_header),
82-
self.fake_token,
83-
)
8481
reload(authentication)
82+
self.assertEqual(
83+
self.backend.get_raw_token(self.fake_header),
84+
self.fake_token,
85+
)
8586

8687
def test_get_validated_token(self):
8788
# Should raise InvalidToken if token not valid
@@ -96,36 +97,39 @@ def test_get_validated_token(self):
9697
self.backend.get_validated_token(str(token)).payload, token.payload
9798
)
9899

100+
@override_api_settings(
101+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
102+
)
103+
def test_get_validated_token_reject_unknown_token(self):
99104
# Should not accept tokens not included in AUTH_TOKEN_CLASSES
100105
sliding_token = SlidingToken()
101-
with override_api_settings(
102-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
103-
):
104-
with self.assertRaises(InvalidToken) as e:
105-
self.backend.get_validated_token(str(sliding_token))
106-
107-
messages = e.exception.detail["messages"]
108-
self.assertEqual(1, len(messages))
109-
self.assertEqual(
110-
{
111-
"token_class": "AccessToken",
112-
"token_type": "access",
113-
"message": "Token has wrong type",
114-
},
115-
messages[0],
116-
)
106+
with self.assertRaises(InvalidToken) as e:
107+
self.backend.get_validated_token(str(sliding_token))
108+
109+
messages = e.exception.detail["messages"]
110+
self.assertEqual(1, len(messages))
111+
self.assertEqual(
112+
{
113+
"token_class": "AccessToken",
114+
"token_type": "access",
115+
"message": "Token has wrong type",
116+
},
117+
messages[0],
118+
)
117119

120+
@override_api_settings(
121+
AUTH_TOKEN_CLASSES=(
122+
"rest_framework_simplejwt.tokens.AccessToken",
123+
"rest_framework_simplejwt.tokens.SlidingToken",
124+
),
125+
)
126+
def test_get_validated_token_accept_known_token(self):
118127
# Should accept tokens included in AUTH_TOKEN_CLASSES
119128
access_token = AccessToken()
120129
sliding_token = SlidingToken()
121-
with override_api_settings(
122-
AUTH_TOKEN_CLASSES=(
123-
"rest_framework_simplejwt.tokens.AccessToken",
124-
"rest_framework_simplejwt.tokens.SlidingToken",
125-
)
126-
):
127-
self.backend.get_validated_token(str(access_token))
128-
self.backend.get_validated_token(str(sliding_token))
130+
131+
self.backend.get_validated_token(str(access_token))
132+
self.backend.get_validated_token(str(sliding_token))
129133

130134
def test_get_user(self):
131135
payload = {"some_other_id": "foo"}

tests/test_integration.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.contrib.auth import get_user_model
44
from django.urls import reverse
5+
from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED
56

67
from rest_framework_simplejwt.settings import api_settings
78
from rest_framework_simplejwt.tokens import AccessToken
@@ -26,7 +27,7 @@ def setUp(self):
2627
def test_no_authorization(self):
2728
res = self.view_get()
2829

29-
self.assertEqual(res.status_code, 401)
30+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
3031
self.assertIn("credentials were not provided", res.data["detail"])
3132

3233
def test_wrong_auth_type(self):
@@ -43,9 +44,12 @@ def test_wrong_auth_type(self):
4344

4445
res = self.view_get()
4546

46-
self.assertEqual(res.status_code, 401)
47+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
4748
self.assertIn("credentials were not provided", res.data["detail"])
4849

50+
@override_api_settings(
51+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
52+
)
4953
def test_expired_token(self):
5054
old_lifetime = AccessToken.lifetime
5155
AccessToken.lifetime = timedelta(seconds=0)
@@ -63,14 +67,14 @@ def test_expired_token(self):
6367
access = res.data["access"]
6468
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
6569

66-
with override_api_settings(
67-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
68-
):
69-
res = self.view_get()
70+
res = self.view_get()
7071

71-
self.assertEqual(res.status_code, 401)
72+
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED)
7273
self.assertEqual("token_not_valid", res.data["code"])
7374

75+
@override_api_settings(
76+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",),
77+
)
7478
def test_user_can_get_sliding_token_and_use_it(self):
7579
res = self.client.post(
7680
reverse("token_obtain_sliding"),
@@ -83,14 +87,14 @@ def test_user_can_get_sliding_token_and_use_it(self):
8387
token = res.data["token"]
8488
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], token)
8589

86-
with override_api_settings(
87-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.SlidingToken",)
88-
):
89-
res = self.view_get()
90+
res = self.view_get()
9091

91-
self.assertEqual(res.status_code, 200)
92+
self.assertEqual(res.status_code, HTTP_200_OK)
9293
self.assertEqual(res.data["foo"], "bar")
9394

95+
@override_api_settings(
96+
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",),
97+
)
9498
def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
9599
res = self.client.post(
96100
reverse("token_obtain_pair"),
@@ -105,12 +109,9 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
105109

106110
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
107111

108-
with override_api_settings(
109-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
110-
):
111-
res = self.view_get()
112+
res = self.view_get()
112113

113-
self.assertEqual(res.status_code, 200)
114+
self.assertEqual(res.status_code, HTTP_200_OK)
114115
self.assertEqual(res.data["foo"], "bar")
115116

116117
res = self.client.post(
@@ -122,10 +123,7 @@ def test_user_can_get_access_and_refresh_tokens_and_use_them(self):
122123

123124
self.authenticate_with_token(api_settings.AUTH_HEADER_TYPES[0], access)
124125

125-
with override_api_settings(
126-
AUTH_TOKEN_CLASSES=("rest_framework_simplejwt.tokens.AccessToken",)
127-
):
128-
res = self.view_get()
126+
res = self.view_get()
129127

130-
self.assertEqual(res.status_code, 200)
128+
self.assertEqual(res.status_code, HTTP_200_OK)
131129
self.assertEqual(res.data["foo"], "bar")

tests/test_serializers.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ def test_it_should_return_access_token_if_everything_ok(self):
285285
access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
286286
)
287287

288+
@override_api_settings(
289+
ROTATE_REFRESH_TOKENS=True,
290+
BLACKLIST_AFTER_ROTATION=False,
291+
)
288292
def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
289293
refresh = RefreshToken()
290294

@@ -298,14 +302,9 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
298302

299303
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
300304

301-
with override_api_settings(
302-
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=False
303-
):
304-
with patch(
305-
"rest_framework_simplejwt.tokens.aware_utcnow"
306-
) as fake_aware_utcnow:
307-
fake_aware_utcnow.return_value = now
308-
self.assertTrue(ser.is_valid())
305+
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
306+
fake_aware_utcnow.return_value = now
307+
self.assertTrue(ser.is_valid())
309308

310309
access = AccessToken(ser.validated_data["access"])
311310
new_refresh = RefreshToken(ser.validated_data["refresh"])
@@ -324,6 +323,10 @@ def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
324323
datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
325324
)
326325

326+
@override_api_settings(
327+
ROTATE_REFRESH_TOKENS=True,
328+
BLACKLIST_AFTER_ROTATION=True,
329+
)
327330
def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
328331
self,
329332
):
@@ -342,14 +345,9 @@ def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_black
342345

343346
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
344347

345-
with override_api_settings(
346-
ROTATE_REFRESH_TOKENS=True, BLACKLIST_AFTER_ROTATION=True
347-
):
348-
with patch(
349-
"rest_framework_simplejwt.tokens.aware_utcnow"
350-
) as fake_aware_utcnow:
351-
fake_aware_utcnow.return_value = now
352-
self.assertTrue(ser.is_valid())
348+
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
349+
fake_aware_utcnow.return_value = now
350+
self.assertTrue(ser.is_valid())
353351

354352
access = AccessToken(ser.validated_data["access"])
355353
new_refresh = RefreshToken(ser.validated_data["refresh"])

tests/test_token_blacklist.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,25 +237,25 @@ def setUp(self):
237237

238238
super().setUp()
239239

240+
@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
240241
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
241242
self,
242243
):
243-
with override_api_settings(BLACKLIST_AFTER_ROTATION=True):
244-
refresh_token = RefreshToken.for_user(self.user)
245-
refresh_token.blacklist()
244+
refresh_token = RefreshToken.for_user(self.user)
245+
refresh_token.blacklist()
246246

247-
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
248-
self.assertFalse(serializer.is_valid())
247+
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
248+
self.assertFalse(serializer.is_valid())
249249

250+
@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
250251
def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
251252
self,
252253
):
253-
with override_api_settings(BLACKLIST_AFTER_ROTATION=False):
254-
refresh_token = RefreshToken.for_user(self.user)
255-
refresh_token.blacklist()
254+
refresh_token = RefreshToken.for_user(self.user)
255+
refresh_token.blacklist()
256256

257-
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
258-
self.assertTrue(serializer.is_valid())
257+
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
258+
self.assertTrue(serializer.is_valid())
259259

260260

261261
class TestBigAutoFieldIDMigration(MigrationTestCase):

tests/test_tokens.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class TestToken(TestCase):
3131
def setUp(self):
3232
self.token = MyToken()
3333

34+
@classmethod
35+
def setUpTestData(cls):
36+
cls.username = "test_user"
37+
cls.user = User.objects.create_user(
38+
username=cls.username,
39+
password="test_password",
40+
)
41+
3442
def test_init_no_token_type_or_lifetime(self):
3543
class MyTestToken(Token):
3644
pass
@@ -225,14 +233,14 @@ def test_set_jti(self):
225233
self.assertIn("jti", token)
226234
self.assertNotEqual(old_jti, token["jti"])
227235

236+
@override_api_settings(JTI_CLAIM=None)
228237
def test_optional_jti(self):
229-
with override_api_settings(JTI_CLAIM=None):
230-
token = MyToken()
238+
token = MyToken()
231239
self.assertNotIn("jti", token)
232240

241+
@override_api_settings(TOKEN_TYPE_CLAIM=None)
233242
def test_optional_type_token(self):
234-
with override_api_settings(TOKEN_TYPE_CLAIM=None):
235-
token = MyToken()
243+
token = MyToken()
236244
self.assertNotIn("type", token)
237245

238246
def test_set_exp(self):
@@ -355,25 +363,19 @@ def test_check_token_if_wrong_type_leeway(self):
355363
token.token_backend.leeway = 0
356364

357365
def test_for_user(self):
358-
username = "test_user"
359-
user = User.objects.create_user(
360-
username=username,
361-
password="test_password",
362-
)
366+
token = MyToken.for_user(self.user)
363367

364-
token = MyToken.for_user(user)
365-
366-
user_id = getattr(user, api_settings.USER_ID_FIELD)
368+
user_id = getattr(self.user, api_settings.USER_ID_FIELD)
367369
if not isinstance(user_id, int):
368370
user_id = str(user_id)
369371

370372
self.assertEqual(token[api_settings.USER_ID_CLAIM], user_id)
371373

374+
@override_api_settings(USER_ID_FIELD="username")
375+
def test_for_user_with_username(self):
372376
# Test with non-int user id
373-
with override_api_settings(USER_ID_FIELD="username"):
374-
token = MyToken.for_user(user)
375-
376-
self.assertEqual(token[api_settings.USER_ID_CLAIM], username)
377+
token = MyToken.for_user(self.user)
378+
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)
377379

378380
def test_get_token_backend(self):
379381
token = MyToken()

0 commit comments

Comments
 (0)