11"""FastAPI Users database adapter for SQLAlchemy."""
22import uuid
3- from typing import Any , Dict , Generic , Optional , Type , TypeVar
3+ from typing import Any , Dict , Generic , Optional , Type
44
55from fastapi_users .db .base import BaseUserDatabase
6- from fastapi_users .models import ID , OAP
6+ from fastapi_users .models import ID , OAP , UP
77from sqlalchemy import Boolean , Column , ForeignKey , Integer , String , func , select
88from sqlalchemy .ext .asyncio import AsyncSession
99from sqlalchemy .ext .declarative import declared_attr
@@ -29,9 +29,6 @@ class SQLAlchemyBaseUserTable(Generic[ID]):
2929 is_verified : bool = Column (Boolean , default = False , nullable = False )
3030
3131
32- UP_SQLALCHEMY = TypeVar ("UP_SQLALCHEMY" , bound = SQLAlchemyBaseUserTable )
33-
34-
3532class SQLAlchemyBaseUserTableUUID (SQLAlchemyBaseUserTable [UUID_ID ]):
3633 id : UUID_ID = Column (GUID , primary_key = True , default = uuid .uuid4 )
3734
@@ -58,9 +55,7 @@ def user_id(cls):
5855 return Column (GUID , ForeignKey ("user.id" , ondelete = "cascade" ), nullable = False )
5956
6057
61- class SQLAlchemyUserDatabase (
62- Generic [UP_SQLALCHEMY , ID ], BaseUserDatabase [UP_SQLALCHEMY , ID ]
63- ):
58+ class SQLAlchemyUserDatabase (Generic [UP , ID ], BaseUserDatabase [UP , ID ]):
6459 """
6560 Database adapter for SQLAlchemy.
6661
@@ -70,32 +65,30 @@ class SQLAlchemyUserDatabase(
7065 """
7166
7267 session : AsyncSession
73- user_table : Type [UP_SQLALCHEMY ]
68+ user_table : Type [UP ]
7469 oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]]
7570
7671 def __init__ (
7772 self ,
7873 session : AsyncSession ,
79- user_table : Type [UP_SQLALCHEMY ],
74+ user_table : Type [UP ],
8075 oauth_account_table : Optional [Type [SQLAlchemyBaseOAuthAccountTable ]] = None ,
8176 ):
8277 self .session = session
8378 self .user_table = user_table
8479 self .oauth_account_table = oauth_account_table
8580
86- async def get (self , id : ID ) -> Optional [UP_SQLALCHEMY ]:
81+ async def get (self , id : ID ) -> Optional [UP ]:
8782 statement = select (self .user_table ).where (self .user_table .id == id )
8883 return await self ._get_user (statement )
8984
90- async def get_by_email (self , email : str ) -> Optional [UP_SQLALCHEMY ]:
85+ async def get_by_email (self , email : str ) -> Optional [UP ]:
9186 statement = select (self .user_table ).where (
9287 func .lower (self .user_table .email ) == func .lower (email )
9388 )
9489 return await self ._get_user (statement )
9590
96- async def get_by_oauth_account (
97- self , oauth : str , account_id : str
98- ) -> Optional [UP_SQLALCHEMY ]:
91+ async def get_by_oauth_account (self , oauth : str , account_id : str ) -> Optional [UP ]:
9992 if self .oauth_account_table is None :
10093 raise NotImplementedError ()
10194
@@ -107,30 +100,26 @@ async def get_by_oauth_account(
107100 )
108101 return await self ._get_user (statement )
109102
110- async def create (self , create_dict : Dict [str , Any ]) -> UP_SQLALCHEMY :
103+ async def create (self , create_dict : Dict [str , Any ]) -> UP :
111104 user = self .user_table (** create_dict )
112105 self .session .add (user )
113106 await self .session .commit ()
114107 await self .session .refresh (user )
115108 return user
116109
117- async def update (
118- self , user : UP_SQLALCHEMY , update_dict : Dict [str , Any ]
119- ) -> UP_SQLALCHEMY :
110+ async def update (self , user : UP , update_dict : Dict [str , Any ]) -> UP :
120111 for key , value in update_dict .items ():
121112 setattr (user , key , value )
122113 self .session .add (user )
123114 await self .session .commit ()
124115 await self .session .refresh (user )
125116 return user
126117
127- async def delete (self , user : UP_SQLALCHEMY ) -> None :
118+ async def delete (self , user : UP ) -> None :
128119 await self .session .delete (user )
129120 await self .session .commit ()
130121
131- async def add_oauth_account (
132- self , user : UP_SQLALCHEMY , create_dict : Dict [str , Any ]
133- ) -> UP_SQLALCHEMY :
122+ async def add_oauth_account (self , user : UP , create_dict : Dict [str , Any ]) -> UP :
134123 if self .oauth_account_table is None :
135124 raise NotImplementedError ()
136125
@@ -145,8 +134,8 @@ async def add_oauth_account(
145134 return user
146135
147136 async def update_oauth_account (
148- self , user : UP_SQLALCHEMY , oauth_account : OAP , update_dict : Dict [str , Any ]
149- ) -> UP_SQLALCHEMY :
137+ self , user : UP , oauth_account : OAP , update_dict : Dict [str , Any ]
138+ ) -> UP :
150139 if self .oauth_account_table is None :
151140 raise NotImplementedError ()
152141
@@ -157,7 +146,7 @@ async def update_oauth_account(
157146 await self .session .refresh (user )
158147 return user
159148
160- async def _get_user (self , statement : Select ) -> Optional [UP_SQLALCHEMY ]:
149+ async def _get_user (self , statement : Select ) -> Optional [UP ]:
161150 results = await self .session .execute (statement )
162151 user = results .first ()
163152 if user is None :
0 commit comments