22from typing import Annotated , Any
33
44import jwt
5- from fastapi import Depends , HTTPException , status
5+ from fastapi import Depends , HTTPException , status , Request
66from fastapi .security import OAuth2PasswordBearer
77from jwt .exceptions import InvalidTokenError
88from passlib .context import CryptContext
99from pydantic import ValidationError
10- from sqlmodel import Session
10+ from sqlmodel import Session , select
1111
1212from src .auth .schemas import TokenPayload
1313from src .core .config import settings
1414from src .core .db import get_db
1515from src .users .models import User
16+ from src .users .schemas import UserPublic
1617
1718ALGORITHM = "HS256"
1819
2324TokenDep = Annotated [str , Depends (reusable_oauth2 )]
2425
2526
26- def get_current_user (session : SessionDep , token : TokenDep ) -> User :
27+ def get_user_from_session (request : Request , session : SessionDep ) -> User :
28+ session_user = request .session .get ("user" )
29+ if not session_user :
30+ raise HTTPException (status_code = 401 , detail = "Not authenticated (no session)" )
31+
32+ user = session .exec (select (User ).where (User .email == session_user ["email" ])).first ()
33+ if not user or not user .is_active :
34+ raise HTTPException (status_code = 401 , detail = "Invalid session user" )
35+ return UserPublic .model_validate (user )
36+
37+
38+ def get_user_from_token (
39+ session : SessionDep ,
40+ token : Annotated [str , Depends (reusable_oauth2 )],
41+ ) -> User :
2742 try :
2843 payload = jwt .decode (token , settings .SECRET_KEY , algorithms = [ALGORITHM ])
2944 token_data = TokenPayload (** payload )
45+ user = session .get (User , token_data .sub )
46+ if not user or not user .is_active :
47+ raise HTTPException (status_code = 401 , detail = "Invalid user" )
48+ return user
3049 except (InvalidTokenError , ValidationError ):
31- raise HTTPException (
32- status_code = status .HTTP_403_FORBIDDEN ,
33- detail = "Could not validate credentials" ,
34- )
35- user = session .get (User , token_data .sub )
36- if not user :
37- raise HTTPException (status_code = 404 , detail = "User not found" )
38- if not user .is_active :
39- raise HTTPException (status_code = 400 , detail = "Inactive user" )
40- return user
50+ raise HTTPException (status_code = 403 , detail = "Invalid token" )
51+
52+
53+ def get_current_user (
54+ request : Request ,
55+ session : SessionDep ,
56+ token : Annotated [str | None , Depends (OAuth2PasswordBearer (tokenUrl = f"{ settings .API_V1_STR } /tokens" , auto_error = False ))] = None ,
57+ ) -> User :
58+ print ("in get current user" )
59+ # Prefer session (Auth0 flow)
60+ session_user = request .session .get ("user" )
61+ if session_user :
62+ print ("Session user found:" , session_user ["email" ])
63+ res = get_user_from_session (request , session )
64+ print ("User from session:" , res )
65+ return res
66+ # Fallback to token (JWT flow)
67+ if token :
68+ return get_user_from_token (session , token )
69+
70+ raise HTTPException (status_code = 401 , detail = "Not authenticated" )
4171
4272
4373CurrentUser = Annotated [User , Depends (get_current_user )]
@@ -53,8 +83,15 @@ def authenticate(*, session: Session, email: str, password: str) -> User | None:
5383 db_user = get_user_by_email (session = session , email = email )
5484 if not db_user :
5585 return None
86+
87+ # Auth0 users may not have a password
88+ if not db_user .hashed_password :
89+ # Return None for users without a password when using password authentication
90+ return None
91+
5692 if not verify_password (password , db_user .hashed_password ):
5793 return None
94+
5895 return db_user
5996
6097
@@ -67,3 +104,15 @@ def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
67104
68105def get_password_hash (password : str ) -> str :
69106 return pwd_context .hash (password )
107+
108+
109+ def get_or_create_user_by_email (session : Session , email : str , defaults : dict = {}) -> User :
110+ user = session .exec (select (User ).where (User .email == email )).first ()
111+ if user :
112+ return user
113+ user = User (email = email , ** defaults )
114+ session .add (user )
115+ session .commit ()
116+ session .refresh (user )
117+ return user
118+
0 commit comments