Skip to content

Commit d691bc6

Browse files
committed
fix: update blacklist() type hint, blacklist_family() type hint and return values, and TokenVerifySerializer check. resolves #911
1 parent afb06dd commit d691bc6

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken, FamilyMixin
1515
from .utils import get_md5_hash_password
1616
from .cache import blacklist_cache
17+
from .token_blacklist.models import BlacklistedToken
1718

1819
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
1920

20-
if api_settings.BLACKLIST_AFTER_ROTATION:
21-
from .token_blacklist.models import BlacklistedToken
22-
2321

2422
class PasswordField(serializers.CharField):
2523
def __init__(self, *args, **kwargs) -> None:
@@ -263,7 +261,7 @@ def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
263261
token = UntypedToken(attrs["token"])
264262

265263
if (
266-
api_settings.BLACKLIST_AFTER_ROTATION
264+
token.get(api_settings.TOKEN_TYPE_CLAIM) == RefreshToken.token_type
267265
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
268266
):
269267
jti = token.get(api_settings.JTI_CLAIM)

rest_framework_simplejwt/tokens.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def check_blacklist(self) -> None:
288288
if BlacklistedToken.objects.filter(token__jti=jti).exists():
289289
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
290290

291-
def blacklist(self) -> BlacklistedToken:
291+
def blacklist(self) -> tuple[BlacklistedToken, bool]:
292292
"""
293293
Ensures this token is included in the outstanding token list and
294294
adds it to the blacklist.
@@ -389,7 +389,7 @@ def verify(self, *args, **kwargs) -> None:
389389

390390
super().verify(*args, **kwargs) # type: ignore
391391

392-
def blacklist_family(self) -> BlacklistedTokenFamily:
392+
def blacklist_family(self) -> tuple[BlacklistedTokenFamily, bool]:
393393
"""
394394
Blacklists the token family.
395395
"""
@@ -413,7 +413,7 @@ def blacklist_family(self) -> BlacklistedTokenFamily:
413413
if blacklist_cache.is_families_cache_enabled:
414414
blacklist_cache.add_token_family(family_id)
415415

416-
return blacklisted_fam
416+
return blacklisted_fam, created
417417

418418
def get_family_id(self) -> Optional[str]:
419419
return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None)

tests/test_token_blacklist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def setUp(self):
312312
super().setUp()
313313

314314
@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
315-
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
315+
def test_token_verify_serializer_should_honour_blacklist_if_rotation_enabled(
316316
self,
317317
):
318318
refresh_token = RefreshToken.for_user(self.user)
@@ -322,14 +322,14 @@ def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled
322322
self.assertFalse(serializer.is_valid())
323323

324324
@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
325-
def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
325+
def test_token_verify_serializer_should_honour_blacklist_if_rotation_not_enabled(
326326
self,
327327
):
328328
refresh_token = RefreshToken.for_user(self.user)
329329
refresh_token.blacklist()
330330

331331
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
332-
self.assertTrue(serializer.is_valid())
332+
self.assertFalse(serializer.is_valid())
333333

334334

335335
class TestBigAutoFieldIDMigration(MigrationTestCase):

tests/test_token_family.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,13 @@ def test_token_family_can_be_manually_blacklisted(self):
243243
self.assertEqual(TokenFamily.objects.count(), 1)
244244

245245
# Add family to blacklist
246-
blacklisted_fam = token.blacklist_family()
246+
blacklisted_fam, created = token.blacklist_family()
247247

248248
# Should not add family to tokenfamily list if already present
249249
self.assertEqual(TokenFamily.objects.count(), 1)
250250

251251
# Should return blacklist record
252+
self.assertTrue(created)
252253
self.assertEqual(blacklisted_fam.family.family_id, token.get_family_id())
253254

254255
with self.assertRaises(TokenError) as cm:

0 commit comments

Comments
 (0)