diff --git a/README.md b/README.md index 43677a5..a84fb9e 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().d ```bash # With per-user API key (includes rate limit headers) curl -H "X-API-Key: pfk_your_api_key_here" \ - http://localhost:8000/users/{user_id}/sleep?days=7 + http://localhost:8000/api/v1/users/{user_id}/sleep?days=7 # Response headers include: # X-RateLimit-Limit: 1000 @@ -211,7 +211,7 @@ Store `api_key` and `polar_user_id` for this user. Use the API key for all data ```bash curl -H "X-API-Key: pfk_abc123..." \ - "https://your-polar-server.com/users/12345678/sleep?days=7" + "https://your-polar-server.com/api/v1/users/12345678/sleep?days=7" ``` ### Polar Admin Setup @@ -226,43 +226,43 @@ https://your-polar-server.com/oauth/callback ```bash # Get key info -GET /users/{user_id}/api-key/info +GET /api/v1/users/{user_id}/api-key/info X-API-Key: pfk_... # Regenerate key (invalidates old key) -POST /users/{user_id}/api-key/regenerate +POST /api/v1/users/{user_id}/api-key/regenerate X-API-Key: pfk_... # Revoke key -POST /users/{user_id}/api-key/revoke +POST /api/v1/users/{user_id}/api-key/revoke X-API-Key: pfk_... ``` ## API Endpoints ```bash -# Health check +# Health check (no auth required) curl http://localhost:8000/health # Get sleep data (last 7 days) curl -H "X-API-Key: pfk_..." \ - "http://localhost:8000/users/{user_id}/sleep?days=7" + "http://localhost:8000/api/v1/users/{user_id}/sleep?days=7" # Get activity data curl -H "X-API-Key: pfk_..." \ - "http://localhost:8000/users/{user_id}/activity?days=7" + "http://localhost:8000/api/v1/users/{user_id}/activity?days=7" # Get nightly recharge (HRV) curl -H "X-API-Key: pfk_..." \ - "http://localhost:8000/users/{user_id}/recharge?days=7" + "http://localhost:8000/api/v1/users/{user_id}/recharge?days=7" # Get exercises curl -H "X-API-Key: pfk_..." \ - "http://localhost:8000/users/{user_id}/exercises?days=30" + "http://localhost:8000/api/v1/users/{user_id}/exercises?days=30" # Export summary curl -H "X-API-Key: pfk_..." \ - "http://localhost:8000/users/{user_id}/export/summary?days=30" + "http://localhost:8000/api/v1/users/{user_id}/export/summary?days=30" ``` ## Development diff --git a/src/polar_flow_server/admin/routes.py b/src/polar_flow_server/admin/routes.py index 68cc1f5..fba2067 100644 --- a/src/polar_flow_server/admin/routes.py +++ b/src/polar_flow_server/admin/routes.py @@ -1,10 +1,13 @@ """Admin panel routes.""" +import asyncio import csv import io +import logging import os import re import secrets +from collections import OrderedDict from datetime import UTC, date, datetime, timedelta from typing import Any from urllib.parse import urlencode @@ -52,9 +55,150 @@ from polar_flow_server.services.scheduler import get_scheduler from polar_flow_server.services.sync import SyncService -# In-memory OAuth state storage (for self-hosted single-instance use) -# In production SaaS, use Redis or database with TTL -_oauth_states: dict[str, datetime] = {} +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Bounded TTL Cache for OAuth States (prevents memory exhaustion) +# ============================================================================= + + +class BoundedTTLCache: + """Simple bounded cache with TTL for OAuth states. + + Prevents memory exhaustion attacks by limiting max entries. + Automatically evicts expired entries on access. + Thread-safe via asyncio lock. + """ + + def __init__(self, maxsize: int = 100, ttl_minutes: int = 10) -> None: + self._cache: OrderedDict[str, datetime] = OrderedDict() + self._maxsize = maxsize + self._ttl = timedelta(minutes=ttl_minutes) + self._lock = asyncio.Lock() + + async def set(self, key: str, expires_at: datetime | None = None) -> None: + """Add or update a key with expiry time.""" + async with self._lock: + self._cleanup_expired() + # If at max, evict oldest entry and log warning + if len(self._cache) >= self._maxsize: + logger.warning(f"OAuth state cache full ({self._maxsize}), evicting oldest entries") + while len(self._cache) >= self._maxsize: + self._cache.popitem(last=False) + self._cache[key] = expires_at or (datetime.now(UTC) + self._ttl) + + async def get(self, key: str) -> datetime | None: + """Get expiry time for a key, or None if not found/expired.""" + async with self._lock: + self._cleanup_expired() + return self._cache.get(key) + + async def pop(self, key: str) -> datetime | None: + """Remove and return expiry time for a key.""" + async with self._lock: + return self._cache.pop(key, None) + + async def contains(self, key: str) -> bool: + """Check if key exists (async version of __contains__).""" + async with self._lock: + self._cleanup_expired() + return key in self._cache + + def _cleanup_expired(self) -> None: + """Remove expired entries. Must be called with lock held.""" + now = datetime.now(UTC) + # Use dict comprehension for atomic update + self._cache = OrderedDict((k, exp) for k, exp in self._cache.items() if exp >= now) + + +# OAuth state storage with bounded size (prevents memory exhaustion) +_oauth_states = BoundedTTLCache(maxsize=100, ttl_minutes=10) + + +# ============================================================================= +# Login Rate Limiting (prevents brute force attacks) +# ============================================================================= + + +class LoginRateLimiter: + """Simple in-memory rate limiter for login attempts. + + Tracks failed attempts by IP address and locks out after threshold. + Thread-safe via asyncio lock. + """ + + def __init__( + self, max_attempts: int = 5, lockout_minutes: int = 15, cleanup_interval: int = 100 + ) -> None: + self._attempts: dict[str, list[datetime]] = {} + self._lockouts: dict[str, datetime] = {} + self._max_attempts = max_attempts + self._lockout_duration = timedelta(minutes=lockout_minutes) + self._attempt_window = timedelta(minutes=15) + self._cleanup_counter = 0 + self._cleanup_interval = cleanup_interval + self._lock = asyncio.Lock() + + async def is_locked_out(self, ip: str) -> bool: + """Check if IP is currently locked out.""" + async with self._lock: + self._maybe_cleanup() + lockout_until = self._lockouts.get(ip) + if lockout_until and lockout_until > datetime.now(UTC): + return True + # Clear expired lockout + if lockout_until: + del self._lockouts[ip] + return False + + async def record_failure(self, ip: str) -> bool: + """Record a failed login attempt. Returns True if now locked out.""" + async with self._lock: + now = datetime.now(UTC) + self._maybe_cleanup() + + # Get recent attempts within window + attempts = self._attempts.get(ip, []) + cutoff = now - self._attempt_window + attempts = [t for t in attempts if t > cutoff] + attempts.append(now) + self._attempts[ip] = attempts + + # Check if should lock out + if len(attempts) >= self._max_attempts: + self._lockouts[ip] = now + self._lockout_duration + logger.warning(f"Login rate limit exceeded for IP {ip}, locked out") + return True + return False + + async def record_success(self, ip: str) -> None: + """Clear attempts on successful login.""" + async with self._lock: + self._attempts.pop(ip, None) + self._lockouts.pop(ip, None) + + def _maybe_cleanup(self) -> None: + """Periodically clean up old entries. Must be called with lock held.""" + self._cleanup_counter += 1 + if self._cleanup_counter < self._cleanup_interval: + return + self._cleanup_counter = 0 + + now = datetime.now(UTC) + cutoff = now - self._attempt_window + + # Atomic cleanup using dict comprehension + self._attempts = { + ip: [t for t in attempts if t > cutoff] + for ip, attempts in self._attempts.items() + if any(t > cutoff for t in attempts) + } + self._lockouts = {ip: exp for ip, exp in self._lockouts.items() if exp >= now} + + +# Global rate limiter instance +_login_rate_limiter = LoginRateLimiter(max_attempts=5, lockout_minutes=15) # Simple email validation pattern _EMAIL_PATTERN = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") @@ -347,11 +491,50 @@ async def login_form(request: Request[Any, Any, Any], session: AsyncSession) -> ) +def _get_client_ip(request: Request[Any, Any, Any]) -> str: + """Get client IP from request, checking proxy headers only from trusted sources. + + Only trusts X-Forwarded-For/X-Real-IP when request comes from localhost + (i.e., from a reverse proxy like nginx/Coolify running on the same host). + This prevents IP spoofing attacks where attackers set fake headers. + """ + client = request.client + direct_ip = client.host if client else "unknown" + + # Only trust proxy headers if request comes from localhost (reverse proxy) + trusted_proxies = {"127.0.0.1", "::1", "localhost"} + if direct_ip in trusted_proxies: + # Check X-Forwarded-For header (set by proxies like nginx, Coolify) + forwarded_for = request.headers.get("x-forwarded-for") + if forwarded_for: + # Take the first IP (original client) + return forwarded_for.split(",")[0].strip() + # Check X-Real-IP header + real_ip = request.headers.get("x-real-ip") + if real_ip: + return real_ip + + return direct_ip + + @post("/login", sync_to_thread=False) async def login_submit( request: Request[Any, Any, Any], session: AsyncSession ) -> Template | Redirect: """Process login form submission.""" + client_ip = _get_client_ip(request) + + # Check if IP is locked out due to too many failed attempts + if await _login_rate_limiter.is_locked_out(client_ip): + return Template( + template_name="admin/login.html", + context={ + "error": "Too many failed attempts. Please try again later.", + "email": "", + "csrf_token": _get_csrf_token(request), + }, + ) + form_data = await request.form() email = form_data.get("email", "").strip() password = form_data.get("password", "") @@ -368,6 +551,8 @@ async def login_submit( admin = await authenticate_admin(str(email), str(password), session) if not admin: + # Record failed attempt + await _login_rate_limiter.record_failure(client_ip) return Template( template_name="admin/login.html", context={ @@ -377,6 +562,8 @@ async def login_submit( }, ) + # Successful login - clear any failed attempts + await _login_rate_limiter.record_success(client_ip) login_admin(request, admin) return Redirect(path="/admin", status_code=HTTP_303_SEE_OTHER) @@ -777,15 +964,9 @@ async def oauth_authorize(request: Request[Any, Any, Any], session: AsyncSession # No OAuth credentials configured, redirect to setup return Redirect(path="/admin", status_code=HTTP_303_SEE_OTHER) - # Generate CSRF state token + # Generate CSRF state token (BoundedTTLCache handles cleanup and size limits) state = secrets.token_urlsafe(32) - _oauth_states[state] = datetime.now(UTC) + timedelta(minutes=10) - - # Clean up expired states - now = datetime.now(UTC) - expired = [s for s, exp in _oauth_states.items() if exp < now] - for s in expired: - del _oauth_states[s] + await _oauth_states.set(state) # Build authorization URL with state for CSRF protection base_url = _get_base_url(request) @@ -819,20 +1000,19 @@ async def oauth_callback( ) # Validate CSRF state token - if not state or state not in _oauth_states: + if not state or not await _oauth_states.contains(state): return Template( template_name="admin/partials/sync_error.html", context={"error": "Invalid OAuth state - possible CSRF attack. Please try again."}, ) - # Check state hasn't expired and remove it (one-time use) - if _oauth_states[state] < datetime.now(UTC): - del _oauth_states[state] + # Get and remove state (one-time use) - also checks expiry + state_expires = await _oauth_states.pop(state) + if state_expires and state_expires < datetime.now(UTC): return Template( template_name="admin/partials/sync_error.html", context={"error": "OAuth state expired. Please try again."}, ) - del _oauth_states[state] # Get OAuth credentials from database stmt = select(AppSettings).where(AppSettings.id == 1) diff --git a/src/polar_flow_server/api/__init__.py b/src/polar_flow_server/api/__init__.py index fccc1ea..bdbca81 100644 --- a/src/polar_flow_server/api/__init__.py +++ b/src/polar_flow_server/api/__init__.py @@ -1,5 +1,7 @@ """API routes.""" +from litestar import Router + from polar_flow_server.api.baselines import baselines_router from polar_flow_server.api.data import data_router from polar_flow_server.api.health import health_router @@ -9,17 +11,24 @@ from polar_flow_server.api.sleep import sleep_router from polar_flow_server.api.sync import sync_router -# All API routers -api_routers = [ - health_router, +# Versioned API routers (user data endpoints) +# These get the /api/v1 prefix +_v1_routers = [ sleep_router, sync_router, data_router, baselines_router, # Analytics baselines patterns_router, # Pattern detection and anomalies insights_router, # Unified insights API - oauth_router, # OAuth flow and code exchange keys_router, # Key management (regenerate, revoke, status) ] +api_v1_router = Router(path="/api/v1", route_handlers=_v1_routers) + +# Export: health (root), oauth (root), v1 (prefixed) +# - health_router: /health - no auth needed, no version prefix +# - oauth_router: /oauth/* - external OAuth flow, no version prefix +# - api_v1_router: /api/v1/* - all user data endpoints +api_routers = [health_router, oauth_router, api_v1_router] + __all__ = ["api_routers"] diff --git a/src/polar_flow_server/api/keys.py b/src/polar_flow_server/api/keys.py index 009f34e..620461c 100644 --- a/src/polar_flow_server/api/keys.py +++ b/src/polar_flow_server/api/keys.py @@ -8,11 +8,13 @@ - SaaS OAuth start (for external clients like Laravel) """ +import asyncio import logging import secrets +from collections import OrderedDict from datetime import UTC, datetime, timedelta from typing import Any -from urllib.parse import urlencode +from urllib.parse import urlencode, urlparse from litestar import Router, get, post from litestar.connection import Request @@ -32,17 +34,125 @@ revoke_api_key, ) from polar_flow_server.core.auth import per_user_api_key_guard +from polar_flow_server.core.config import settings from polar_flow_server.models.api_key import APIKey from polar_flow_server.models.settings import AppSettings from polar_flow_server.models.user import User -# Store OAuth states with their callback URLs -# Format: {state: {"expires_at": datetime, "callback_url": str, "client_id": str | None}} -_saas_oauth_states: dict[str, dict[str, Any]] = {} - logger = logging.getLogger(__name__) +# ============================================================================= +# Bounded TTL Cache for OAuth States (prevents memory exhaustion) +# ============================================================================= + + +class BoundedOAuthStateCache: + """Bounded cache for SaaS OAuth states with TTL. + + Prevents memory exhaustion attacks by limiting max entries. + Stores callback_url and client_id along with expiry. + Thread-safe via asyncio lock. + """ + + def __init__(self, maxsize: int = 100, ttl_minutes: int = 10) -> None: + self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict() + self._maxsize = maxsize + self._ttl = timedelta(minutes=ttl_minutes) + self._lock = asyncio.Lock() + + async def set(self, key: str, callback_url: str, client_id: str | None = None) -> None: + """Add a new OAuth state with its associated data.""" + async with self._lock: + self._cleanup_expired() + # If at max, evict oldest entry and log warning + if len(self._cache) >= self._maxsize: + logger.warning( + f"SaaS OAuth state cache full ({self._maxsize}), evicting oldest entries" + ) + while len(self._cache) >= self._maxsize: + self._cache.popitem(last=False) + self._cache[key] = { + "expires_at": datetime.now(UTC) + self._ttl, + "callback_url": callback_url, + "client_id": client_id, + } + + async def get(self, key: str) -> dict[str, Any] | None: + """Get state data, or None if not found/expired.""" + async with self._lock: + self._cleanup_expired() + return self._cache.get(key) + + async def pop(self, key: str) -> dict[str, Any] | None: + """Remove and return state data.""" + async with self._lock: + return self._cache.pop(key, None) + + async def contains(self, key: str) -> bool: + """Check if key exists (async version of __contains__).""" + async with self._lock: + self._cleanup_expired() + return key in self._cache + + def _cleanup_expired(self) -> None: + """Remove expired entries. Must be called with lock held.""" + now = datetime.now(UTC) + # Use dict comprehension for atomic update + self._cache = OrderedDict((k, v) for k, v in self._cache.items() if v["expires_at"] >= now) + + +# OAuth state storage with bounded size (prevents memory exhaustion) +_saas_oauth_states = BoundedOAuthStateCache(maxsize=100, ttl_minutes=10) + + +# ============================================================================= +# Callback URL Validation +# ============================================================================= + + +def _is_localhost(netloc: str) -> bool: + """Check if netloc is localhost (with or without port).""" + # Remove port if present + host = netloc.split(":")[0].lower() + return host in {"localhost", "127.0.0.1", "::1", "[::1]"} + + +def _validate_callback_url(callback_url: str) -> tuple[bool, str]: + """Validate that callback_url is well-formed and secure. + + Returns (is_valid, error_message). + """ + # Length check to prevent DoS via extremely long URLs + if len(callback_url) > 2048: + return False, "URL too long (max 2048 characters)" + + try: + parsed = urlparse(callback_url) + except Exception: + return False, "Invalid URL format" + + # Must have scheme and netloc + if not parsed.scheme or not parsed.netloc: + return False, "URL must include scheme and host (e.g., https://example.com/callback)" + + # Only allow http and https schemes + if parsed.scheme not in {"http", "https"}: + return False, "Only http and https schemes are allowed" + + # Production check: use deployment_mode or presence of base_url + is_production = settings.deployment_mode.value == "saas" or settings.base_url is not None + + if parsed.scheme == "http": + if is_production: + return False, "HTTPS required for callback URLs in production" + # In development, only allow http for actual localhost + if not _is_localhost(parsed.netloc): + return False, "HTTP only allowed for localhost in development" + + return True, "" + + # ============================================================================== # Request/Response Models # ============================================================================== @@ -121,6 +231,15 @@ async def oauth_start_saas( callback_url: Where to redirect after OAuth (your app's callback endpoint) client_id: Optional client identifier for validation during exchange """ + # Validate callback URL (includes length check) + is_valid, error_msg = _validate_callback_url(callback_url) + if not is_valid: + raise NotAuthorizedException(f"Invalid callback_url: {error_msg}") + + # Validate client_id length to prevent DoS + if client_id and len(client_id) > 255: + raise NotAuthorizedException("client_id too long (max 255 characters)") + # Get OAuth credentials from database stmt = select(AppSettings).where(AppSettings.id == 1) result = await session.execute(stmt) @@ -129,19 +248,9 @@ async def oauth_start_saas( if not app_settings or not app_settings.polar_client_id: raise NotFoundException("OAuth credentials not configured on server") - # Generate CSRF state token + # Generate CSRF state token (BoundedOAuthStateCache handles cleanup and size limits) state = secrets.token_urlsafe(32) - _saas_oauth_states[state] = { - "expires_at": datetime.now(UTC) + timedelta(minutes=10), - "callback_url": callback_url, - "client_id": client_id, - } - - # Clean up expired states - now = datetime.now(UTC) - expired = [s for s, data in _saas_oauth_states.items() if data["expires_at"] < now] - for s in expired: - del _saas_oauth_states[s] + await _saas_oauth_states.set(state, callback_url, client_id) # Build authorization URL - extract host/scheme from request headers # Coolify/nginx sets x-forwarded-* headers @@ -183,31 +292,34 @@ async def oauth_callback_saas( from polar_flow_server.core.security import token_encryption - # Handle errors + # Handle errors - try to redirect to callback with error if we have state if error or not code: - # Redirect to callback with error if we have oauth_state - if oauth_state and oauth_state in _saas_oauth_states: - callback_url = _saas_oauth_states[oauth_state]["callback_url"] - del _saas_oauth_states[oauth_state] - error_params = urlencode({"error": error or "no_code", "status": "failed"}) - return Redirect(path=f"{callback_url}?{error_params}", status_code=HTTP_303_SEE_OTHER) + if oauth_state: + state_data = await _saas_oauth_states.pop(oauth_state) + if state_data: + callback_url = state_data["callback_url"] + error_params = urlencode({"error": error or "no_code", "status": "failed"}) + return Redirect( + path=f"{callback_url}?{error_params}", status_code=HTTP_303_SEE_OTHER + ) raise NotAuthorizedException(f"OAuth authorization failed: {error or 'No code received'}") - # Validate oauth_state - if not oauth_state or oauth_state not in _saas_oauth_states: + # Validate oauth_state exists + if not oauth_state or not await _saas_oauth_states.contains(oauth_state): raise NotAuthorizedException("Invalid OAuth state - possible CSRF attack") - state_data = _saas_oauth_states[oauth_state] + # Get and remove state data (one-time use) + state_data = await _saas_oauth_states.pop(oauth_state) + if not state_data: + raise NotAuthorizedException("OAuth state not found") # Check expiry if state_data["expires_at"] < datetime.now(UTC): - del _saas_oauth_states[oauth_state] raise NotAuthorizedException("OAuth state expired. Please try again.") - # Get callback info and remove oauth_state (one-time use) + # Get callback info callback_url = state_data["callback_url"] stored_client_id = state_data["client_id"] - del _saas_oauth_states[oauth_state] # Get OAuth credentials stmt = select(AppSettings).where(AppSettings.id == 1) diff --git a/src/polar_flow_server/app.py b/src/polar_flow_server/app.py index 204a0df..07da495 100644 --- a/src/polar_flow_server/app.py +++ b/src/polar_flow_server/app.py @@ -98,28 +98,36 @@ def create_app() -> Litestar: # In production with multiple instances, use Redis instead session_store = MemoryStore() - # Session middleware config + # Session middleware config with explicit security settings + is_debug = settings.log_level == "DEBUG" session_config = ServerSideSessionConfig( key=settings.get_session_secret(), store="session_store", max_age=86400, # 24 hours + secure=not is_debug, # HTTPS only in production + httponly=True, # Prevent JS access + samesite="lax", # CSRF protection ) # CSRF protection config + # Note: HTMX and our JS already send CSRF token in X-CSRF-Token header, + # so most admin routes can (and should) require CSRF validation. csrf_config = CSRFConfig( secret=settings.get_session_secret(), cookie_name="csrf_token", header_name="X-CSRF-Token", exclude=[ - "/admin/login", # Login form - entry point, no session yet - "/admin/setup", # Setup flow - entry point, no session yet - "/admin/oauth/callback", # OAuth callback from Polar (admin dashboard) - "/admin/settings", # Settings pages (reset-oauth, etc.) - "/admin/sync", # Sync trigger from dashboard - "/admin/logout", # Logout action - "/admin/api-keys/", # API key management (uses session auth) - "/oauth/", # OAuth endpoints for SaaS (callback, exchange, start) - "/users/", # API routes use API key auth, not sessions + # Entry points (no session yet) + "/admin/login", + "/admin/setup", + # External OAuth callbacks (redirects from Polar) + "/admin/oauth/callback", + "/oauth/", # SaaS OAuth flow (callback, exchange, start) + # Safe to exclude (just destroys session) + "/admin/logout", + # API routes use API key auth, not CSRF + "/api/v1/users/", + # Health check (no auth needed) "/health", ], ) @@ -148,7 +156,7 @@ def create_app() -> Litestar: middleware=[session_config.middleware, RateLimitHeadersMiddleware], csrf_config=csrf_config, stores={"session_store": session_store}, - debug=settings.log_level == "DEBUG", + debug=is_debug, ) diff --git a/src/polar_flow_server/templates/base.html b/src/polar_flow_server/templates/base.html index 42dce52..578df75 100644 --- a/src/polar_flow_server/templates/base.html +++ b/src/polar_flow_server/templates/base.html @@ -4,7 +4,11 @@ {% block title %}polar-flow-server{% endblock %} - + + + - - + + +