1212
1313from app .core .config import settings
1414from app .core .security import redact_token
15+ from app .services .redis_service import redis_service
1516
1617
1718class TokenStore :
@@ -20,12 +21,9 @@ class TokenStore:
2021 KEY_PREFIX = settings .REDIS_TOKEN_KEY
2122
2223 def __init__ (self ) -> None :
23- self ._client : redis .Redis | None = None
2424 # Negative cache for missing tokens to avoid repeated Redis GETs
2525 # when external probes request non-existent tokens.
2626 self ._missing_tokens : TTLCache = TTLCache (maxsize = 10000 , ttl = 86400 )
27- if not settings .REDIS_URL :
28- logger .warning ("REDIS_URL is not set. Token storage will fail until a Redis instance is configured." )
2927
3028 if not settings .TOKEN_SALT or settings .TOKEN_SALT == "change-me" :
3129 logger .warning (
@@ -34,10 +32,8 @@ def __init__(self) -> None:
3432
3533 def _ensure_secure_salt (self ) -> None :
3634 if not settings .TOKEN_SALT or settings .TOKEN_SALT == "change-me" :
37- logger .error ("Refusing to store credentials because TOKEN_SALT is unset or using the insecure default." )
38- raise RuntimeError (
39- "Server misconfiguration: TOKEN_SALT must be set to a non-default value before storing" " credentials."
40- )
35+ logger .error ("TOKEN_SALT is unset or using the insecure default." )
36+ raise RuntimeError ("TOKEN_SALT must be set to a non-default value before storing credentials." )
4137
4238 def _get_cipher (self ) -> Fernet :
4339 salt = b"x7FDf9kypzQ1LmR32b8hWv49sKq2Pd8T"
@@ -59,59 +55,6 @@ def decrypt_token(self, enc: str) -> str:
5955 cipher = self ._get_cipher ()
6056 return cipher .decrypt (enc .encode ("utf-8" )).decode ("utf-8" )
6157
62- async def _get_client (self ) -> redis .Redis :
63- if self ._client is None :
64- # Add socket timeouts to avoid hanging on Redis operations
65- import traceback
66-
67- logger .info ("Creating shared Redis client" )
68- # Limit the number of pooled connections to avoid unbounded growth
69- # `max_connections` is forwarded to ConnectionPool.from_url
70- self ._client = redis .from_url (
71- settings .REDIS_URL ,
72- decode_responses = True ,
73- encoding = "utf-8" ,
74- socket_connect_timeout = 5 ,
75- socket_timeout = 5 ,
76- max_connections = getattr (settings , "REDIS_MAX_CONNECTIONS" , 100 ),
77- health_check_interval = 30 ,
78- socket_keepalive = True ,
79- )
80- if getattr (self , "_creation_count" , None ) is None :
81- self ._creation_count = 1
82- else :
83- self ._creation_count += 1
84- logger .warning (
85- f"Redis client creation invoked again (count={ self ._creation_count } )."
86- f" Stack:\n { '' .join (traceback .format_stack ())} "
87- )
88- return self ._client
89-
90- async def close (self ) -> None :
91- """Close and disconnect the shared Redis client (call on shutdown)."""
92- if self ._client is None :
93- return
94- try :
95- logger .info ("Closing shared Redis client" )
96- # Close client and disconnect underlying pool
97- try :
98- await self ._client .close ()
99- except Exception as e :
100- logger .debug (f"Silent failure closing redis client: { e } " )
101- try :
102- pool = getattr (self ._client , "connection_pool" , None )
103- if pool is not None :
104- # connection_pool.disconnect may be a coroutine in some redis implementations
105- disconnect = getattr (pool , "disconnect" , None )
106- if disconnect :
107- res = disconnect ()
108- if hasattr (res , "__await__" ):
109- await res
110- except Exception as e :
111- logger .debug (f"Silent failure disconnecting redis pool: { e } " )
112- finally :
113- self ._client = None
114-
11558 def _format_key (self , token : str ) -> str :
11659 """Format Redis key from token."""
11760 return f"{ self .KEY_PREFIX } { token } "
@@ -145,13 +88,12 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str:
14588 # Do not store plaintext passwords
14689 raise RuntimeError ("PASSWORD_ENCRYPT_FAILED" )
14790
148- client = await self ._get_client ()
14991 json_str = json .dumps (storage_data )
15092
15193 if settings .TOKEN_TTL_SECONDS and settings .TOKEN_TTL_SECONDS > 0 :
152- await client . setex (key , settings .TOKEN_TTL_SECONDS , json_str )
94+ await redis_service . set (key , json_str , settings .TOKEN_TTL_SECONDS )
15395 else :
154- await client .set (key , json_str )
96+ await redis_service .set (key , json_str )
15597
15698 # Invalidate async LRU cache for fresh reads on subsequent requests
15799 try :
@@ -193,8 +135,7 @@ async def get_user_data(self, token: str) -> dict[str, Any] | None:
193135
194136 logger .debug (f"[REDIS] Cache miss. Fetching data from redis for { token } " )
195137 key = self ._format_key (token )
196- client = await self ._get_client ()
197- data_raw = await client .get (key )
138+ data_raw = await redis_service .get (key )
198139
199140 if not data_raw :
200141 # remember negative result briefly
@@ -232,8 +173,7 @@ async def delete_token(self, token: str = None, key: str = None) -> None:
232173 if token :
233174 key = self ._format_key (token )
234175
235- client = await self ._get_client ()
236- await client .delete (key )
176+ await redis_service .delete (key )
237177
238178 # Invalidate async LRU cache so future reads reflect deletion
239179 try :
@@ -261,7 +201,7 @@ async def count_users(self) -> int:
261201 Cached for 12 hours to avoid frequent Redis scans.
262202 """
263203 try :
264- client = await self . _get_client ()
204+ client = await redis_service . get_client ()
265205 except (redis .RedisError , OSError ) as exc :
266206 logger .warning (f"Cannot count users; Redis unavailable: { exc } " )
267207 return 0
0 commit comments