Skip to content

Commit e488fac

Browse files
committed
feat: integrate caching logic for blacklisted refresh tokens and families
- Added cache check for blacklisted refresh tokens in TokenVerifySerializer - Enhanced BlacklistMixin and FamilyMixin to support caching for blacklisted tokens and families.
1 parent 09215d6 commit e488fac

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

rest_framework_simplejwt/serializers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .settings import api_settings
1414
from .tokens import RefreshToken, SlidingToken, Token, UntypedToken, FamilyMixin
1515
from .utils import get_md5_hash_password
16+
from .cache import blacklist_cache
1617

1718
AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser)
1819

@@ -266,6 +267,12 @@ def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
266267
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
267268
):
268269
jti = token.get(api_settings.JTI_CLAIM)
270+
if (
271+
blacklist_cache.is_refresh_tokens_cache_enabled
272+
and blacklist_cache.is_refresh_token_blacklisted(jti)
273+
):
274+
raise ValidationError(_("Token is blacklisted"))
275+
269276
if BlacklistedToken.objects.filter(token__jti=jti).exists():
270277
raise ValidationError(_("Token is blacklisted"))
271278

rest_framework_simplejwt/tokens.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_md5_hash_password,
2828
logger,
2929
)
30+
from .cache import blacklist_cache
3031

3132
if TYPE_CHECKING:
3233
from .backends import TokenBackend
@@ -278,6 +279,12 @@ def check_blacklist(self) -> None:
278279
"""
279280
jti = self.payload[api_settings.JTI_CLAIM]
280281

282+
if (
283+
blacklist_cache.is_refresh_tokens_cache_enabled and
284+
blacklist_cache.is_refresh_token_blacklisted(jti)
285+
):
286+
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
287+
281288
if BlacklistedToken.objects.filter(token__jti=jti).exists():
282289
raise RefreshTokenBlacklistedError(_("Token is blacklisted"))
283290

@@ -306,7 +313,12 @@ def blacklist(self) -> BlacklistedToken:
306313
},
307314
)
308315

309-
return BlacklistedToken.objects.get_or_create(token=token)
316+
blacklisted_token, created = BlacklistedToken.objects.get_or_create(token=token)
317+
318+
if blacklist_cache.is_refresh_tokens_cache_enabled:
319+
blacklist_cache.add_refresh_token(jti)
320+
321+
return blacklisted_token, created
310322

311323
def outstand(self) -> Optional[OutstandingToken]:
312324
"""
@@ -396,7 +408,12 @@ def blacklist_family(self) -> BlacklistedTokenFamily:
396408
)
397409

398410
# Blacklist the entire family
399-
return BlacklistedTokenFamily.objects.get_or_create(family=family)[0]
411+
blacklisted_fam, created = BlacklistedTokenFamily.objects.get_or_create(family=family)
412+
413+
if blacklist_cache.is_families_cache_enabled:
414+
blacklist_cache.add_token_family(family_id)
415+
416+
return blacklisted_fam
400417

401418
def get_family_id(self) -> Optional[str]:
402419
return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None)
@@ -442,6 +459,12 @@ def check_family_blacklist(token: T) -> None:
442459
logger.warning(f"Token of user:{user_id} does not have a family_id. Skipping family blacklist check.")
443460
return
444461

462+
if (
463+
blacklist_cache.is_families_cache_enabled and
464+
blacklist_cache.is_token_family_blacklisted(family_id)
465+
):
466+
raise TokenError(_("Token family is blacklisted"))
467+
445468
if BlacklistedTokenFamily.objects.filter(family__family_id=family_id).exists():
446469
raise TokenError(_("Token family is blacklisted"))
447470

0 commit comments

Comments
 (0)