Skip to content

Commit 670baff

Browse files
feat(serializers): Add full user validation to sliding token refresh
Implements the same user validation logic (active status, password change) in to ensure consistent behavior with the standard .
1 parent c1fd2ea commit 670baff

File tree

2 files changed

+120
-8
lines changed

2 files changed

+120
-8
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ class TokenRefreshSerializer(serializers.Serializer):
113113
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
114114
refresh = self.token_class(attrs["refresh"])
115115

116-
data = {"access": str(refresh.access_token)}
117-
118-
# We handle user-related validation in a single, efficient block.
119116
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
120117
if user_id:
121118
try:
@@ -154,6 +151,7 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
154151
code="password_changed",
155152
)
156153

154+
data = {"access": str(refresh.access_token)}
157155
if api_settings.ROTATE_REFRESH_TOKENS:
158156
if api_settings.BLACKLIST_AFTER_ROTATION:
159157
try:
@@ -178,8 +176,50 @@ class TokenRefreshSlidingSerializer(serializers.Serializer):
178176
token = serializers.CharField()
179177
token_class = SlidingToken
180178

179+
default_error_messages = {
180+
"no_active_account": _("No active account found for the given token."),
181+
"password_changed": _("The user's password has been changed."),
182+
}
183+
181184
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
182185
token = self.token_class(attrs["token"])
186+
user_id = token.payload.get(api_settings.USER_ID_CLAIM, None)
187+
if user_id:
188+
try:
189+
user = get_user_model().objects.get(
190+
**{api_settings.USER_ID_FIELD: user_id}
191+
)
192+
except get_user_model().DoesNotExist:
193+
# This handles the case where the user has been deleted.
194+
raise AuthenticationFailed(
195+
self.error_messages["no_active_account"], "no_active_account"
196+
)
197+
198+
if not api_settings.USER_AUTHENTICATION_RULE(user):
199+
raise AuthenticationFailed(
200+
self.error_messages["no_active_account"], "no_active_account"
201+
)
202+
203+
if api_settings.CHECK_REVOKE_TOKEN:
204+
token_hash = token.payload.get(api_settings.REVOKE_TOKEN_CLAIM)
205+
user_hash = get_md5_hash_password(user.password)
206+
207+
if token_hash != user_hash:
208+
# If the password has changed, we blacklist the token
209+
# to prevent any further use.
210+
if (
211+
"rest_framework_simplejwt.token_blacklist"
212+
in settings.INSTALLED_APPS
213+
):
214+
try:
215+
token.blacklist()
216+
except AttributeError:
217+
pass
218+
219+
raise AuthenticationFailed(
220+
self.error_messages["password_changed"],
221+
code="password_changed",
222+
)
183223

184224
# Check that the timestamp in the "refresh_exp" claim has not
185225
# passed

tests/test_serializers.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ def test_it_should_produce_a_json_web_token_when_valid(self):
187187

188188

189189
class TestTokenRefreshSlidingSerializer(TestCase):
190+
def setUp(self):
191+
self.username = "test_user"
192+
self.password = "test_password"
193+
194+
self.user = User.objects.create_user(
195+
username=self.username,
196+
password=self.password,
197+
)
198+
190199
def test_it_should_not_validate_if_token_invalid(self):
191200
token = SlidingToken()
192201
del token["exp"]
@@ -268,6 +277,74 @@ def test_it_should_update_token_exp_claim_if_everything_ok(self):
268277

269278
self.assertTrue(old_exp < new_exp)
270279

280+
def test_it_should_raise_error_for_deleted_users(self):
281+
token = SlidingToken.for_user(self.user)
282+
self.user.delete()
283+
284+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
285+
286+
# It should raise AuthenticationFailed instead of ObjectDoesNotExist
287+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
288+
s.is_valid()
289+
290+
self.assertEqual(e.exception.get_codes(), "no_active_account")
291+
292+
def test_it_should_raise_error_for_inactive_users(self):
293+
token = SlidingToken.for_user(self.user)
294+
self.user.is_active = False
295+
self.user.save()
296+
297+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
298+
299+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
300+
s.is_valid()
301+
302+
self.assertEqual(e.exception.get_codes(), "no_active_account")
303+
304+
@override_api_settings(
305+
CHECK_REVOKE_TOKEN=True,
306+
REVOKE_TOKEN_CLAIM="hash_password",
307+
BLACKLIST_AFTER_ROTATION=False,
308+
)
309+
def test_sliding_token_should_fail_after_password_change(self):
310+
"""
311+
Tests that sliding token refresh fails if CHECK_REVOKE_TOKEN is True and the
312+
user's password has changed.
313+
"""
314+
token = SlidingToken.for_user(self.user)
315+
self.user.set_password("new_password")
316+
self.user.save()
317+
318+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
319+
320+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
321+
s.is_valid(raise_exception=True)
322+
323+
self.assertEqual(e.exception.get_codes(), "password_changed")
324+
325+
@override_api_settings(
326+
CHECK_REVOKE_TOKEN=True,
327+
REVOKE_TOKEN_CLAIM="hash_password",
328+
BLACKLIST_AFTER_ROTATION=True,
329+
)
330+
def test_sliding_token_should_blacklist_after_password_change(self):
331+
"""
332+
Tests that if sliding token refresh fails due to a password change, the
333+
offending token is blacklisted.
334+
"""
335+
token = SlidingToken.for_user(self.user)
336+
self.user.set_password("new_password")
337+
self.user.save()
338+
339+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
340+
with self.assertRaises(drf_exceptions.AuthenticationFailed):
341+
s.is_valid(raise_exception=True)
342+
343+
# Check that the token is now in the blacklist
344+
jti = token[api_settings.JTI_CLAIM]
345+
self.assertTrue(OutstandingToken.objects.filter(jti=jti).exists())
346+
self.assertTrue(BlacklistedToken.objects.filter(token__jti=jti).exists())
347+
271348

272349
class TestTokenRefreshSerializer(TestCase):
273350
def setUp(self):
@@ -514,11 +591,6 @@ def test_refresh_token_should_blacklist_after_password_change(self):
514591
Tests that if token refresh fails due to a password change, the
515592
offending refresh token is blacklisted.
516593
"""
517-
from rest_framework_simplejwt.token_blacklist.models import (
518-
BlacklistedToken,
519-
OutstandingToken,
520-
)
521-
522594
refresh = RefreshToken.for_user(self.user)
523595
self.user.set_password("new_password")
524596
self.user.save()

0 commit comments

Comments
 (0)