Skip to content

Commit f4e9886

Browse files
feat(serializers): Add full user validation to sliding token refresh
Implements the same user validation logic (active status, password change) in to ensure consistent behavior with the standard .
1 parent f7fe3db commit f4e9886

File tree

2 files changed

+120
-8
lines changed

2 files changed

+120
-8
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ class TokenRefreshSerializer(serializers.Serializer):
113113
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
114114
refresh = self.token_class(attrs["refresh"])
115115

116-
data = {"access": str(refresh.access_token)}
117-
118-
# We handle user-related validation in a single, efficient block.
119116
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
120117
if user_id:
121118
try:
@@ -154,6 +151,7 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
154151
code="password_changed",
155152
)
156153

154+
data = {"access": str(refresh.access_token)}
157155
if api_settings.ROTATE_REFRESH_TOKENS:
158156
if api_settings.BLACKLIST_AFTER_ROTATION:
159157
try:
@@ -178,8 +176,50 @@ class TokenRefreshSlidingSerializer(serializers.Serializer):
178176
token = serializers.CharField()
179177
token_class = SlidingToken
180178

179+
default_error_messages = {
180+
"no_active_account": _("No active account found for the given token."),
181+
"password_changed": _("The user's password has been changed."),
182+
}
183+
181184
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
182185
token = self.token_class(attrs["token"])
186+
user_id = token.payload.get(api_settings.USER_ID_CLAIM, None)
187+
if user_id:
188+
try:
189+
user = get_user_model().objects.get(
190+
**{api_settings.USER_ID_FIELD: user_id}
191+
)
192+
except get_user_model().DoesNotExist:
193+
# This handles the case where the user has been deleted.
194+
raise AuthenticationFailed(
195+
self.error_messages["no_active_account"], "no_active_account"
196+
)
197+
198+
if not api_settings.USER_AUTHENTICATION_RULE(user):
199+
raise AuthenticationFailed(
200+
self.error_messages["no_active_account"], "no_active_account"
201+
)
202+
203+
if api_settings.CHECK_REVOKE_TOKEN:
204+
token_hash = token.payload.get(api_settings.REVOKE_TOKEN_CLAIM)
205+
user_hash = get_md5_hash_password(user.password)
206+
207+
if token_hash != user_hash:
208+
# If the password has changed, we blacklist the token
209+
# to prevent any further use.
210+
if (
211+
"rest_framework_simplejwt.token_blacklist"
212+
in settings.INSTALLED_APPS
213+
):
214+
try:
215+
token.blacklist()
216+
except AttributeError:
217+
pass
218+
219+
raise AuthenticationFailed(
220+
self.error_messages["password_changed"],
221+
code="password_changed",
222+
)
183223

184224
# Check that the timestamp in the "refresh_exp" claim has not
185225
# passed

tests/test_serializers.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ def test_it_should_produce_a_json_web_token_when_valid(self):
188188

189189

190190
class TestTokenRefreshSlidingSerializer(TestCase):
191+
def setUp(self):
192+
self.username = "test_user"
193+
self.password = "test_password"
194+
195+
self.user = User.objects.create_user(
196+
username=self.username,
197+
password=self.password,
198+
)
199+
191200
def test_it_should_not_validate_if_token_invalid(self):
192201
token = SlidingToken()
193202
del token["exp"]
@@ -269,6 +278,74 @@ def test_it_should_update_token_exp_claim_if_everything_ok(self):
269278

270279
self.assertTrue(old_exp < new_exp)
271280

281+
def test_it_should_raise_error_for_deleted_users(self):
282+
token = SlidingToken.for_user(self.user)
283+
self.user.delete()
284+
285+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
286+
287+
# It should raise AuthenticationFailed instead of ObjectDoesNotExist
288+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
289+
s.is_valid()
290+
291+
self.assertEqual(e.exception.get_codes(), "no_active_account")
292+
293+
def test_it_should_raise_error_for_inactive_users(self):
294+
token = SlidingToken.for_user(self.user)
295+
self.user.is_active = False
296+
self.user.save()
297+
298+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
299+
300+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
301+
s.is_valid()
302+
303+
self.assertEqual(e.exception.get_codes(), "no_active_account")
304+
305+
@override_api_settings(
306+
CHECK_REVOKE_TOKEN=True,
307+
REVOKE_TOKEN_CLAIM="hash_password",
308+
BLACKLIST_AFTER_ROTATION=False,
309+
)
310+
def test_sliding_token_should_fail_after_password_change(self):
311+
"""
312+
Tests that sliding token refresh fails if CHECK_REVOKE_TOKEN is True and the
313+
user's password has changed.
314+
"""
315+
token = SlidingToken.for_user(self.user)
316+
self.user.set_password("new_password")
317+
self.user.save()
318+
319+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
320+
321+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
322+
s.is_valid(raise_exception=True)
323+
324+
self.assertEqual(e.exception.get_codes(), "password_changed")
325+
326+
@override_api_settings(
327+
CHECK_REVOKE_TOKEN=True,
328+
REVOKE_TOKEN_CLAIM="hash_password",
329+
BLACKLIST_AFTER_ROTATION=True,
330+
)
331+
def test_sliding_token_should_blacklist_after_password_change(self):
332+
"""
333+
Tests that if sliding token refresh fails due to a password change, the
334+
offending token is blacklisted.
335+
"""
336+
token = SlidingToken.for_user(self.user)
337+
self.user.set_password("new_password")
338+
self.user.save()
339+
340+
s = TokenRefreshSlidingSerializer(data={"token": str(token)})
341+
with self.assertRaises(drf_exceptions.AuthenticationFailed):
342+
s.is_valid(raise_exception=True)
343+
344+
# Check that the token is now in the blacklist
345+
jti = token[api_settings.JTI_CLAIM]
346+
self.assertTrue(OutstandingToken.objects.filter(jti=jti).exists())
347+
self.assertTrue(BlacklistedToken.objects.filter(token__jti=jti).exists())
348+
272349

273350
class TestTokenRefreshSerializer(TestCase):
274351
def setUp(self):
@@ -515,11 +592,6 @@ def test_refresh_token_should_blacklist_after_password_change(self):
515592
Tests that if token refresh fails due to a password change, the
516593
offending refresh token is blacklisted.
517594
"""
518-
from rest_framework_simplejwt.token_blacklist.models import (
519-
BlacklistedToken,
520-
OutstandingToken,
521-
)
522-
523595
refresh = RefreshToken.for_user(self.user)
524596
self.user.set_password("new_password")
525597
self.user.save()

0 commit comments

Comments
 (0)