Skip to content

Commit 9b06293

Browse files
hodossyHodossy, SzabolcsAndrew-Chen-Wang
authored
Add blacklist view to log out users (jazzband#306)
* feat: add blacklist view to log out users * fix: import order * test: check if blacklisted tokens cannot be used Co-authored-by: Hodossy, Szabolcs <[email protected]> Co-authored-by: Andrew Chen Wang <[email protected]>
1 parent c9e989e commit 9b06293

File tree

5 files changed

+168
-1
lines changed

5 files changed

+168
-1
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,15 @@ def validate(self, attrs):
151151
raise ValidationError("Token is blacklisted")
152152

153153
return {}
154+
155+
156+
class TokenBlacklistSerializer(serializers.Serializer):
157+
refresh = serializers.CharField()
158+
159+
def validate(self, attrs):
160+
refresh = RefreshToken(attrs['refresh'])
161+
try:
162+
refresh.blacklist()
163+
except AttributeError:
164+
pass
165+
return {}

rest_framework_simplejwt/views.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,14 @@ class TokenVerifyView(TokenViewBase):
8484

8585

8686
token_verify = TokenVerifyView.as_view()
87+
88+
89+
class TokenBlacklistView(TokenViewBase):
90+
"""
91+
Takes a token and blacklists it. Must be used with the
92+
`rest_framework_simplejwt.token_blacklist` app installed.
93+
"""
94+
serializer_class = serializers.TokenBlacklistSerializer
95+
96+
97+
token_blacklist = TokenBlacklistView.as_view()

tests/test_serializers.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from rest_framework_simplejwt.exceptions import TokenError
99
from rest_framework_simplejwt.serializers import (
10-
TokenObtainPairSerializer, TokenObtainSerializer,
10+
TokenBlacklistSerializer, TokenObtainPairSerializer, TokenObtainSerializer,
1111
TokenObtainSlidingSerializer, TokenRefreshSerializer,
1212
TokenRefreshSlidingSerializer, TokenVerifySerializer,
1313
)
@@ -365,3 +365,76 @@ def test_it_should_return_given_token_if_everything_ok(self):
365365
self.assertTrue(s.is_valid())
366366

367367
self.assertEqual(len(s.validated_data), 0)
368+
369+
370+
class TestTokenBlacklistSerializer(TestCase):
371+
def test_it_should_raise_token_error_if_token_invalid(self):
372+
token = RefreshToken()
373+
del token['exp']
374+
375+
s = TokenBlacklistSerializer(data={'refresh': str(token)})
376+
377+
with self.assertRaises(TokenError) as e:
378+
s.is_valid()
379+
380+
self.assertIn("has no 'exp' claim", e.exception.args[0])
381+
382+
token.set_exp(lifetime=-timedelta(days=1))
383+
384+
s = TokenBlacklistSerializer(data={'refresh': str(token)})
385+
386+
with self.assertRaises(TokenError) as e:
387+
s.is_valid()
388+
389+
self.assertIn('invalid or expired', e.exception.args[0])
390+
391+
def test_it_should_raise_token_error_if_token_has_wrong_type(self):
392+
token = RefreshToken()
393+
token[api_settings.TOKEN_TYPE_CLAIM] = 'wrong_type'
394+
395+
s = TokenBlacklistSerializer(data={'refresh': str(token)})
396+
397+
with self.assertRaises(TokenError) as e:
398+
s.is_valid()
399+
400+
self.assertIn("wrong type", e.exception.args[0])
401+
402+
def test_it_should_return_nothing_if_everything_ok(self):
403+
refresh = RefreshToken()
404+
refresh['test_claim'] = 'arst'
405+
406+
# Serializer validates
407+
s = TokenBlacklistSerializer(data={'refresh': str(refresh)})
408+
409+
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
410+
411+
with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
412+
fake_aware_utcnow.return_value = now
413+
self.assertTrue(s.is_valid())
414+
415+
self.assertDictEqual(s.validated_data, {})
416+
417+
def test_it_should_blacklist_refresh_token_if_everything_ok(self):
418+
self.assertEqual(OutstandingToken.objects.count(), 0)
419+
self.assertEqual(BlacklistedToken.objects.count(), 0)
420+
421+
refresh = RefreshToken()
422+
423+
refresh['test_claim'] = 'arst'
424+
425+
old_jti = refresh['jti']
426+
427+
# Serializer validates
428+
ser = TokenBlacklistSerializer(data={'refresh': str(refresh)})
429+
430+
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
431+
432+
with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
433+
fake_aware_utcnow.return_value = now
434+
self.assertTrue(ser.is_valid())
435+
436+
self.assertEqual(OutstandingToken.objects.count(), 1)
437+
self.assertEqual(BlacklistedToken.objects.count(), 1)
438+
439+
# Assert old refresh token is blacklisted
440+
self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)

tests/test_views.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,72 @@ def test_it_should_ignore_token_type(self):
335335
res = self.view_post(data={'token': str(token)})
336336
self.assertEqual(res.status_code, 200)
337337
self.assertEqual(len(res.data), 0)
338+
339+
340+
class TestTokenBlacklistView(APIViewTestCase):
341+
view_name = 'token_blacklist'
342+
343+
def setUp(self):
344+
self.username = 'test_user'
345+
self.password = 'test_password'
346+
347+
self.user = User.objects.create_user(
348+
username=self.username,
349+
password=self.password,
350+
)
351+
352+
def test_fields_missing(self):
353+
res = self.view_post(data={})
354+
self.assertEqual(res.status_code, 400)
355+
self.assertIn('refresh', res.data)
356+
357+
def test_it_should_return_401_if_token_invalid(self):
358+
token = RefreshToken()
359+
del token['exp']
360+
361+
res = self.view_post(data={'refresh': str(token)})
362+
self.assertEqual(res.status_code, 401)
363+
self.assertEqual(res.data['code'], 'token_not_valid')
364+
365+
token.set_exp(lifetime=-timedelta(seconds=1))
366+
367+
res = self.view_post(data={'refresh': str(token)})
368+
self.assertEqual(res.status_code, 401)
369+
self.assertEqual(res.data['code'], 'token_not_valid')
370+
371+
def test_it_should_return_if_everything_ok(self):
372+
refresh = RefreshToken()
373+
refresh['test_claim'] = 'arst'
374+
375+
# View returns 200
376+
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
377+
378+
with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
379+
fake_aware_utcnow.return_value = now
380+
381+
res = self.view_post(data={'refresh': str(refresh)})
382+
383+
self.assertEqual(res.status_code, 200)
384+
385+
self.assertDictEqual(res.data, {})
386+
387+
def test_it_should_return_401_if_token_is_blacklisted(self):
388+
refresh = RefreshToken()
389+
refresh['test_claim'] = 'arst'
390+
391+
# View returns 200
392+
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
393+
394+
with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
395+
fake_aware_utcnow.return_value = now
396+
397+
res = self.view_post(data={'refresh': str(refresh)})
398+
399+
self.assertEqual(res.status_code, 200)
400+
401+
self.view_name = 'token_refresh'
402+
res = self.view_post(data={'refresh': str(refresh)})
403+
# make sure other tests are not affected
404+
del self.view_name
405+
406+
self.assertEqual(res.status_code, 401)

tests/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313

1414
re_path(r'^token/verify/$', jwt_views.token_verify, name='token_verify'),
1515

16+
re_path(r'^token/blacklist/$', jwt_views.token_blacklist, name='token_blacklist'),
17+
1618
re_path(r'^test-view/$', views.test_view, name='test_view'),
1719
]

0 commit comments

Comments
 (0)