77from app .schemas .user import UserInDB
88from fastapi import Depends , HTTPException , Request , status
99from fastapi .security import OAuth2PasswordBearer
10+ from itsdangerous import BadSignature , SignatureExpired , URLSafeTimedSerializer
1011from passlib .context import CryptContext
11- from itsdangerous import URLSafeTimedSerializer , BadSignature , SignatureExpired
1212
1313oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "/api/v1/login" )
1414
@@ -89,7 +89,7 @@ def validate_csrf_token(self, token: str, session_id: str) -> bool:
8989 """Validate a CSRF token"""
9090 try :
9191 data = self .csrf_serializer .loads (token , max_age = 3600 ) # 1 hour
92- return data .get ("session_id" ) == session_id
92+ return bool ( data .get ("session_id" ) == session_id )
9393 except (BadSignature , SignatureExpired ):
9494 return False
9595
@@ -98,7 +98,7 @@ def get_session_id_from_request(self, request: Request) -> str:
9898 token = request .cookies .get ("access_token" )
9999 if token :
100100 return token [:32 ] # Use first 32 chars as session ID
101-
101+
102102 # Fallback to client fingerprint
103103 client_ip = request .client .host if request .client else "unknown"
104104 user_agent = request .headers .get ("user-agent" , "unknown" )
@@ -113,35 +113,35 @@ def validate_csrf_token(request: Request) -> str:
113113 # Skip CSRF validation for safe methods
114114 if request .method in ["GET" , "HEAD" , "OPTIONS" ]:
115115 return "skip"
116-
116+
117117 # Skip CSRF validation for auth endpoints
118118 if request .url .path in ["/api/v1/login" , "/api/v1/register" , "/api/v1/logout" ]:
119119 return "skip"
120-
120+
121121 # Skip CSRF validation for non-API endpoints
122122 if not request .url .path .startswith ("/api/" ):
123123 return "skip"
124-
124+
125125 # Check if user is authenticated first (has access_token cookie)
126126 access_token = request .cookies .get ("access_token" )
127127 if not access_token :
128128 # If not authenticated, skip CSRF validation (auth will be handled by other dependencies)
129129 return "skip"
130-
130+
131131 # Get CSRF token from request
132132 csrf_token = request .headers .get ("X-CSRF-Token" )
133133 if not csrf_token :
134134 raise HTTPException (
135135 status_code = status .HTTP_403_FORBIDDEN ,
136136 detail = "CSRF token missing"
137137 )
138-
138+
139139 # Validate CSRF token
140140 session_id = security_service .get_session_id_from_request (request )
141141 if not security_service .validate_csrf_token (csrf_token , session_id ):
142142 raise HTTPException (
143143 status_code = status .HTTP_403_FORBIDDEN ,
144144 detail = "CSRF token invalid"
145145 )
146-
146+
147147 return csrf_token
0 commit comments