|
27 | 27 | get_md5_hash_password, |
28 | 28 | logger, |
29 | 29 | ) |
| 30 | +from .cache import blacklist_cache |
30 | 31 |
|
31 | 32 | if TYPE_CHECKING: |
32 | 33 | from .backends import TokenBackend |
@@ -279,6 +280,12 @@ def check_blacklist(self) -> None: |
279 | 280 | """ |
280 | 281 | jti = self.payload[api_settings.JTI_CLAIM] |
281 | 282 |
|
| 283 | + if ( |
| 284 | + blacklist_cache.is_refresh_tokens_cache_enabled and |
| 285 | + blacklist_cache.is_refresh_token_blacklisted(jti) |
| 286 | + ): |
| 287 | + raise RefreshTokenBlacklistedError(_("Token is blacklisted")) |
| 288 | + |
282 | 289 | if BlacklistedToken.objects.filter(token__jti=jti).exists(): |
283 | 290 | raise RefreshTokenBlacklistedError(_("Token is blacklisted")) |
284 | 291 |
|
@@ -307,7 +314,12 @@ def blacklist(self) -> BlacklistedToken: |
307 | 314 | }, |
308 | 315 | ) |
309 | 316 |
|
310 | | - return BlacklistedToken.objects.get_or_create(token=token) |
| 317 | + blacklisted_token, created = BlacklistedToken.objects.get_or_create(token=token) |
| 318 | + |
| 319 | + if blacklist_cache.is_refresh_tokens_cache_enabled: |
| 320 | + blacklist_cache.add_refresh_token(jti) |
| 321 | + |
| 322 | + return blacklisted_token, created |
311 | 323 |
|
312 | 324 | def outstand(self) -> Optional[OutstandingToken]: |
313 | 325 | """ |
@@ -397,7 +409,12 @@ def blacklist_family(self) -> BlacklistedTokenFamily: |
397 | 409 | ) |
398 | 410 |
|
399 | 411 | # Blacklist the entire family |
400 | | - return BlacklistedTokenFamily.objects.get_or_create(family=family)[0] |
| 412 | + blacklisted_fam, created = BlacklistedTokenFamily.objects.get_or_create(family=family) |
| 413 | + |
| 414 | + if blacklist_cache.is_families_cache_enabled: |
| 415 | + blacklist_cache.add_token_family(family_id) |
| 416 | + |
| 417 | + return blacklisted_fam |
401 | 418 |
|
402 | 419 | def get_family_id(self) -> Optional[str]: |
403 | 420 | return self.payload.get(api_settings.TOKEN_FAMILY_CLAIM, None) |
@@ -443,6 +460,12 @@ def check_family_blacklist(token: T) -> None: |
443 | 460 | logger.warning(f"Token of user:{user_id} does not have a family_id. Skipping family blacklist check.") |
444 | 461 | return |
445 | 462 |
|
| 463 | + if ( |
| 464 | + blacklist_cache.is_families_cache_enabled and |
| 465 | + blacklist_cache.is_token_family_blacklisted(family_id) |
| 466 | + ): |
| 467 | + raise TokenError(_("Token family is blacklisted")) |
| 468 | + |
446 | 469 | if BlacklistedTokenFamily.objects.filter(family__family_id=family_id).exists(): |
447 | 470 | raise TokenError(_("Token family is blacklisted")) |
448 | 471 |
|
|
0 commit comments