Skip to content

Commit 52df0b1

Browse files
committed
feat: integrate caching logic for blacklisted refresh tokens and families
- Added cache check for blacklisted refresh tokens in TokenVerifySerializer - Enhanced BlacklistMixin and FamilyMixin to support caching for blacklisted tokens and families.
1 parent 0c9eb31 commit 52df0b1

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .models import TokenUser
1212
from .settings import api_settings
1313
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken, FamilyMixin
14+
from .cache import blacklist_cache
1415

1516
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
1617

@@ -195,6 +196,12 @@ def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
195196
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
196197
):
197198
jti = token.get(api_settings.JTI_CLAIM)
199+
if (
200+
blacklist_cache.is_refresh_tokens_cache_enabled
201+
and blacklist_cache.is_refresh_token_blacklisted(jti)
202+
):
203+
raise ValidationError(_("Token is blacklisted"))
204+
198205
if BlacklistedToken.objects.filter(token__jti=jti).exists():
199206
raise ValidationError(_("Token is blacklisted"))
200207

rest_framework_simplejwt/tokens.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_md5_hash_password,
2828
logger,
2929
)
30+
from .cache import blacklist_cache
3031

3132
if TYPE_CHECKING:
3233
from .backends import TokenBackend
@@ -279,6 +280,12 @@ def check_blacklist(self) -> None:
279280
"""
280281
jti = self.payload[api_settings.JTI_CLAIM]
281282

283+
if (
284+
blacklist_cache.is_refresh_tokens_cache_enabled and
285+
blacklist_cache.is_refresh_token_blacklisted(jti)
286+
):
287+
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
288+
282289
if BlacklistedToken.objects.filter(token__jti=jti).exists():
283290
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
284291

@@ -307,7 +314,12 @@ def blacklist(self) -> BlacklistedToken:
307314
},
308315
)
309316

310-
return BlacklistedToken.objects.get_or_create(token=token)
317+
blacklisted_token, created = BlacklistedToken.objects.get_or_create(token=token)
318+
319+
if blacklist_cache.is_refresh_tokens_cache_enabled:
320+
blacklist_cache.add_refresh_token(jti)
321+
322+
return blacklisted_token, created
311323

312324
def outstand(self) -> Optional[OutstandingToken]:
313325
"""
@@ -397,7 +409,12 @@ def blacklist_family(self) -> BlacklistedTokenFamily:
397409
)
398410

399411
# Blacklist the entire family
400-
return BlacklistedTokenFamily.objects.get_or_create(family=family)[0]
412+
blacklisted_fam, created = BlacklistedTokenFamily.objects.get_or_create(family=family)
413+
414+
if blacklist_cache.is_families_cache_enabled:
415+
blacklist_cache.add_token_family(family_id)
416+
417+
return blacklisted_fam
401418

402419
def get_family_id(self) -> Optional[str]:
403420
return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None)
@@ -443,6 +460,12 @@ def check_family_blacklist(token: T) -> None:
443460
logger.warning(f"Token of user:{user_id} does not have a family_id. Skipping family blacklist check.")
444461
return
445462

463+
if (
464+
blacklist_cache.is_families_cache_enabled and
465+
blacklist_cache.is_token_family_blacklisted(family_id)
466+
):
467+
raise TokenError(_("Token family is blacklisted"))
468+
446469
if BlacklistedTokenFamily.objects.filter(family__family_id=family_id).exists():
447470
raise TokenError(_("Token family is blacklisted"))
448471

0 commit comments

Comments
 (0)