1+ import base64
12import json
23from collections .abc import AsyncIterator
34from typing import Any
45
56import redis .asyncio as redis
67from cachetools import TTLCache
8+ from cryptography .fernet import Fernet
9+ from cryptography .hazmat .primitives import hashes
10+ from cryptography .hazmat .primitives .kdf .pbkdf2 import PBKDF2HMAC
711from loguru import logger
812
913from app .core .config import settings
@@ -23,6 +27,38 @@ def __init__(self) -> None:
2327 if not settings .REDIS_URL :
2428 logger .warning ("REDIS_URL is not set. Token storage will fail until a Redis instance is configured." )
2529
30+ if not settings .TOKEN_SALT or settings .TOKEN_SALT == "change-me" :
31+ logger .warning (
32+ "TOKEN_SALT is missing or using the default placeholder. Set a strong value to secure tokens."
33+ )
34+
35+ def _ensure_secure_salt (self ) -> None :
36+ if not settings .TOKEN_SALT or settings .TOKEN_SALT == "change-me" :
37+ logger .error ("Refusing to store credentials because TOKEN_SALT is unset or using the insecure default." )
38+ raise RuntimeError (
39+ "Server misconfiguration: TOKEN_SALT must be set to a non-default value before storing credentials."
40+ )
41+
42+ def _get_cipher (self ) -> Fernet :
43+ """Get or create Fernet cipher instance based on TOKEN_SALT."""
44+ if self ._cipher is None :
45+ kdf = PBKDF2HMAC (
46+ algorithm = hashes .SHA256 (),
47+ length = 32 ,
48+ salt = b"" , # empty salt
49+ iterations = 200_000 ,
50+ )
51+
52+ key = base64 .urlsafe_b64encode (kdf .derive (settings .TOKEN_SALT .encode ("utf-8" )))
53+ self ._cipher = Fernet (key )
54+ return self ._cipher
55+
56+ def encrypt_token (self , token : str ) -> str :
57+ return self ._cipher .encrypt (token .encode ("utf-8" )).decode ("utf-8" )
58+
59+ def decrypt_token (self , enc : str ) -> str :
60+ return self ._cipher .decrypt (enc .encode ("utf-8" )).decode ("utf-8" )
61+
2662 async def _get_client (self ) -> redis .Redis :
2763 if self ._client is None :
2864 self ._client = redis .from_url (settings .REDIS_URL , decode_responses = True , encoding = "utf-8" )
@@ -39,6 +75,7 @@ def get_user_id_from_token(self, token: str) -> str:
3975 return token .strip () if token else ""
4076
4177 async def store_user_data (self , user_id : str , payload : dict [str , Any ]) -> str :
78+ self ._ensure_secure_salt ()
4279 token = self .get_token_from_user_id (user_id )
4380 key = self ._format_key (token )
4481
@@ -48,6 +85,9 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str:
4885 # Store user_id in payload for convenience
4986 storage_data ["user_id" ] = user_id
5087
88+ if storage_data .get ("authKey" ):
89+ storage_data ["authKey" ] = self .encrypt_token (storage_data ["authKey" ])
90+
5191 client = await self ._get_client ()
5292 json_str = json .dumps (storage_data )
5393
@@ -74,6 +114,8 @@ async def get_user_data(self, token: str) -> dict[str, Any] | None:
74114
75115 try :
76116 data = json .loads (data_raw )
117+ if data .get ("authKey" ):
118+ data ["authKey" ] = self .decrypt_token (data ["authKey" ])
77119 self ._payload_cache [token ] = data
78120 return data
79121 except json .JSONDecodeError :
0 commit comments