From 59746bc811915d69a3ec42b3b84feaf799f2d55b Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Sun, 8 Jun 2025 16:47:06 +0800 Subject: [PATCH] Simplify OAuth2 model and optimize auth service --- backend/app/admin/crud/crud_user.py | 11 +++ backend/app/admin/schema/user.py | 3 +- .../plugin/oauth2/crud/crud_user_social.py | 15 +++- backend/plugin/oauth2/model/user_social.py | 10 +-- backend/plugin/oauth2/plugin.toml | 2 +- backend/plugin/oauth2/schema/user_social.py | 6 +- .../plugin/oauth2/service/oauth2_service.py | 79 +++++++++++-------- 7 files changed, 74 insertions(+), 52 deletions(-) diff --git a/backend/app/admin/crud/crud_user.py b/backend/app/admin/crud/crud_user.py index 7bd5f812e..25f03074c 100644 --- a/backend/app/admin/crud/crud_user.py +++ b/backend/app/admin/crud/crud_user.py @@ -114,6 +114,17 @@ async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam) input_user.roles = roles.scalars().all() return count + async def update_avatar(self, db: AsyncSession, user_id: int, avatar: str) -> int: + """ + 更新用户头像 + + :param db: 数据库会话 + :param user_id: 用户 ID + :param avatar: 头像地址 + :return: + """ + return await self.update_model(db, user_id, {'avatar': avatar}) + async def delete(self, db: AsyncSession, user_id: int) -> int: """ 删除用户 diff --git a/backend/app/admin/schema/user.py b/backend/app/admin/schema/user.py index 397fc860f..4f8f818a6 100644 --- a/backend/app/admin/schema/user.py +++ b/backend/app/admin/schema/user.py @@ -38,6 +38,7 @@ class AddOAuth2UserParam(AuthSchemaBase): nickname: str | None = Field(None, description='昵称') email: EmailStr = Field(description='邮箱') + avatar: HttpUrl | None = Field(None, description='头像地址') class ResetPasswordParam(SchemaBase): @@ -54,7 +55,7 @@ class UserInfoSchemaBase(SchemaBase): dept_id: int | None = Field(None, description='部门 ID') username: str = Field(description='用户名') nickname: str = Field(description='昵称') - avatar: HttpUrl | None = Field(None, description='头像') + avatar: HttpUrl | None = Field(None, description='头像地址') class UpdateUserParam(UserInfoSchemaBase): diff --git a/backend/plugin/oauth2/crud/crud_user_social.py b/backend/plugin/oauth2/crud/crud_user_social.py index 53852f37f..e098c1fac 100644 --- a/backend/plugin/oauth2/crud/crud_user_social.py +++ b/backend/plugin/oauth2/crud/crud_user_social.py @@ -10,9 +10,9 @@ class CRUDUserSocial(CRUDPlus[UserSocial]): """用户社交账号数据库操作类""" - async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None: + async def check_binding(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None: """ - 获取用户社交账号绑定详情 + 检查系统用户社交账号绑定 :param db: 数据库会话 :param pk: 用户 ID @@ -21,6 +21,17 @@ async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None """ return await self.select_model_by_column(db, user_id=pk, source=source) + async def get_by_sid(self, db: AsyncSession, sid: str, source: str) -> UserSocial | None: + """ + 通过 UUID 获取社交用户 + + :param db: 数据库会话 + :param sid: 第三方 UUID + :param source: 社交账号类型 + :return: + """ + return await self.select_model_by_column(db, sid=sid, source=source) + async def create(self, db: AsyncSession, obj: CreateUserSocialParam) -> None: """ 创建用户社交账号绑定 diff --git a/backend/plugin/oauth2/model/user_social.py b/backend/plugin/oauth2/model/user_social.py index 980608b40..1e18e10a9 100644 --- a/backend/plugin/oauth2/model/user_social.py +++ b/backend/plugin/oauth2/model/user_social.py @@ -19,15 +19,9 @@ class UserSocial(Base): __tablename__ = 'sys_user_social' id: Mapped[id_key] = mapped_column(init=False) + sid: Mapped[str] = mapped_column(String(20), comment='第三方用户 ID') source: Mapped[str] = mapped_column(String(20), comment='第三方用户来源') - open_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 open id') - sid: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 ID') - union_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 union id') - scope: Mapped[str | None] = mapped_column(String(120), default=None, comment='第三方用户授予的权限') - code: Mapped[str | None] = mapped_column(String(50), default=None, comment='用户的授权 code') # 用户社交信息一对多 - user_id: Mapped[int | None] = mapped_column( - ForeignKey('sys_user.id', ondelete='SET NULL'), default=None, comment='用户关联ID' - ) + user_id: Mapped[int] = mapped_column(ForeignKey('sys_user.id', ondelete='CASCADE'), comment='用户关联ID') user: Mapped[User | None] = relationship(init=False, backref='socials') diff --git a/backend/plugin/oauth2/plugin.toml b/backend/plugin/oauth2/plugin.toml index 734af64ee..6e5686137 100644 --- a/backend/plugin/oauth2/plugin.toml +++ b/backend/plugin/oauth2/plugin.toml @@ -1,6 +1,6 @@ [plugin] summary = 'OAuth 2.0' -version = '0.0.1' +version = '0.0.2' description = '通过 OAuth 2.0 的方式登录系统' author = 'wu-clan' diff --git a/backend/plugin/oauth2/schema/user_social.py b/backend/plugin/oauth2/schema/user_social.py index 104f735e5..75c24ee5f 100644 --- a/backend/plugin/oauth2/schema/user_social.py +++ b/backend/plugin/oauth2/schema/user_social.py @@ -9,12 +9,8 @@ class UserSocialSchemaBase(SchemaBase): """用户社交基础模型""" + sid: str = Field(description='第三方用户 ID') source: UserSocialType = Field(description='社交平台') - open_id: str | None = Field(None, description='开放平台 ID') - sid: str | None = Field(None, description='第三方用户 ID') - union_id: str | None = Field(None, description='开放平台唯一 ID') - scope: str | None = Field(None, description='授权范围') - code: str | None = Field(None, description='授权码') class CreateUserSocialParam(UserSocialSchemaBase): diff --git a/backend/plugin/oauth2/service/oauth2_service.py b/backend/plugin/oauth2/service/oauth2_service.py index 3489faa77..461dedcb6 100644 --- a/backend/plugin/oauth2/service/oauth2_service.py +++ b/backend/plugin/oauth2/service/oauth2_service.py @@ -7,10 +7,9 @@ from backend.app.admin.crud.crud_user import user_dao from backend.app.admin.schema.token import GetLoginToken -from backend.app.admin.schema.user import RegisterUserParam +from backend.app.admin.schema.user import AddOAuth2UserParam from backend.app.admin.service.login_log_service import login_log_service from backend.common.enums import LoginLogStatusType, UserSocialType -from backend.common.exception.errors import AuthorizationError from backend.common.security import jwt from backend.core.conf import settings from backend.database.db import async_db_session @@ -43,55 +42,65 @@ async def create_with_login( :return: """ async with async_db_session.begin() as db: - # 获取 OAuth2 平台用户信息 - social_id = user.get('id') - social_nickname = user.get('name') + sid = user.get('uuid') + username = user.get('username') + nickname = user.get('nickname') + email = user.get('email') + avatar = user.get('avatar_url') - social_username = user.get('username') if social == UserSocialType.github: - social_username = user.get('login') + sid = user.get('id') + username = user.get('login') + nickname = user.get('name') - social_email = user.get('email') if social == UserSocialType.linux_do: - social_email = f'{social_username}@linux.do' - if not social_email: - raise AuthorizationError(msg=f'授权失败,{social.value} 账户未绑定邮箱') + sid = user.get('id') + nickname = user.get('name') - # 创建系统用户 - sys_user = await user_dao.check_email(db, social_email) - if not sys_user: - sys_user = await user_dao.get_by_username(db, social_username) - if sys_user: - social_username = f'{social_username}#{text_captcha(5)}' - sys_user = await user_dao.get_by_nickname(db, social_nickname) - if sys_user: - social_username = f'{social_nickname}#{text_captcha(5)}' - new_sys_user = RegisterUserParam( - username=social_username, password=None, nickname=social_username, email=social_email - ) - await user_dao.create(db, new_sys_user, social=True) - await db.flush() - sys_user = await user_dao.check_email(db, social_email) - # 绑定社交用户 - sys_user_id = sys_user.id - user_social = await user_social_dao.get(db, sys_user_id, social.value) + sys_user = None + user_social = await user_social_dao.get_by_sid(db, str(sid), str(social.value)) if not user_social: - new_user_social = CreateUserSocialParam(source=social.value, sid=str(social_id), user_id=sys_user_id) - await user_social_dao.create(db, new_user_social) + if email: + sys_user = await user_dao.check_email(db, email) + + # 创建系统用户 + if not sys_user: + while await user_dao.get_by_username(db, username): + username = f'{username}_{text_captcha(5)}' + new_sys_user = AddOAuth2UserParam( + username=username, + password='123456', # 默认密码,可修改系统用户表进行默认密码检测并配合前端进行修改密码提示 + nickname=nickname, + email=email, + avatar=avatar, + ) + await user_dao.add_by_oauth2(db, new_sys_user) + await db.flush() + sys_user = await user_dao.get_by_username(db, username) + + # 绑定社交用户 + new_user_social = CreateUserSocialParam(sid=str(sid), source=social.value, user_id=sys_user.id) + await user_social_dao.create(db, new_user_social) + + if not sys_user: + sys_user = await user_dao.get(db, user_social.user_id) + if avatar: + await user_dao.update_avatar(db, sys_user.id, avatar) + # 创建 token access_token = await jwt.create_access_token( - str(sys_user_id), + str(sys_user.id), sys_user.is_multi_login, # extra info username=sys_user.username, - nickname=sys_user.nickname, + nickname=sys_user.nickname or f'#{text_captcha(5)}', last_login_time=timezone.t_str(timezone.now()), ip=request.state.ip, os=request.state.os, browser=request.state.browser, device=request.state.device, ) - refresh_token = await jwt.create_refresh_token(str(sys_user_id), multi_login=sys_user.is_multi_login) + refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login) await user_dao.update_login_time(db, sys_user.username) await db.refresh(sys_user) login_log = dict( @@ -115,8 +124,8 @@ async def create_with_login( data = GetLoginToken( access_token=access_token.access_token, access_token_expire_time=access_token.access_token_expire_time, - user=sys_user, # type: ignore session_uuid=access_token.session_uuid, + user=sys_user, # type: ignore ) return data