Skip to content

Commit cebedb6

Browse files
committed
Harden revoke access token for password changes
1 parent 00de028 commit cebedb6

File tree

5 files changed

+70
-46
lines changed

5 files changed

+70
-46
lines changed

rest_framework_simplejwt/authentication.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.contrib.auth import get_user_model
44
from django.contrib.auth.models import AbstractBaseUser
5+
from django.utils.crypto import constant_time_compare
56
from django.utils.translation import gettext_lazy as _
67
from rest_framework import HTTP_HEADER_ENCODING, authentication
78
from rest_framework.request import Request
@@ -10,7 +11,7 @@
1011
from .models import TokenUser
1112
from .settings import api_settings
1213
from .tokens import Token
13-
from .utils import get_md5_hash_password
14+
from .utils import get_fallback_token_auth_hash, get_token_auth_hash
1415

1516
AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES
1617

@@ -139,9 +140,17 @@ def get_user(self, validated_token: Token) -> AuthUser:
139140
raise AuthenticationFailed(_("User is inactive"), code="user_inactive")
140141

141142
if api_settings.CHECK_REVOKE_TOKEN:
142-
if validated_token.get(
143-
api_settings.REVOKE_TOKEN_CLAIM
144-
) != get_md5_hash_password(user.password):
143+
validation_claim = validated_token.get(api_settings.REVOKE_TOKEN_CLAIM)
144+
if (
145+
validation_claim is None
146+
or not constant_time_compare(
147+
validation_claim, get_token_auth_hash(user)
148+
)
149+
and not any(
150+
constant_time_compare(validation_claim, fallback_auth_hash)
151+
for fallback_auth_hash in get_fallback_token_auth_hash(user)
152+
)
153+
):
145154
raise AuthenticationFailed(
146155
_("The user's password has been changed."), code="password_changed"
147156
)

rest_framework_simplejwt/tokens.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
datetime_from_epoch,
2323
datetime_to_epoch,
2424
format_lazy,
25-
get_md5_hash_password,
25+
get_token_auth_hash,
2626
logger,
2727
)
2828

@@ -250,9 +250,7 @@ def for_user(cls: type[T], user: AuthUser) -> T:
250250
token[api_settings.USER_ID_CLAIM] = user_id
251251

252252
if api_settings.CHECK_REVOKE_TOKEN:
253-
token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(
254-
user.password
255-
)
253+
token[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)
256254

257255
return token
258256

rest_framework_simplejwt/utils.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,44 @@
1-
import hashlib
21
import logging
32
from calendar import timegm
43
from datetime import datetime, timezone
5-
from typing import Callable
4+
from typing import TYPE_CHECKING, Callable, TypeVar
65

76
from django.conf import settings
7+
from django.contrib.auth.models import AbstractBaseUser
8+
from django.utils.crypto import salted_hmac
89
from django.utils.functional import lazy
910

11+
if TYPE_CHECKING:
12+
from .models import TokenUser
1013

11-
def get_md5_hash_password(password: str) -> str:
14+
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
15+
16+
17+
def _get_token_auth_hash(user: "AuthUser", secret=None) -> str:
18+
key_salt = "rest_framework_simplejwt.utils.get_token_auth_hash"
19+
return salted_hmac(key_salt, user.password, secret=secret).hexdigest()
20+
21+
22+
def get_token_auth_hash(user: "AuthUser") -> str:
1223
"""
13-
Returns MD5 hash of the given password
24+
Return an HMAC of the given user password field.
1425
"""
15-
return hashlib.md5(password.encode()).hexdigest().upper()
26+
if hasattr(user, "get_session_auth_hash"):
27+
return user.get_session_auth_hash()
28+
return _get_token_auth_hash(user)
29+
30+
31+
def get_fallback_token_auth_hash(user: "AuthUser") -> str:
32+
"""
33+
Yields a sequence of fallback HMACs of the given user password field.
34+
"""
35+
if hasattr(user, "get_session_auth_fallback_hash"):
36+
yield from user.get_session_auth_fallback_hash()
37+
38+
fallback_keys = getattr(settings, "SECRET_KEY_FALLBACKS", [])
39+
yield from (
40+
_get_token_auth_hash(user, fallback_secret) for fallback_secret in fallback_keys
41+
)
1642

1743

1844
def make_utc(dt: datetime) -> datetime:

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def pytest_configure():
1616
},
1717
SITE_ID=1,
1818
SECRET_KEY="not very secret in tests",
19+
SECRET_KEY_FALLBACKS=[
20+
"old not very secure secret",
21+
"other old not very secure secret",
22+
],
1923
USE_I18N=True,
2024
STATIC_URL="/static/",
2125
ROOT_URLCONF="tests.urls",

tests/test_authentication.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rest_framework_simplejwt.models import TokenUser
1111
from rest_framework_simplejwt.settings import api_settings
1212
from rest_framework_simplejwt.tokens import AccessToken, SlidingToken
13-
from rest_framework_simplejwt.utils import get_md5_hash_password
13+
from rest_framework_simplejwt.utils import _get_token_auth_hash, get_token_auth_hash
1414

1515
from .utils import override_api_settings
1616

@@ -145,21 +145,19 @@ def test_get_user(self):
145145
with self.assertRaises(AuthenticationFailed):
146146
self.backend.get_user(payload)
147147

148-
u = User.objects.create_user(username="markhamill")
149-
u.is_active = False
150-
u.save()
148+
user = User.objects.create_user(username="markhamill", is_active=False)
151149

152-
payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)
150+
payload[api_settings.USER_ID_CLAIM] = getattr(user, api_settings.USER_ID_FIELD)
153151

154152
# Should raise exception if user is inactive
155153
with self.assertRaises(AuthenticationFailed):
156154
self.backend.get_user(payload)
157155

158-
u.is_active = True
159-
u.save()
156+
user.is_active = True
157+
user.save()
160158

161159
# Otherwise, should return correct user
162-
self.assertEqual(self.backend.get_user(payload).id, u.id)
160+
self.assertEqual(self.backend.get_user(payload).id, user.id)
163161

164162
@override_api_settings(
165163
CHECK_USER_IS_ACTIVE=False,
@@ -190,40 +188,29 @@ def test_get_inactive_user(self):
190188
CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim"
191189
)
192190
def test_get_user_with_check_revoke_token(self):
193-
payload = {"some_other_id": "foo"}
194-
195-
# Should raise error if no recognizable user identification
196-
with self.assertRaises(InvalidToken):
197-
self.backend.get_user(payload)
198-
199-
payload[api_settings.USER_ID_CLAIM] = 42
200-
201-
# Should raise exception if user not found
202-
with self.assertRaises(AuthenticationFailed):
203-
self.backend.get_user(payload)
204-
205-
u = User.objects.create_user(username="markhamill")
206-
u.is_active = False
207-
u.save()
191+
user = User.objects.create_user(username="markhamill")
192+
payload = {
193+
api_settings.USER_ID_CLAIM: getattr(user, api_settings.USER_ID_FIELD)
194+
}
208195

209-
payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)
210-
211-
# Should raise exception if user is inactive
196+
# Should raise exception if claim is missing
212197
with self.assertRaises(AuthenticationFailed):
213198
self.backend.get_user(payload)
214199

215-
u.is_active = True
216-
u.save()
217-
218-
# Should raise exception if hash password is different
200+
payload[api_settings.REVOKE_TOKEN_CLAIM] = "differenthash"
201+
# Should raise exception if claim is different
219202
with self.assertRaises(AuthenticationFailed):
220203
self.backend.get_user(payload)
221204

222-
if api_settings.CHECK_REVOKE_TOKEN:
223-
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password)
205+
payload[api_settings.REVOKE_TOKEN_CLAIM] = _get_token_auth_hash(
206+
user, "other old not very secure secret"
207+
)
208+
# Should return correct user if claim was signed with an old key
209+
self.assertEqual(self.backend.get_user(payload).id, user.id)
224210

211+
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_token_auth_hash(user)
225212
# Otherwise, should return correct user
226-
self.assertEqual(self.backend.get_user(payload).id, u.id)
213+
self.assertEqual(self.backend.get_user(payload).id, user.id)
227214

228215

229216
class TestJWTStatelessUserAuthentication(TestCase):

0 commit comments

Comments
 (0)