Skip to content

Commit 5382408

Browse files
feat: Enhance Redis client management and implement rate limiting for token requests
1 parent 213531a commit 5382408

File tree

6 files changed

+172
-17
lines changed

6 files changed

+172
-17
lines changed

app/api/endpoints/health.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,36 @@
11
from fastapi import APIRouter
2+
from loguru import logger
3+
4+
from app.services.token_store import token_store
25

36
router = APIRouter(tags=["health"])
47

58

69
@router.get("/health", summary="Simple readiness probe")
710
async def health_check() -> dict[str, str]:
811
return {"status": "ok"}
12+
13+
14+
@router.get("/metrics", summary="Runtime metrics (lightweight)")
15+
async def metrics() -> dict:
16+
"""Return lightweight runtime metrics useful for diagnosing Redis connection growth."""
17+
try:
18+
client = await token_store._get_client()
19+
except Exception as exc:
20+
logger.warning(f"Failed to fetch Redis client for metrics: {exc}")
21+
return {"redis": "unavailable"}
22+
23+
metrics: dict = {}
24+
try:
25+
info = await client.info(section="clients")
26+
metrics["redis_connected_clients"] = int(info.get("connected_clients", 0))
27+
except Exception as exc:
28+
logger.warning(f"Failed to read Redis INFO clients: {exc}")
29+
metrics["redis_connected_clients"] = "error"
30+
31+
try:
32+
metrics["per_request_redis_calls_last"] = token_store.get_call_count()
33+
except Exception:
34+
metrics["per_request_redis_calls_last"] = "error"
35+
36+
return metrics

app/core/app.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import asynccontextmanager
44
from pathlib import Path
55

6+
from cachetools import TTLCache
67
from fastapi import FastAPI, Request
78
from fastapi.middleware.cors import CORSMiddleware
89
from fastapi.responses import HTMLResponse
@@ -61,6 +62,12 @@ def _on_done(t: asyncio.Task):
6162
await catalog_updater.stop()
6263
catalog_updater = None
6364
logger.info("Background catalog updates stopped")
65+
# Close shared token store Redis client
66+
try:
67+
await token_store.close()
68+
logger.info("TokenStore Redis client closed")
69+
except Exception as exc:
70+
logger.warning(f"Failed to close TokenStore Redis client: {exc}")
6471

6572

6673
if settings.APP_ENV != "development":
@@ -73,6 +80,8 @@ def _on_done(t: asyncio.Task):
7380
description="Stremio catalog addon for movie and series recommendations",
7481
version=__version__,
7582
lifespan=lifespan,
83+
docs_url=None if settings.APP_ENV != "development" else "/docs",
84+
redoc_url=None if settings.APP_ENV != "development" else "/redoc",
7685
)
7786

7887
app.add_middleware(
@@ -84,6 +93,33 @@ def _on_done(t: asyncio.Task):
8493
)
8594

8695

96+
# Simple IP-based rate limiter for repeated probes of missing tokens.
97+
# Tracks recent failure counts per IP to avoid expensive repeated requests.
98+
_ip_failure_cache: TTLCache = TTLCache(maxsize=10000, ttl=600)
99+
_IP_FAILURE_THRESHOLD = 8
100+
101+
102+
@app.middleware("http")
103+
async def block_missing_token_middleware(request: Request, call_next):
104+
# Extract first path segment which is commonly the token in addon routes
105+
path = request.url.path.lstrip("/")
106+
seg = path.split("/", 1)[0] if path else ""
107+
try:
108+
# If token is known-missing, short-circuit and track IP failures
109+
if seg and seg in token_store._missing_tokens:
110+
ip = request.client.host if request.client else "unknown"
111+
try:
112+
_ip_failure_cache[ip] = _ip_failure_cache.get(ip, 0) + 1
113+
except Exception:
114+
pass
115+
if _ip_failure_cache.get(ip, 0) > _IP_FAILURE_THRESHOLD:
116+
return HTMLResponse(content="Too many requests", status_code=429)
117+
return HTMLResponse(content="Invalid token", status_code=401)
118+
except Exception:
119+
pass
120+
return await call_next(request)
121+
122+
87123
# Middleware to track per-request Redis calls and attach as response header for diagnostics
88124
@app.middleware("http")
89125
async def redis_calls_middleware(request: Request, call_next):

app/core/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class Settings(BaseSettings):
2020
ADDON_ID: str = "com.bimal.watchly"
2121
ADDON_NAME: str = "Watchly"
2222
REDIS_URL: str = "redis://redis:6379/0"
23+
# Maximum number of connections Redis client will open per process
24+
# Set conservatively to avoid unbounded connection growth under high concurrency
25+
REDIS_MAX_CONNECTIONS: int = 20
26+
# If total connected clients reported by Redis exceeds this, background
27+
# Redis-heavy jobs will back off. Tune according to your Redis capacity.
28+
REDIS_CONNECTIONS_THRESHOLD: int = 100
2329
REDIS_TOKEN_KEY: str = "watchly:token:"
2430
TOKEN_SALT: str = "change-me"
2531
TOKEN_TTL_SECONDS: int = 0 # 0 = never expire

app/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.1.0"
1+
__version__ = "1.1.1"

app/services/catalog_updater.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ async def _update_safe(key: str, payload: dict[str, Any]) -> None:
143143
logger.error(f"Background refresh failed for {redact_token(key)}: {exc}", exc_info=True)
144144

145145
try:
146+
# Check Redis connected clients and back off if overloaded
147+
try:
148+
client = await token_store._get_client()
149+
info = await client.info(section="clients")
150+
connected = int(info.get("connected_clients", 0))
151+
threshold = getattr(settings, "REDIS_CONNECTIONS_THRESHOLD", 1000)
152+
if connected > threshold:
153+
logger.warning(
154+
f"Redis connected clients {connected} exceed threshold {threshold}; skipping"
155+
"background refresh."
156+
)
157+
return
158+
except Exception as exc:
159+
logger.warning(f"Failed to check Redis client info before refresh: {exc}")
160+
146161
async for key, payload in token_store.iter_payloads():
147162
# Extract token from redis key prefix
148163
prefix = token_store.KEY_PREFIX

app/services/token_store.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
import redis.asyncio as redis
8+
from async_lru import alru_cache
89
from cachetools import TTLCache
910
from cryptography.fernet import Fernet, InvalidToken
1011
from cryptography.hazmat.primitives import hashes
@@ -22,9 +23,9 @@ class TokenStore:
2223

2324
def __init__(self) -> None:
2425
self._client: redis.Redis | None = None
25-
# Cache decrypted payloads for 1 day (86400s) to reduce Redis hits
26-
# Max size 5000 allows many active users without eviction
27-
self._payload_cache: TTLCache = TTLCache(maxsize=5000, ttl=86400)
26+
# Negative cache for missing tokens to avoid repeated Redis GETs
27+
# when external probes request non-existent tokens.
28+
self._missing_tokens: TTLCache = TTLCache(maxsize=10000, ttl=3600)
2829
# per-request redis call counter (context-local)
2930
self._redis_calls_var: contextvars.ContextVar[int] = contextvars.ContextVar("watchly_redis_calls", default=0)
3031

@@ -66,15 +67,59 @@ def decrypt_token(self, enc: str) -> str:
6667
async def _get_client(self) -> redis.Redis:
6768
if self._client is None:
6869
# Add socket timeouts to avoid hanging on Redis operations
70+
import traceback
71+
72+
logger.info("Creating shared Redis client")
73+
# Limit the number of pooled connections to avoid unbounded growth
74+
# `max_connections` is forwarded to ConnectionPool.from_url
6975
self._client = redis.from_url(
7076
settings.REDIS_URL,
7177
decode_responses=True,
7278
encoding="utf-8",
7379
socket_connect_timeout=5,
7480
socket_timeout=5,
81+
max_connections=getattr(settings, "REDIS_MAX_CONNECTIONS", 100),
82+
health_check_interval=30,
83+
socket_keepalive=True,
7584
)
85+
# If _get_client is called multiple times in different contexts it
86+
# could indicate multiple processes/threads or a bug opening
87+
# additional clients; log a stacktrace for debugging.
88+
if getattr(self, "_creation_count", None) is None:
89+
self._creation_count = 1
90+
else:
91+
self._creation_count += 1
92+
logger.warning(
93+
f"Redis client creation invoked again (count={self._creation_count})."
94+
f" Stack:\n{''.join(traceback.format_stack())}"
95+
)
7696
return self._client
7797

98+
async def close(self) -> None:
99+
"""Close and disconnect the shared Redis client (call on shutdown)."""
100+
if self._client is None:
101+
return
102+
try:
103+
logger.info("Closing shared Redis client")
104+
# Close client and disconnect underlying pool
105+
try:
106+
await self._client.close()
107+
except Exception:
108+
pass
109+
try:
110+
pool = getattr(self._client, "connection_pool", None)
111+
if pool is not None:
112+
# connection_pool.disconnect may be a coroutine in some redis implementations
113+
disconnect = getattr(pool, "disconnect", None)
114+
if disconnect:
115+
res = disconnect()
116+
if hasattr(res, "__await__"):
117+
await res
118+
except Exception:
119+
pass
120+
finally:
121+
self._client = None
122+
78123
def _format_key(self, token: str) -> str:
79124
"""Format Redis key from token."""
80125
return f"{self.KEY_PREFIX}{token}"
@@ -109,30 +154,49 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str:
109154
self._incr_calls()
110155
await client.set(key, json_str)
111156

112-
# Update cache with the payload
113-
self._payload_cache[token] = payload
157+
# Invalidate async LRU cached reads so future reads use the updated payload
158+
try:
159+
self.get_user_data.cache_clear()
160+
except Exception:
161+
pass
162+
163+
# Ensure we remove from negative cache so new value is read next time
164+
try:
165+
if token in self._missing_tokens:
166+
del self._missing_tokens[token]
167+
except Exception:
168+
pass
114169

115170
return token
116171

172+
@alru_cache(maxsize=5000)
117173
async def get_user_data(self, token: str) -> dict[str, Any] | None:
118-
if token in self._payload_cache:
119-
logger.info(f"[REDIS] Using cached redis data {token}")
120-
return self._payload_cache[token]
121-
logger.info(f"[REDIS]Caching Failed. Fetching data from redis for {token}")
174+
# Short-circuit for tokens known to be missing
175+
try:
176+
if token in self._missing_tokens:
177+
logger.debug(f"[REDIS] Negative cache hit for missing token {token}")
178+
return None
179+
except Exception:
180+
pass
122181

182+
logger.debug(f"[REDIS] Cache miss. Fetching data from redis for {token}")
123183
key = self._format_key(token)
124184
client = await self._get_client()
125185
self._incr_calls()
126186
data_raw = await client.get(key)
127187

128188
if not data_raw:
189+
# remember negative result briefly
190+
try:
191+
self._missing_tokens[token] = True
192+
except Exception:
193+
pass
129194
return None
130195

131196
try:
132197
data = json.loads(data_raw)
133198
if data.get("authKey"):
134199
data["authKey"] = self.decrypt_token(data["authKey"])
135-
self._payload_cache[token] = data
136200
return data
137201
except (json.JSONDecodeError, InvalidToken):
138202
return None
@@ -147,9 +211,17 @@ async def delete_token(self, token: str = None, key: str = None) -> None:
147211
self._incr_calls()
148212
await client.delete(key)
149213

150-
# Invalidate local cache
151-
if token and token in self._payload_cache:
152-
del self._payload_cache[token]
214+
# Invalidate async LRU cached reads
215+
try:
216+
self.get_user_data.cache_clear()
217+
except Exception:
218+
pass
219+
# Remove from negative cache as token is deleted
220+
try:
221+
if token and token in self._missing_tokens:
222+
del self._missing_tokens[token]
223+
except Exception:
224+
pass
153225

154226
async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str, dict[str, Any]]]:
155227
try:
@@ -185,9 +257,8 @@ async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str,
185257
payload["authKey"] = self.decrypt_token(payload["authKey"])
186258
except Exception:
187259
pass
188-
# Update L1 cache (token only)
260+
# Token payload ready for consumer
189261
tok = k[len(self.KEY_PREFIX) :] if k.startswith(self.KEY_PREFIX) else k # noqa
190-
self._payload_cache[tok] = payload
191262
yield k, payload
192263
buffer.clear()
193264

@@ -213,7 +284,6 @@ async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str,
213284
except Exception:
214285
pass
215286
tok = k[len(self.KEY_PREFIX) :] if k.startswith(self.KEY_PREFIX) else k # noqa
216-
self._payload_cache[tok] = payload
217287
yield k, payload
218288
except (redis.RedisError, OSError) as exc:
219289
logger.warning(f"Failed to scan credential tokens: {exc}")

0 commit comments

Comments
 (0)