55from typing import Any
66
77import redis .asyncio as redis
8+ from async_lru import alru_cache
89from cachetools import TTLCache
910from cryptography .fernet import Fernet , InvalidToken
1011from cryptography .hazmat .primitives import hashes
@@ -22,9 +23,9 @@ class TokenStore:
2223
2324 def __init__ (self ) -> None :
2425 self ._client : redis .Redis | None = None
25- # Cache decrypted payloads for 1 day (86400s) to reduce Redis hits
26- # Max size 5000 allows many active users without eviction
27- self ._payload_cache : TTLCache = TTLCache (maxsize = 5000 , ttl = 86400 )
26+ # Negative cache for missing tokens to avoid repeated Redis GETs
27+ # when external probes request non-existent tokens.
28+ self ._missing_tokens : TTLCache = TTLCache (maxsize = 10000 , ttl = 3600 )
2829 # per-request redis call counter (context-local)
2930 self ._redis_calls_var : contextvars .ContextVar [int ] = contextvars .ContextVar ("watchly_redis_calls" , default = 0 )
3031
@@ -66,15 +67,59 @@ def decrypt_token(self, enc: str) -> str:
6667 async def _get_client (self ) -> redis .Redis :
6768 if self ._client is None :
6869 # Add socket timeouts to avoid hanging on Redis operations
70+ import traceback
71+
72+ logger .info ("Creating shared Redis client" )
73+ # Limit the number of pooled connections to avoid unbounded growth
74+ # `max_connections` is forwarded to ConnectionPool.from_url
6975 self ._client = redis .from_url (
7076 settings .REDIS_URL ,
7177 decode_responses = True ,
7278 encoding = "utf-8" ,
7379 socket_connect_timeout = 5 ,
7480 socket_timeout = 5 ,
81+ max_connections = getattr (settings , "REDIS_MAX_CONNECTIONS" , 100 ),
82+ health_check_interval = 30 ,
83+ socket_keepalive = True ,
7584 )
85+ # If _get_client is called multiple times in different contexts it
86+ # could indicate multiple processes/threads or a bug opening
87+ # additional clients; log a stacktrace for debugging.
88+ if getattr (self , "_creation_count" , None ) is None :
89+ self ._creation_count = 1
90+ else :
91+ self ._creation_count += 1
92+ logger .warning (
93+ f"Redis client creation invoked again (count={ self ._creation_count } )."
94+ f" Stack:\n { '' .join (traceback .format_stack ())} "
95+ )
7696 return self ._client
7797
98+ async def close (self ) -> None :
99+ """Close and disconnect the shared Redis client (call on shutdown)."""
100+ if self ._client is None :
101+ return
102+ try :
103+ logger .info ("Closing shared Redis client" )
104+ # Close client and disconnect underlying pool
105+ try :
106+ await self ._client .close ()
107+ except Exception :
108+ pass
109+ try :
110+ pool = getattr (self ._client , "connection_pool" , None )
111+ if pool is not None :
112+ # connection_pool.disconnect may be a coroutine in some redis implementations
113+ disconnect = getattr (pool , "disconnect" , None )
114+ if disconnect :
115+ res = disconnect ()
116+ if hasattr (res , "__await__" ):
117+ await res
118+ except Exception :
119+ pass
120+ finally :
121+ self ._client = None
122+
78123 def _format_key (self , token : str ) -> str :
79124 """Format Redis key from token."""
80125 return f"{ self .KEY_PREFIX } { token } "
@@ -109,30 +154,49 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str:
109154 self ._incr_calls ()
110155 await client .set (key , json_str )
111156
112- # Update cache with the payload
113- self ._payload_cache [token ] = payload
157+ # Invalidate async LRU cached reads so future reads use the updated payload
158+ try :
159+ self .get_user_data .cache_clear ()
160+ except Exception :
161+ pass
162+
163+ # Ensure we remove from negative cache so new value is read next time
164+ try :
165+ if token in self ._missing_tokens :
166+ del self ._missing_tokens [token ]
167+ except Exception :
168+ pass
114169
115170 return token
116171
172+ @alru_cache (maxsize = 5000 )
117173 async def get_user_data (self , token : str ) -> dict [str , Any ] | None :
118- if token in self ._payload_cache :
119- logger .info (f"[REDIS] Using cached redis data { token } " )
120- return self ._payload_cache [token ]
121- logger .info (f"[REDIS]Caching Failed. Fetching data from redis for { token } " )
174+ # Short-circuit for tokens known to be missing
175+ try :
176+ if token in self ._missing_tokens :
177+ logger .debug (f"[REDIS] Negative cache hit for missing token { token } " )
178+ return None
179+ except Exception :
180+ pass
122181
182+ logger .debug (f"[REDIS] Cache miss. Fetching data from redis for { token } " )
123183 key = self ._format_key (token )
124184 client = await self ._get_client ()
125185 self ._incr_calls ()
126186 data_raw = await client .get (key )
127187
128188 if not data_raw :
189+ # remember negative result briefly
190+ try :
191+ self ._missing_tokens [token ] = True
192+ except Exception :
193+ pass
129194 return None
130195
131196 try :
132197 data = json .loads (data_raw )
133198 if data .get ("authKey" ):
134199 data ["authKey" ] = self .decrypt_token (data ["authKey" ])
135- self ._payload_cache [token ] = data
136200 return data
137201 except (json .JSONDecodeError , InvalidToken ):
138202 return None
@@ -147,9 +211,17 @@ async def delete_token(self, token: str = None, key: str = None) -> None:
147211 self ._incr_calls ()
148212 await client .delete (key )
149213
150- # Invalidate local cache
151- if token and token in self ._payload_cache :
152- del self ._payload_cache [token ]
214+ # Invalidate async LRU cached reads
215+ try :
216+ self .get_user_data .cache_clear ()
217+ except Exception :
218+ pass
219+ # Remove from negative cache as token is deleted
220+ try :
221+ if token and token in self ._missing_tokens :
222+ del self ._missing_tokens [token ]
223+ except Exception :
224+ pass
153225
154226 async def iter_payloads (self , batch_size : int = 200 ) -> AsyncIterator [tuple [str , dict [str , Any ]]]:
155227 try :
@@ -185,9 +257,8 @@ async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str,
185257 payload ["authKey" ] = self .decrypt_token (payload ["authKey" ])
186258 except Exception :
187259 pass
188- # Update L1 cache (token only)
260+ # Token payload ready for consumer
189261 tok = k [len (self .KEY_PREFIX ) :] if k .startswith (self .KEY_PREFIX ) else k # noqa
190- self ._payload_cache [tok ] = payload
191262 yield k , payload
192263 buffer .clear ()
193264
@@ -213,7 +284,6 @@ async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str,
213284 except Exception :
214285 pass
215286 tok = k [len (self .KEY_PREFIX ) :] if k .startswith (self .KEY_PREFIX ) else k # noqa
216- self ._payload_cache [tok ] = payload
217287 yield k , payload
218288 except (redis .RedisError , OSError ) as exc :
219289 logger .warning (f"Failed to scan credential tokens: { exc } " )
0 commit comments