3939from .models import BaseUser , Role , User , UserRoleLink
4040from .schemas import UserLoginOut
4141
42- _UserModelT = TypeVar ("_UserModelT " , bound = BaseUser )
42+ UserModelT = TypeVar ("UserModelT " , bound = BaseUser )
4343
4444
45- class AuthBackend (AuthenticationBackend , Generic [_UserModelT ]):
45+ class AuthBackend (AuthenticationBackend , Generic [UserModelT ]):
4646 def __init__ (self , auth : "Auth" , token_store : BaseTokenStore ):
4747 self .auth = auth
4848 self .token_store = token_store
@@ -53,23 +53,23 @@ def get_user_token(request: Request) -> Optional[str]:
5353 scheme , token = get_authorization_scheme_param (authorization )
5454 return None if not authorization or scheme .lower () != "bearer" else token
5555
56- async def authenticate (self , request : Request ) -> Tuple ["Auth" , Optional [_UserModelT ]]:
56+ async def authenticate (self , request : Request ) -> Tuple ["Auth" , Optional [UserModelT ]]:
5757 return self .auth , await self .auth .get_current_user (request )
5858
5959 def attach_middleware (self , app : FastAPI ):
6060 app .add_middleware (AuthenticationMiddleware , backend = self ) # 添加auth中间件
6161
6262
63- class Auth (Generic [_UserModelT ]):
64- user_model : Type [_UserModelT ] = None
63+ class Auth (Generic [UserModelT ]):
64+ user_model : Type [UserModelT ] = None
6565 db : Union [AsyncDatabase , Database ] = None
66- backend : AuthBackend [_UserModelT ] = None
66+ backend : AuthBackend [UserModelT ] = None
6767
6868 def __init__ (
6969 self ,
7070 db : Union [AsyncDatabase , Database ],
7171 token_store : BaseTokenStore = None ,
72- user_model : Type [_UserModelT ] = User ,
72+ user_model : Type [UserModelT ] = User ,
7373 pwd_context : CryptContext = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" ),
7474 ):
7575 self .user_model = user_model or self .user_model
@@ -78,7 +78,7 @@ def __init__(
7878 self .backend = self .backend or AuthBackend (self , token_store or DbTokenStore (self .db ))
7979 self .pwd_context = pwd_context
8080
81- async def authenticate_user (self , username : str , password : Union [str , SecretStr ]) -> Optional [_UserModelT ]:
81+ async def authenticate_user (self , username : str , password : Union [str , SecretStr ]) -> Optional [UserModelT ]:
8282 user = await self .db .async_scalar (select (self .user_model ).where (self .user_model .username == username ))
8383 if user :
8484 pwd = password .get_secret_value () if isinstance (password , SecretStr ) else password
@@ -87,21 +87,17 @@ async def authenticate_user(self, username: str, password: Union[str, SecretStr]
8787 return user
8888 return None
8989
90- @cached_property
91- def get_current_user (self ):
92- async def _get_current_user (request : Request ) -> Optional [_UserModelT ]:
93- if request .scope .get ("auth" ): # 防止重复授权
94- return request .scope .get ("user" )
95- request .scope ["auth" ], request .scope ["user" ] = self , None
96- token = self .backend .get_user_token (request )
97- if not token :
98- return None
99- token_data = await self .backend .token_store .read_token (token )
100- if token_data is not None :
101- request .scope ["user" ]: _UserModelT = await self .db .async_get (self .user_model , token_data .id )
102- return request .user
103-
104- return _get_current_user
90+ async def get_current_user (self , request : Request ) -> Optional [UserModelT ]:
91+ if request .scope .get ("auth" ): # 防止重复授权
92+ return request .scope .get ("user" )
93+ request .scope ["auth" ], request .scope ["user" ] = self , None
94+ token = self .backend .get_user_token (request )
95+ if not token :
96+ return None
97+ token_data = await self .backend .token_store .read_token (token )
98+ if token_data is not None :
99+ request .scope ["user" ]: UserModelT = await self .db .async_get (self .user_model , token_data .id )
100+ return request .user
105101
106102 def requires (
107103 self ,
@@ -116,12 +112,12 @@ def requires(
116112 roles_ = (roles ,) if not roles or isinstance (roles , str ) else tuple (roles )
117113 permissions_ = (permissions ,) if not permissions or isinstance (permissions , str ) else tuple (permissions )
118114
119- async def has_requires (user : _UserModelT ) -> bool :
115+ async def has_requires (user : UserModelT ) -> bool :
120116 return user and await self .db .async_run_sync (user .has_requires , roles = roles , groups = groups , permissions = permissions )
121117
122118 async def depend (
123119 request : Request ,
124- user : _UserModelT = Depends (self .get_current_user ),
120+ user : UserModelT = Depends (self .get_current_user ),
125121 ) -> Union [bool , Response ]:
126122 user_auth = request .scope .get ("__user_auth__" , None )
127123 if user_auth is None :
0 commit comments