77from app .schemas .user import UserInDB
88from fastapi import Depends , HTTPException , Request , status
99from fastapi .security import OAuth2PasswordBearer
10- from itsdangerous import BadSignature , SignatureExpired , URLSafeTimedSerializer
10+ from itsdangerous import URLSafeTimedSerializer
1111from passlib .context import CryptContext
1212
1313oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "/api/v1/login" )
@@ -28,10 +28,6 @@ class SecurityService:
2828 def __init__ (self ) -> None :
2929 self .pwd_context = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" )
3030 self .settings = get_settings ()
31- self .csrf_serializer = URLSafeTimedSerializer (
32- secret_key = self .settings .SECRET_KEY ,
33- salt = "csrf-token"
34- )
3531
3632 def verify_password (self , plain_password : str , hashed_password : str ) -> bool :
3733 return self .pwd_context .verify (plain_password , hashed_password ) # type: ignore
@@ -77,39 +73,25 @@ async def get_current_user(
7773 raise credentials_exception
7874 return user
7975
80- def generate_csrf_token (self , session_id : str ) -> str :
81- """Generate a CSRF token for the given session"""
82- data = {
83- "session_id" : session_id ,
84- "timestamp" : datetime .utcnow ().isoformat ()
85- }
86- return self .csrf_serializer .dumps (data )
76+ def generate_csrf_token (self ) -> str :
77+ """Generate a CSRF token using secure random"""
78+ import secrets
79+ return secrets .token_urlsafe (32 )
8780
88- def validate_csrf_token (self , token : str , session_id : str ) -> bool :
89- """Validate a CSRF token"""
90- try :
91- data = self .csrf_serializer .loads (token , max_age = 3600 ) # 1 hour
92- return bool (data .get ("session_id" ) == session_id )
93- except (BadSignature , SignatureExpired ):
81+ def validate_csrf_token (self , header_token : str , cookie_token : str ) -> bool :
82+ """Validate CSRF token using double-submit cookie pattern"""
83+ if not header_token or not cookie_token :
9484 return False
95-
96- def get_session_id_from_request (self , request : Request ) -> str :
97- """Get session ID from request (using access token as session identifier)"""
98- token = request .cookies .get ("access_token" )
99- if token :
100- return token [:32 ] # Use first 32 chars as session ID
101-
102- # Fallback to client fingerprint
103- client_ip = request .client .host if request .client else "unknown"
104- user_agent = request .headers .get ("user-agent" , "unknown" )
105- return f"{ client_ip } :{ user_agent } " [:32 ]
85+ # Constant-time comparison to prevent timing attacks
86+ import hmac
87+ return hmac .compare_digest (header_token , cookie_token )
10688
10789
10890security_service = SecurityService ()
10991
11092
11193def validate_csrf_token (request : Request ) -> str :
112- """FastAPI dependency to validate CSRF token"""
94+ """FastAPI dependency to validate CSRF token using double-submit cookie pattern """
11395 # Skip CSRF validation for safe methods
11496 if request .method in ["GET" , "HEAD" , "OPTIONS" ]:
11597 return "skip"
@@ -128,20 +110,21 @@ def validate_csrf_token(request: Request) -> str:
128110 # If not authenticated, skip CSRF validation (auth will be handled by other dependencies)
129111 return "skip"
130112
131- # Get CSRF token from request
132- csrf_token = request .headers .get ("X-CSRF-Token" )
133- if not csrf_token :
113+ # Get CSRF token from header and cookie
114+ header_token = request .headers .get ("X-CSRF-Token" )
115+ cookie_token = request .cookies .get ("csrf_token" )
116+
117+ if not header_token :
134118 raise HTTPException (
135119 status_code = status .HTTP_403_FORBIDDEN ,
136120 detail = "CSRF token missing"
137121 )
138122
139- # Validate CSRF token
140- session_id = security_service .get_session_id_from_request (request )
141- if not security_service .validate_csrf_token (csrf_token , session_id ):
123+ # Validate using double-submit cookie pattern
124+ if not security_service .validate_csrf_token (header_token , cookie_token ):
142125 raise HTTPException (
143126 status_code = status .HTTP_403_FORBIDDEN ,
144127 detail = "CSRF token invalid"
145128 )
146129
147- return csrf_token
130+ return header_token
0 commit comments