1+ import contextlib
12from typing import Any , Callable , Dict , List , Type
23
34from fastapi import Depends , HTTPException
1415 PageSchema ,
1516)
1617from fastapi_amis_admin .amis .constants import DisplayModeEnum , LevelEnum
18+ from fastapi_amis_admin .crud .base import SchemaUpdateT
1719from fastapi_amis_admin .crud .schema import BaseApiOut
1820from fastapi_amis_admin .utils .translation import i18n as _
1921from pydantic import BaseModel
20- from sqlalchemy import insert , select , update
22+ from sqlalchemy import select
2123from starlette import status
2224from starlette .requests import Request
2325from starlette .responses import Response
@@ -46,14 +48,12 @@ class UserLoginFormAdmin(FormAdmin):
4648 page = Page (title = _ ("User Login" ))
4749 page_path = "/login"
4850 page_parser_mode = "html"
49- schema : Type [BaseModel ] = None
51+ schema : Type [SchemaUpdateT ] = None
5052 schema_submit_out : Type [UserLoginOut ] = None
5153 page_schema = None
5254 page_route_kwargs = {"name" : "login" }
5355
54- async def handle (
55- self , request : Request , data : BaseModel , ** kwargs # self.schema
56- ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
56+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
5757 if request .user :
5858 return BaseApiOut (code = 1 , msg = _ ("User logged in!" ), data = self .schema_submit_out .parse_obj (request .user ))
5959 user = await request .auth .authenticate_user (username = data .username , password = data .password ) # type:ignore
@@ -79,16 +79,14 @@ async def route(response: Response, result: BaseApiOut = Depends(super().route_s
7979 async def get_form (self , request : Request ) -> Form :
8080 form = await super ().get_form (request )
8181 buttons = []
82- try :
82+ with contextlib . suppress ( NoMatchFound ) :
8383 buttons .append (
8484 ActionType .Link (
8585 actionType = "link" ,
86- link = f"{ self .router_path } { self .router .url_path_for ('reg' )} " ,
86+ link = f"{ self .site . router_path } { self . app .router .url_path_for ('reg' )} " ,
8787 label = _ ("Sign up" ),
8888 )
8989 )
90- except NoMatchFound :
91- pass
9290 buttons .append (Action (actionType = "submit" , label = _ ("Sign in" ), level = LevelEnum .primary ))
9391 form .body .sort (key = lambda form_item : form_item .type , reverse = True )
9492 form .update_from_kwargs (
@@ -130,27 +128,25 @@ class UserRegFormAdmin(FormAdmin):
130128 page = Page (title = _ ("User Register" ))
131129 page_path = "/reg"
132130 page_parser_mode = "html"
133- schema : Type [BaseModel ] = None
131+ schema : Type [SchemaUpdateT ] = None
134132 schema_submit_out : Type [UserLoginOut ] = None
135133 page_schema = None
136134 page_route_kwargs = {"name" : "reg" }
137135
138- async def handle (
139- self , request : Request , data : BaseModel , ** kwargs # self.schema
140- ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
136+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [BaseModel ]: # self.schema_submit_out
141137 auth : Auth = request .auth
142- user = await auth .db .scalar (select (self .user_model ).where (self .user_model .username == data .username ))
138+ user = await auth .db .async_scalar (select (self .user_model ).where (self .user_model .username == data .username ))
143139 if user :
144140 return BaseApiOut (status = - 1 , msg = _ ("Username has been registered!" ), data = None )
145- user = await auth .db .scalar (select (self .user_model ).where (self .user_model .email == data .email ))
141+ user = await auth .db .async_scalar (select (self .user_model ).where (self .user_model .email == data .email ))
146142 if user :
147143 return BaseApiOut (status = - 2 , msg = _ ("Email has been registered!" ), data = None )
148- user = self .user_model .parse_obj (data )
149- values = user .dict (exclude = {"id" , "password" })
150- values ["password" ] = auth .pwd_context .hash (user .password .get_secret_value ()) # 密码hash保存
151- stmt = insert (self .user_model ).values (values )
144+ values = data .dict (exclude = {"id" , "password" })
145+ values ["password" ] = auth .pwd_context .hash (data .password .get_secret_value ()) # 密码hash保存
146+ user = self .user_model .parse_obj (values )
152147 try :
153- user .id = await auth .db .async_execute (stmt , on_close_pre = lambda r : getattr (r , "lastrowid" , None ))
148+ auth .db .add (user )
149+ await auth .db .async_flush ()
154150 except Exception as e :
155151 raise HTTPException (
156152 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
@@ -211,7 +207,7 @@ class UserInfoFormAdmin(FormAdmin):
211207 user_model : Type [BaseUser ] = User
212208 page = Page (title = _ ("User Profile" ))
213209 page_path = "/userinfo"
214- schema : Type [BaseModel ] = None
210+ schema : Type [SchemaUpdateT ] = None
215211 schema_submit_out : Type [BaseUser ] = None
216212 form_init = True
217213 form = Form (mode = DisplayModeEnum .horizontal )
@@ -230,10 +226,9 @@ async def get_form(self, request: Request) -> Form:
230226 form .body .extend (formitem .update_from_kwargs (disabled = True ) for formitem in formitems if formitem )
231227 return form
232228
233- async def handle (self , request : Request , data : BaseModel , ** kwargs ) -> BaseApiOut [Any ]:
234- stmt = update (self .user_model ).where (self .user_model .username == request .user .username ).values (data .dict ())
235- await self .site .db .async_execute (stmt )
236- await self .site .db .async_refresh (request .user )
229+ async def handle (self , request : Request , data : SchemaUpdateT , ** kwargs ) -> BaseApiOut [Any ]:
230+ for k , v in data .dict ().items ():
231+ setattr (request .user , k , v )
237232 return BaseApiOut (data = self .schema_submit_out .parse_obj (request .user ))
238233
239234 async def has_page_permission (self , request : Request ) -> bool :
0 commit comments