11"""Admin panel routes."""
22
3+ import asyncio
34import csv
45import io
6+ import logging
57import os
68import re
79import secrets
10+ from collections import OrderedDict
811from datetime import UTC , date , datetime , timedelta
912from typing import Any
1013from urllib .parse import urlencode
5255from polar_flow_server .services .scheduler import get_scheduler
5356from polar_flow_server .services .sync import SyncService
5457
55- # In-memory OAuth state storage (for self-hosted single-instance use)
56- # In production SaaS, use Redis or database with TTL
57- _oauth_states : dict [str , datetime ] = {}
58+ logger = logging .getLogger (__name__ )
59+
60+
61+ # =============================================================================
62+ # Bounded TTL Cache for OAuth States (prevents memory exhaustion)
63+ # =============================================================================
64+
65+
66+ class BoundedTTLCache :
67+ """Simple bounded cache with TTL for OAuth states.
68+
69+ Prevents memory exhaustion attacks by limiting max entries.
70+ Automatically evicts expired entries on access.
71+ Thread-safe via asyncio lock.
72+ """
73+
74+ def __init__ (self , maxsize : int = 100 , ttl_minutes : int = 10 ) -> None :
75+ self ._cache : OrderedDict [str , datetime ] = OrderedDict ()
76+ self ._maxsize = maxsize
77+ self ._ttl = timedelta (minutes = ttl_minutes )
78+ self ._lock = asyncio .Lock ()
79+
80+ async def set (self , key : str , expires_at : datetime | None = None ) -> None :
81+ """Add or update a key with expiry time."""
82+ async with self ._lock :
83+ self ._cleanup_expired ()
84+ # If at max, evict oldest entry and log warning
85+ if len (self ._cache ) >= self ._maxsize :
86+ logger .warning (f"OAuth state cache full ({ self ._maxsize } ), evicting oldest entries" )
87+ while len (self ._cache ) >= self ._maxsize :
88+ self ._cache .popitem (last = False )
89+ self ._cache [key ] = expires_at or (datetime .now (UTC ) + self ._ttl )
90+
91+ async def get (self , key : str ) -> datetime | None :
92+ """Get expiry time for a key, or None if not found/expired."""
93+ async with self ._lock :
94+ self ._cleanup_expired ()
95+ return self ._cache .get (key )
96+
97+ async def pop (self , key : str ) -> datetime | None :
98+ """Remove and return expiry time for a key."""
99+ async with self ._lock :
100+ return self ._cache .pop (key , None )
101+
102+ async def contains (self , key : str ) -> bool :
103+ """Check if key exists (async version of __contains__)."""
104+ async with self ._lock :
105+ self ._cleanup_expired ()
106+ return key in self ._cache
107+
108+ def _cleanup_expired (self ) -> None :
109+ """Remove expired entries. Must be called with lock held."""
110+ now = datetime .now (UTC )
111+ # Use dict comprehension for atomic update
112+ self ._cache = OrderedDict ((k , exp ) for k , exp in self ._cache .items () if exp >= now )
113+
114+
115+ # OAuth state storage with bounded size (prevents memory exhaustion)
116+ _oauth_states = BoundedTTLCache (maxsize = 100 , ttl_minutes = 10 )
117+
118+
119+ # =============================================================================
120+ # Login Rate Limiting (prevents brute force attacks)
121+ # =============================================================================
122+
123+
124+ class LoginRateLimiter :
125+ """Simple in-memory rate limiter for login attempts.
126+
127+ Tracks failed attempts by IP address and locks out after threshold.
128+ Thread-safe via asyncio lock.
129+ """
130+
131+ def __init__ (
132+ self , max_attempts : int = 5 , lockout_minutes : int = 15 , cleanup_interval : int = 100
133+ ) -> None :
134+ self ._attempts : dict [str , list [datetime ]] = {}
135+ self ._lockouts : dict [str , datetime ] = {}
136+ self ._max_attempts = max_attempts
137+ self ._lockout_duration = timedelta (minutes = lockout_minutes )
138+ self ._attempt_window = timedelta (minutes = 15 )
139+ self ._cleanup_counter = 0
140+ self ._cleanup_interval = cleanup_interval
141+ self ._lock = asyncio .Lock ()
142+
143+ async def is_locked_out (self , ip : str ) -> bool :
144+ """Check if IP is currently locked out."""
145+ async with self ._lock :
146+ self ._maybe_cleanup ()
147+ lockout_until = self ._lockouts .get (ip )
148+ if lockout_until and lockout_until > datetime .now (UTC ):
149+ return True
150+ # Clear expired lockout
151+ if lockout_until :
152+ del self ._lockouts [ip ]
153+ return False
154+
155+ async def record_failure (self , ip : str ) -> bool :
156+ """Record a failed login attempt. Returns True if now locked out."""
157+ async with self ._lock :
158+ now = datetime .now (UTC )
159+ self ._maybe_cleanup ()
160+
161+ # Get recent attempts within window
162+ attempts = self ._attempts .get (ip , [])
163+ cutoff = now - self ._attempt_window
164+ attempts = [t for t in attempts if t > cutoff ]
165+ attempts .append (now )
166+ self ._attempts [ip ] = attempts
167+
168+ # Check if should lock out
169+ if len (attempts ) >= self ._max_attempts :
170+ self ._lockouts [ip ] = now + self ._lockout_duration
171+ logger .warning (f"Login rate limit exceeded for IP { ip } , locked out" )
172+ return True
173+ return False
174+
175+ async def record_success (self , ip : str ) -> None :
176+ """Clear attempts on successful login."""
177+ async with self ._lock :
178+ self ._attempts .pop (ip , None )
179+ self ._lockouts .pop (ip , None )
180+
181+ def _maybe_cleanup (self ) -> None :
182+ """Periodically clean up old entries. Must be called with lock held."""
183+ self ._cleanup_counter += 1
184+ if self ._cleanup_counter < self ._cleanup_interval :
185+ return
186+ self ._cleanup_counter = 0
187+
188+ now = datetime .now (UTC )
189+ cutoff = now - self ._attempt_window
190+
191+ # Atomic cleanup using dict comprehension
192+ self ._attempts = {
193+ ip : [t for t in attempts if t > cutoff ]
194+ for ip , attempts in self ._attempts .items ()
195+ if any (t > cutoff for t in attempts )
196+ }
197+ self ._lockouts = {ip : exp for ip , exp in self ._lockouts .items () if exp >= now }
198+
199+
200+ # Global rate limiter instance
201+ _login_rate_limiter = LoginRateLimiter (max_attempts = 5 , lockout_minutes = 15 )
58202
59203# Simple email validation pattern
60204_EMAIL_PATTERN = re .compile (r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" )
@@ -347,11 +491,50 @@ async def login_form(request: Request[Any, Any, Any], session: AsyncSession) ->
347491 )
348492
349493
494+ def _get_client_ip (request : Request [Any , Any , Any ]) -> str :
495+ """Get client IP from request, checking proxy headers only from trusted sources.
496+
497+ Only trusts X-Forwarded-For/X-Real-IP when request comes from localhost
498+ (i.e., from a reverse proxy like nginx/Coolify running on the same host).
499+ This prevents IP spoofing attacks where attackers set fake headers.
500+ """
501+ client = request .client
502+ direct_ip = client .host if client else "unknown"
503+
504+ # Only trust proxy headers if request comes from localhost (reverse proxy)
505+ trusted_proxies = {"127.0.0.1" , "::1" , "localhost" }
506+ if direct_ip in trusted_proxies :
507+ # Check X-Forwarded-For header (set by proxies like nginx, Coolify)
508+ forwarded_for = request .headers .get ("x-forwarded-for" )
509+ if forwarded_for :
510+ # Take the first IP (original client)
511+ return forwarded_for .split ("," )[0 ].strip ()
512+ # Check X-Real-IP header
513+ real_ip = request .headers .get ("x-real-ip" )
514+ if real_ip :
515+ return real_ip
516+
517+ return direct_ip
518+
519+
350520@post ("/login" , sync_to_thread = False )
351521async def login_submit (
352522 request : Request [Any , Any , Any ], session : AsyncSession
353523) -> Template | Redirect :
354524 """Process login form submission."""
525+ client_ip = _get_client_ip (request )
526+
527+ # Check if IP is locked out due to too many failed attempts
528+ if await _login_rate_limiter .is_locked_out (client_ip ):
529+ return Template (
530+ template_name = "admin/login.html" ,
531+ context = {
532+ "error" : "Too many failed attempts. Please try again later." ,
533+ "email" : "" ,
534+ "csrf_token" : _get_csrf_token (request ),
535+ },
536+ )
537+
355538 form_data = await request .form ()
356539 email = form_data .get ("email" , "" ).strip ()
357540 password = form_data .get ("password" , "" )
@@ -368,6 +551,8 @@ async def login_submit(
368551
369552 admin = await authenticate_admin (str (email ), str (password ), session )
370553 if not admin :
554+ # Record failed attempt
555+ await _login_rate_limiter .record_failure (client_ip )
371556 return Template (
372557 template_name = "admin/login.html" ,
373558 context = {
@@ -377,6 +562,8 @@ async def login_submit(
377562 },
378563 )
379564
565+ # Successful login - clear any failed attempts
566+ await _login_rate_limiter .record_success (client_ip )
380567 login_admin (request , admin )
381568 return Redirect (path = "/admin" , status_code = HTTP_303_SEE_OTHER )
382569
@@ -777,15 +964,9 @@ async def oauth_authorize(request: Request[Any, Any, Any], session: AsyncSession
777964 # No OAuth credentials configured, redirect to setup
778965 return Redirect (path = "/admin" , status_code = HTTP_303_SEE_OTHER )
779966
780- # Generate CSRF state token
967+ # Generate CSRF state token (BoundedTTLCache handles cleanup and size limits)
781968 state = secrets .token_urlsafe (32 )
782- _oauth_states [state ] = datetime .now (UTC ) + timedelta (minutes = 10 )
783-
784- # Clean up expired states
785- now = datetime .now (UTC )
786- expired = [s for s , exp in _oauth_states .items () if exp < now ]
787- for s in expired :
788- del _oauth_states [s ]
969+ await _oauth_states .set (state )
789970
790971 # Build authorization URL with state for CSRF protection
791972 base_url = _get_base_url (request )
@@ -819,20 +1000,19 @@ async def oauth_callback(
8191000 )
8201001
8211002 # Validate CSRF state token
822- if not state or state not in _oauth_states :
1003+ if not state or not await _oauth_states . contains ( state ) :
8231004 return Template (
8241005 template_name = "admin/partials/sync_error.html" ,
8251006 context = {"error" : "Invalid OAuth state - possible CSRF attack. Please try again." },
8261007 )
8271008
828- # Check state hasn't expired and remove it (one-time use)
829- if _oauth_states [ state ] < datetime . now ( UTC ):
830- del _oauth_states [ state ]
1009+ # Get and remove state (one-time use) - also checks expiry
1010+ state_expires = await _oauth_states . pop ( state )
1011+ if state_expires and state_expires < datetime . now ( UTC ):
8311012 return Template (
8321013 template_name = "admin/partials/sync_error.html" ,
8331014 context = {"error" : "OAuth state expired. Please try again." },
8341015 )
835- del _oauth_states [state ]
8361016
8371017 # Get OAuth credentials from database
8381018 stmt = select (AppSettings ).where (AppSettings .id == 1 )
0 commit comments