Skip to content

Commit 5a6c4cb

Browse files
committed
fix: update blacklist() type hint, blacklist_family() type hint and return values, and TokenVerifySerializer check. resolves #911
1 parent 9097e21 commit 5a6c4cb

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
from .settings import api_settings
1313
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken, FamilyMixin
1414
from .cache import blacklist_cache
15+
from .token_blacklist.models import BlacklistedToken
1516

1617
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
1718

18-
if api_settings.BLACKLIST_AFTER_ROTATION:
19-
from .token_blacklist.models import BlacklistedToken
20-
2119

2220
class PasswordField(serializers.CharField):
2321
def __init__(self, *args, **kwargs) -> None:
@@ -192,7 +190,7 @@ def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
192190
token = UntypedToken(attrs["token"])
193191

194192
if (
195-
api_settings.BLACKLIST_AFTER_ROTATION
193+
token.get(api_settings.TOKEN_TYPE_CLAIM) == RefreshToken.token_type
196194
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
197195
):
198196
jti = token.get(api_settings.JTI_CLAIM)

rest_framework_simplejwt/tokens.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def check_blacklist(self) -> None:
289289
if BlacklistedToken.objects.filter(token__jti=jti).exists():
290290
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
291291

292-
def blacklist(self) -> BlacklistedToken:
292+
def blacklist(self) -> tuple[BlacklistedToken, bool]:
293293
"""
294294
Ensures this token is included in the outstanding token list and
295295
adds it to the blacklist.
@@ -390,7 +390,7 @@ def verify(self, *args, **kwargs) -> None:
390390

391391
super().verify(*args, **kwargs) # type: ignore
392392

393-
def blacklist_family(self) -> BlacklistedTokenFamily:
393+
def blacklist_family(self) -> tuple[BlacklistedTokenFamily, bool]:
394394
"""
395395
Blacklists the token family.
396396
"""
@@ -414,7 +414,7 @@ def blacklist_family(self) -> BlacklistedTokenFamily:
414414
if blacklist_cache.is_families_cache_enabled:
415415
blacklist_cache.add_token_family(family_id)
416416

417-
return blacklisted_fam
417+
return blacklisted_fam, created
418418

419419
def get_family_id(self) -> Optional[str]:
420420
return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None)

tests/test_token_blacklist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def setUp(self):
312312
super().setUp()
313313

314314
@override_api_settings(BLACKLIST_AFTER_ROTATION=True)
315-
def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled(
315+
def test_token_verify_serializer_should_honour_blacklist_if_rotation_enabled(
316316
self,
317317
):
318318
refresh_token = RefreshToken.for_user(self.user)
@@ -322,14 +322,14 @@ def test_token_verify_serializer_should_honour_blacklist_if_blacklisting_enabled
322322
self.assertFalse(serializer.is_valid())
323323

324324
@override_api_settings(BLACKLIST_AFTER_ROTATION=False)
325-
def test_token_verify_serializer_should_not_honour_blacklist_if_blacklisting_not_enabled(
325+
def test_token_verify_serializer_should_honour_blacklist_if_rotation_not_enabled(
326326
self,
327327
):
328328
refresh_token = RefreshToken.for_user(self.user)
329329
refresh_token.blacklist()
330330

331331
serializer = TokenVerifySerializer(data={"token": str(refresh_token)})
332-
self.assertTrue(serializer.is_valid())
332+
self.assertFalse(serializer.is_valid())
333333

334334

335335
class TestBigAutoFieldIDMigration(MigrationTestCase):

tests/test_token_family.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,13 @@ def test_token_family_can_be_manually_blacklisted(self):
243243
self.assertEqual(TokenFamily.objects.count(), 1)
244244

245245
# Add family to blacklist
246-
blacklisted_fam = token.blacklist_family()
246+
blacklisted_fam, created = token.blacklist_family()
247247

248248
# Should not add family to tokenfamily list if already present
249249
self.assertEqual(TokenFamily.objects.count(), 1)
250250

251251
# Should return blacklist record
252+
self.assertTrue(created)
252253
self.assertEqual(blacklisted_fam.family.family_id, token.get_family_id())
253254

254255
with self.assertRaises(TokenError) as cm:

0 commit comments

Comments
 (0)