Skip to content

Commit 3fe1cb3

Browse files
feat(auth): Revoke refresh token on password change
1 parent 6c45510 commit 3fe1cb3

File tree

2 files changed

+86
-14
lines changed

2 files changed

+86
-14
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .models import TokenUser
1111
from .settings import api_settings
1212
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken
13+
from .utils import get_md5_hash_password
1314

1415
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
1516

@@ -111,18 +112,6 @@ class TokenRefreshSerializer(serializers.Serializer):
111112
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
112113
refresh = self.token_class(attrs["refresh"])
113114

114-
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
115-
if user_id and (
116-
user := get_user_model().objects.get(
117-
**{api_settings.USER_ID_FIELD: user_id}
118-
)
119-
):
120-
if not api_settings.USER_AUTHENTICATION_RULE(user):
121-
raise AuthenticationFailed(
122-
self.error_messages["no_active_account"],
123-
"no_active_account",
124-
)
125-
126115
data = {"access": str(refresh.access_token)}
127116

128117
if api_settings.ROTATE_REFRESH_TOKENS:
@@ -142,6 +131,39 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
142131

143132
data["refresh"] = str(refresh)
144133

134+
# We handle user-related validation in a single, efficient block.
135+
user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
136+
if user_id:
137+
try:
138+
user = get_user_model().objects.get(**{api_settings.USER_ID_FIELD: user_id})
139+
except get_user_model().DoesNotExist:
140+
# This handles the case where the user has been deleted.
141+
raise AuthenticationFailed(
142+
self.error_messages["no_active_account"], "no_active_account"
143+
)
144+
145+
if not api_settings.USER_AUTHENTICATION_RULE(user):
146+
raise AuthenticationFailed(
147+
self.error_messages["no_active_account"], "no_active_account"
148+
)
149+
150+
if api_settings.CHECK_REVOKE_TOKEN:
151+
token_hash = refresh.payload.get(api_settings.REVOKE_TOKEN_CLAIM)
152+
user_hash = get_md5_hash_password(user.password)
153+
154+
if token_hash != user_hash:
155+
# If the password has changed, we blacklist the token
156+
# to prevent any further use.
157+
if "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS:
158+
try:
159+
refresh.blacklist()
160+
except AttributeError:
161+
pass
162+
163+
raise AuthenticationFailed(
164+
_("The user's password has been changed."), code="password_changed"
165+
)
166+
145167
return data
146168

147169

tests/test_serializers.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,11 @@ def test_it_should_raise_error_for_deleted_users(self):
286286

287287
s = TokenRefreshSerializer(data={"refresh": str(refresh)})
288288

289-
with self.assertRaises(django_exceptions.ObjectDoesNotExist) as e:
289+
# It should raise AuthenticationFailed instead of ObjectDoesNotExist
290+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
290291
s.is_valid()
291292

292-
self.assertIn("does not exist", str(e.exception))
293+
self.assertEqual(e.exception.get_codes(), "no_active_account")
293294

294295
def test_it_should_raise_error_for_inactive_users(self):
295296
refresh = RefreshToken.for_user(self.user)
@@ -483,6 +484,55 @@ def test_blacklist_app_not_installed_should_pass(self):
483484
reload(tokens)
484485
reload(serializers)
485486

487+
@override_api_settings(
488+
CHECK_REVOKE_TOKEN=True,
489+
REVOKE_TOKEN_CLAIM="hash_password",
490+
BLACKLIST_AFTER_ROTATION=False,
491+
)
492+
def test_refresh_token_should_fail_after_password_change(self):
493+
"""
494+
Tests that token refresh fails if CHECK_REVOKE_TOKEN is True and the
495+
user's password has changed.
496+
"""
497+
refresh = RefreshToken.for_user(self.user)
498+
self.user.set_password("new_password")
499+
self.user.save()
500+
501+
s = TokenRefreshSerializer(data={"refresh": str(refresh)})
502+
503+
with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
504+
s.is_valid(raise_exception=True)
505+
506+
self.assertEqual(e.exception.get_codes(), "password_changed")
507+
508+
@override_api_settings(
509+
CHECK_REVOKE_TOKEN=True,
510+
REVOKE_TOKEN_CLAIM="hash_password",
511+
BLACKLIST_AFTER_ROTATION=True,
512+
)
513+
def test_refresh_token_should_blacklist_after_password_change(self):
514+
"""
515+
Tests that if token refresh fails due to a password change, the
516+
offending refresh token is blacklisted.
517+
"""
518+
from rest_framework_simplejwt.token_blacklist.models import (
519+
BlacklistedToken,
520+
OutstandingToken,
521+
)
522+
523+
refresh = RefreshToken.for_user(self.user)
524+
self.user.set_password("new_password")
525+
self.user.save()
526+
527+
s = TokenRefreshSerializer(data={"refresh": str(refresh)})
528+
with self.assertRaises(drf_exceptions.AuthenticationFailed):
529+
s.is_valid(raise_exception=True)
530+
531+
# Check that the token is now in the blacklist
532+
jti = refresh[api_settings.JTI_CLAIM]
533+
self.assertTrue(OutstandingToken.objects.filter(jti=jti).exists())
534+
self.assertTrue(BlacklistedToken.objects.filter(token__jti=jti).exists())
535+
486536

487537
class TestTokenVerifySerializer(TestCase):
488538
def test_it_should_raise_token_error_if_token_invalid(self):

0 commit comments

Comments
 (0)