88from fastapi import Depends , HTTPException , Request , status
99from fastapi .security import OAuth2PasswordBearer
1010from passlib .context import CryptContext
11+ from itsdangerous import URLSafeTimedSerializer , BadSignature , SignatureExpired
1112
1213oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "/api/v1/login" )
1314
@@ -27,6 +28,10 @@ class SecurityService:
2728 def __init__ (self ) -> None :
2829 self .pwd_context = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" )
2930 self .settings = get_settings ()
31+ self .csrf_serializer = URLSafeTimedSerializer (
32+ secret_key = self .settings .SECRET_KEY ,
33+ salt = "csrf-token"
34+ )
3035
3136 def verify_password (self , plain_password : str , hashed_password : str ) -> bool :
3237 return self .pwd_context .verify (plain_password , hashed_password ) # type: ignore
@@ -72,5 +77,71 @@ async def get_current_user(
7277 raise credentials_exception
7378 return user
7479
80+ def generate_csrf_token (self , session_id : str ) -> str :
81+ """Generate a CSRF token for the given session"""
82+ data = {
83+ "session_id" : session_id ,
84+ "timestamp" : datetime .utcnow ().isoformat ()
85+ }
86+ return self .csrf_serializer .dumps (data )
87+
88+ def validate_csrf_token (self , token : str , session_id : str ) -> bool :
89+ """Validate a CSRF token"""
90+ try :
91+ data = self .csrf_serializer .loads (token , max_age = 3600 ) # 1 hour
92+ return data .get ("session_id" ) == session_id
93+ except (BadSignature , SignatureExpired ):
94+ return False
95+
96+ def get_session_id_from_request (self , request : Request ) -> str :
97+ """Get session ID from request (using access token as session identifier)"""
98+ token = request .cookies .get ("access_token" )
99+ if token :
100+ return token [:32 ] # Use first 32 chars as session ID
101+
102+ # Fallback to client fingerprint
103+ client_ip = request .client .host if request .client else "unknown"
104+ user_agent = request .headers .get ("user-agent" , "unknown" )
105+ return f"{ client_ip } :{ user_agent } " [:32 ]
106+
75107
76108security_service = SecurityService ()
109+
110+
111+ def validate_csrf_token (request : Request ) -> str :
112+ """FastAPI dependency to validate CSRF token"""
113+ # Skip CSRF validation for safe methods
114+ if request .method in ["GET" , "HEAD" , "OPTIONS" ]:
115+ return "skip"
116+
117+ # Skip CSRF validation for auth endpoints
118+ if request .url .path in ["/api/v1/login" , "/api/v1/register" , "/api/v1/logout" ]:
119+ return "skip"
120+
121+ # Skip CSRF validation for non-API endpoints
122+ if not request .url .path .startswith ("/api/" ):
123+ return "skip"
124+
125+ # Check if user is authenticated first (has access_token cookie)
126+ access_token = request .cookies .get ("access_token" )
127+ if not access_token :
128+ # If not authenticated, skip CSRF validation (auth will be handled by other dependencies)
129+ return "skip"
130+
131+ # Get CSRF token from request
132+ csrf_token = request .headers .get ("X-CSRF-Token" )
133+ if not csrf_token :
134+ raise HTTPException (
135+ status_code = status .HTTP_403_FORBIDDEN ,
136+ detail = "CSRF token missing"
137+ )
138+
139+ # Validate CSRF token
140+ session_id = security_service .get_session_id_from_request (request )
141+ if not security_service .validate_csrf_token (csrf_token , session_id ):
142+ raise HTTPException (
143+ status_code = status .HTTP_403_FORBIDDEN ,
144+ detail = "CSRF token invalid"
145+ )
146+
147+ return csrf_token
0 commit comments