diff --git a/backend/app/admin/api/v1/sys/user.py b/backend/app/admin/api/v1/sys/user.py index a27ea8dd9..e131fe116 100644 --- a/backend/app/admin/api/v1/sys/user.py +++ b/backend/app/admin/api/v1/sys/user.py @@ -9,7 +9,6 @@ AddUserParam, GetCurrentUserInfoWithRelationDetail, GetUserInfoWithRelationDetail, - RegisterUserParam, ResetPasswordParam, UpdateUserParam, ) @@ -24,12 +23,6 @@ router = APIRouter() -@router.post('/register', summary='注册用户') -async def register_user(obj: RegisterUserParam) -> ResponseModel: - await user_service.register(obj=obj) - return response_base.success() - - @router.post('/add', summary='添加用户', dependencies=[DependsRBAC]) async def add_user(request: Request, obj: AddUserParam) -> ResponseSchemaModel[GetUserInfoWithRelationDetail]: await user_service.add(request=request, obj=obj) diff --git a/backend/app/admin/crud/crud_user.py b/backend/app/admin/crud/crud_user.py index fe0c709a4..7bd5f812e 100644 --- a/backend/app/admin/crud/crud_user.py +++ b/backend/app/admin/crud/crud_user.py @@ -10,8 +10,8 @@ from backend.app.admin.model import Dept, Role, User from backend.app.admin.schema.user import ( + AddOAuth2UserParam, AddUserParam, - RegisterUserParam, UpdateUserParam, ) from backend.common.security.jwt import get_hash_password @@ -61,26 +61,6 @@ async def update_login_time(self, db: AsyncSession, username: str) -> int: """ return await self.update_model_by_column(db, {'last_login_time': timezone.now()}, username=username) - async def create(self, db: AsyncSession, obj: RegisterUserParam, *, social: bool = False) -> None: - """ - 创建用户 - - :param db: 数据库会话 - :param obj: 注册用户参数 - :param social: 是否社交用户 - :return: - """ - if not social: - salt = bcrypt.gensalt() - obj.password = get_hash_password(obj.password, salt) - dict_obj = obj.model_dump() - dict_obj.update({'is_staff': True, 'salt': salt}) - else: - dict_obj = obj.model_dump() - dict_obj.update({'is_staff': True, 'salt': None}) - new_user = self.model(**dict_obj) - db.add(new_user) - async def add(self, db: AsyncSession, obj: AddUserParam) -> None: """ 添加用户 @@ -101,6 +81,21 @@ async def add(self, db: AsyncSession, obj: AddUserParam) -> None: db.add(new_user) + async def add_by_oauth2(self, db: AsyncSession, obj: AddOAuth2UserParam) -> None: + """ + 通过 OAuth2 添加用户 + + :param db: 数据库会话 + :param obj: 注册用户参数 + :return: + """ + salt = bcrypt.gensalt() + obj.password = get_hash_password(obj.password, salt) + dict_obj = obj.model_dump() + dict_obj.update({'is_staff': True, 'salt': salt}) + new_user = self.model(**dict_obj) + db.add(new_user) + async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam) -> int: """ 更新用户信息 @@ -131,7 +126,7 @@ async def delete(self, db: AsyncSession, user_id: int) -> int: async def check_email(self, db: AsyncSession, email: str) -> User | None: """ - 检查邮箱是否已被注册 + 检查邮箱是否已被绑定 :param db: 数据库会话 :param email: 电子邮箱 diff --git a/backend/app/admin/model/user.py b/backend/app/admin/model/user.py index 871c1b6e8..7e03ade72 100644 --- a/backend/app/admin/model/user.py +++ b/backend/app/admin/model/user.py @@ -26,22 +26,22 @@ class User(Base): id: Mapped[id_key] = mapped_column(init=False) uuid: Mapped[str] = mapped_column(String(50), init=False, default_factory=uuid4_str, unique=True) username: Mapped[str] = mapped_column(String(20), unique=True, index=True, comment='用户名') - nickname: Mapped[str] = mapped_column(String(20), unique=True, comment='昵称') - password: Mapped[str | None] = mapped_column(String(255), comment='密码') - salt: Mapped[bytes | None] = mapped_column(VARBINARY(255).with_variant(BYTEA(255), 'postgresql'), comment='加密盐') - email: Mapped[str] = mapped_column(String(50), unique=True, index=True, comment='邮箱') + nickname: Mapped[str] = mapped_column(String(20), comment='昵称') + password: Mapped[str] = mapped_column(String(255), comment='密码') + salt: Mapped[bytes] = mapped_column(VARBINARY(255).with_variant(BYTEA(255), 'postgresql'), comment='加密盐') + email: Mapped[str | None] = mapped_column(String(50), default=None, unique=True, index=True, comment='邮箱') + phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号') + avatar: Mapped[str | None] = mapped_column(String(255), default=None, comment='头像') + status: Mapped[int] = mapped_column(default=1, index=True, comment='用户账号状态(0停用 1正常)') is_superuser: Mapped[bool] = mapped_column( Boolean().with_variant(INTEGER, 'postgresql'), default=False, comment='超级权限(0否 1是)' ) is_staff: Mapped[bool] = mapped_column( Boolean().with_variant(INTEGER, 'postgresql'), default=False, comment='后台管理登陆(0否 1是)' ) - status: Mapped[int] = mapped_column(default=1, index=True, comment='用户账号状态(0停用 1正常)') is_multi_login: Mapped[bool] = mapped_column( Boolean().with_variant(INTEGER, 'postgresql'), default=False, comment='是否重复登陆(0否 1是)' ) - avatar: Mapped[str | None] = mapped_column(String(255), default=None, comment='头像') - phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号') join_time: Mapped[datetime] = mapped_column( DateTime(timezone=True), init=False, default_factory=timezone.now, comment='注册时间' ) diff --git a/backend/app/admin/schema/user.py b/backend/app/admin/schema/user.py index 3dfb4a6bc..397fc860f 100644 --- a/backend/app/admin/schema/user.py +++ b/backend/app/admin/schema/user.py @@ -9,7 +9,7 @@ from backend.app.admin.schema.dept import GetDeptDetail from backend.app.admin.schema.role import GetRoleWithRelationDetail from backend.common.enums import StatusType -from backend.common.schema import CustomPhoneNumber, SchemaBase +from backend.common.schema import CustomEmailStr, CustomPhoneNumber, SchemaBase class AuthSchemaBase(SchemaBase): @@ -25,20 +25,19 @@ class AuthLoginParam(AuthSchemaBase): captcha: str = Field(description='验证码') -class RegisterUserParam(AuthSchemaBase): - """用户注册参数""" - - nickname: str | None = Field(None, description='昵称') - email: EmailStr = Field(examples=['user@example.com'], description='邮箱') - - class AddUserParam(AuthSchemaBase): """添加用户参数""" dept_id: int = Field(description='部门 ID') roles: list[int] = Field(description='角色 ID 列表') nickname: str | None = Field(None, description='昵称') - email: EmailStr = Field(examples=['user@example.com'], description='邮箱') + + +class AddOAuth2UserParam(AuthSchemaBase): + """添加 OAuth2 用户参数""" + + nickname: str | None = Field(None, description='昵称') + email: EmailStr = Field(description='邮箱') class ResetPasswordParam(SchemaBase): @@ -56,8 +55,6 @@ class UserInfoSchemaBase(SchemaBase): username: str = Field(description='用户名') nickname: str = Field(description='昵称') avatar: HttpUrl | None = Field(None, description='头像') - email: EmailStr = Field(examples=['user@example.com'], description='邮箱') - phone: CustomPhoneNumber | None = Field(None, description='手机号') class UpdateUserParam(UserInfoSchemaBase): @@ -74,7 +71,8 @@ class GetUserInfoDetail(UserInfoSchemaBase): dept_id: int | None = Field(None, description='部门 ID') id: int = Field(description='用户 ID') uuid: str = Field(description='用户 UUID') - avatar: str | None = Field(None, description='头像') + email: CustomEmailStr | None = Field(None, description='邮箱') + phone: CustomPhoneNumber | None = Field(None, description='手机号') status: StatusType = Field(StatusType.enable, description='状态') is_superuser: bool = Field(description='是否超级管理员') is_staff: bool = Field(description='是否管理员') diff --git a/backend/app/admin/service/user_service.py b/backend/app/admin/service/user_service.py index 7730f68ef..4c0f58113 100644 --- a/backend/app/admin/service/user_service.py +++ b/backend/app/admin/service/user_service.py @@ -13,7 +13,6 @@ from backend.app.admin.model import Role, User from backend.app.admin.schema.user import ( AddUserParam, - RegisterUserParam, ResetPasswordParam, UpdateUserParam, ) @@ -27,29 +26,6 @@ class UserService: """用户服务类""" - @staticmethod - async def register(*, obj: RegisterUserParam) -> None: - """ - 注册新用户 - - :param obj: 用户注册参数 - :return: - """ - async with async_db_session.begin() as db: - if not obj.password: - raise errors.ForbiddenError(msg='密码为空') - username = await user_dao.get_by_username(db, obj.username) - if username: - raise errors.ForbiddenError(msg='用户已注册') - obj.nickname = obj.nickname if obj.nickname else f'#{random.randrange(10000, 88888)}' - nickname = await user_dao.get_by_nickname(db, obj.nickname) - if nickname: - raise errors.ForbiddenError(msg='昵称已注册') - email = await user_dao.check_email(db, obj.email) - if email: - raise errors.ForbiddenError(msg='邮箱已注册') - await user_dao.create(db, obj) - @staticmethod async def add(*, request: Request, obj: AddUserParam) -> None: """ @@ -70,9 +46,6 @@ async def add(*, request: Request, obj: AddUserParam) -> None: raise errors.ForbiddenError(msg='昵称已注册') if not obj.password: raise errors.ForbiddenError(msg='密码为空') - email = await user_dao.check_email(db, obj.email) - if email: - raise errors.ForbiddenError(msg='邮箱已注册') dept = await dept_dao.get(db, obj.dept_id) if not dept: raise errors.NotFoundError(msg='部门不存在') @@ -175,10 +148,6 @@ async def update(*, request: Request, username: str, obj: UpdateUserParam) -> in nickname = await user_dao.get_by_nickname(db, obj.nickname) if nickname: raise errors.ForbiddenError(msg='昵称已注册') - if user.email != obj.email: - email = await user_dao.check_email(db, obj.email) - if email: - raise errors.ForbiddenError(msg='邮箱已注册') for role_id in obj.roles: role = await role_dao.get(db, role_id) if not role: