Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions backend/app/admin/crud/crud_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam)
input_user.roles = roles.scalars().all()
return count

async def update_avatar(self, db: AsyncSession, user_id: int, avatar: str) -> int:
"""
更新用户头像

:param db: 数据库会话
:param user_id: 用户 ID
:param avatar: 头像地址
:return:
"""
return await self.update_model(db, user_id, {'avatar': avatar})

async def delete(self, db: AsyncSession, user_id: int) -> int:
"""
删除用户
Expand Down
3 changes: 2 additions & 1 deletion backend/app/admin/schema/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class AddOAuth2UserParam(AuthSchemaBase):

nickname: str | None = Field(None, description='昵称')
email: EmailStr = Field(description='邮箱')
avatar: HttpUrl | None = Field(None, description='头像地址')


class ResetPasswordParam(SchemaBase):
Expand All @@ -54,7 +55,7 @@ class UserInfoSchemaBase(SchemaBase):
dept_id: int | None = Field(None, description='部门 ID')
username: str = Field(description='用户名')
nickname: str = Field(description='昵称')
avatar: HttpUrl | None = Field(None, description='头像')
avatar: HttpUrl | None = Field(None, description='头像地址')


class UpdateUserParam(UserInfoSchemaBase):
Expand Down
15 changes: 13 additions & 2 deletions backend/plugin/oauth2/crud/crud_user_social.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class CRUDUserSocial(CRUDPlus[UserSocial]):
"""用户社交账号数据库操作类"""

async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None:
async def check_binding(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None:
"""
获取用户社交账号绑定详情
检查系统用户社交账号绑定

:param db: 数据库会话
:param pk: 用户 ID
Expand All @@ -21,6 +21,17 @@ async def get(self, db: AsyncSession, pk: int, source: str) -> UserSocial | None
"""
return await self.select_model_by_column(db, user_id=pk, source=source)

async def get_by_sid(self, db: AsyncSession, sid: str, source: str) -> UserSocial | None:
"""
通过 UUID 获取社交用户

:param db: 数据库会话
:param sid: 第三方 UUID
:param source: 社交账号类型
:return:
"""
return await self.select_model_by_column(db, sid=sid, source=source)

async def create(self, db: AsyncSession, obj: CreateUserSocialParam) -> None:
"""
创建用户社交账号绑定
Expand Down
10 changes: 2 additions & 8 deletions backend/plugin/oauth2/model/user_social.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ class UserSocial(Base):
__tablename__ = 'sys_user_social'

id: Mapped[id_key] = mapped_column(init=False)
sid: Mapped[str] = mapped_column(String(20), comment='第三方用户 ID')
source: Mapped[str] = mapped_column(String(20), comment='第三方用户来源')
open_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 open id')
sid: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 ID')
union_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户 union id')
scope: Mapped[str | None] = mapped_column(String(120), default=None, comment='第三方用户授予的权限')
code: Mapped[str | None] = mapped_column(String(50), default=None, comment='用户的授权 code')

# 用户社交信息一对多
user_id: Mapped[int | None] = mapped_column(
ForeignKey('sys_user.id', ondelete='SET NULL'), default=None, comment='用户关联ID'
)
user_id: Mapped[int] = mapped_column(ForeignKey('sys_user.id', ondelete='CASCADE'), comment='用户关联ID')
user: Mapped[User | None] = relationship(init=False, backref='socials')
2 changes: 1 addition & 1 deletion backend/plugin/oauth2/plugin.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[plugin]
summary = 'OAuth 2.0'
version = '0.0.1'
version = '0.0.2'
description = '通过 OAuth 2.0 的方式登录系统'
author = 'wu-clan'

Expand Down
6 changes: 1 addition & 5 deletions backend/plugin/oauth2/schema/user_social.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
class UserSocialSchemaBase(SchemaBase):
"""用户社交基础模型"""

sid: str = Field(description='第三方用户 ID')
source: UserSocialType = Field(description='社交平台')
open_id: str | None = Field(None, description='开放平台 ID')
sid: str | None = Field(None, description='第三方用户 ID')
union_id: str | None = Field(None, description='开放平台唯一 ID')
scope: str | None = Field(None, description='授权范围')
code: str | None = Field(None, description='授权码')


class CreateUserSocialParam(UserSocialSchemaBase):
Expand Down
79 changes: 44 additions & 35 deletions backend/plugin/oauth2/service/oauth2_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

from backend.app.admin.crud.crud_user import user_dao
from backend.app.admin.schema.token import GetLoginToken
from backend.app.admin.schema.user import RegisterUserParam
from backend.app.admin.schema.user import AddOAuth2UserParam
from backend.app.admin.service.login_log_service import login_log_service
from backend.common.enums import LoginLogStatusType, UserSocialType
from backend.common.exception.errors import AuthorizationError
from backend.common.security import jwt
from backend.core.conf import settings
from backend.database.db import async_db_session
Expand Down Expand Up @@ -43,55 +42,65 @@ async def create_with_login(
:return:
"""
async with async_db_session.begin() as db:
# 获取 OAuth2 平台用户信息
social_id = user.get('id')
social_nickname = user.get('name')
sid = user.get('uuid')
username = user.get('username')
nickname = user.get('nickname')
email = user.get('email')
avatar = user.get('avatar_url')

social_username = user.get('username')
if social == UserSocialType.github:
social_username = user.get('login')
sid = user.get('id')
username = user.get('login')
nickname = user.get('name')

social_email = user.get('email')
if social == UserSocialType.linux_do:
social_email = f'{social_username}@linux.do'
if not social_email:
raise AuthorizationError(msg=f'授权失败,{social.value} 账户未绑定邮箱')
sid = user.get('id')
nickname = user.get('name')

# 创建系统用户
sys_user = await user_dao.check_email(db, social_email)
if not sys_user:
sys_user = await user_dao.get_by_username(db, social_username)
if sys_user:
social_username = f'{social_username}#{text_captcha(5)}'
sys_user = await user_dao.get_by_nickname(db, social_nickname)
if sys_user:
social_username = f'{social_nickname}#{text_captcha(5)}'
new_sys_user = RegisterUserParam(
username=social_username, password=None, nickname=social_username, email=social_email
)
await user_dao.create(db, new_sys_user, social=True)
await db.flush()
sys_user = await user_dao.check_email(db, social_email)
# 绑定社交用户
sys_user_id = sys_user.id
user_social = await user_social_dao.get(db, sys_user_id, social.value)
sys_user = None
user_social = await user_social_dao.get_by_sid(db, str(sid), str(social.value))
if not user_social:
new_user_social = CreateUserSocialParam(source=social.value, sid=str(social_id), user_id=sys_user_id)
await user_social_dao.create(db, new_user_social)
if email:
sys_user = await user_dao.check_email(db, email)

# 创建系统用户
if not sys_user:
while await user_dao.get_by_username(db, username):
username = f'{username}_{text_captcha(5)}'
new_sys_user = AddOAuth2UserParam(
username=username,
password='123456', # 默认密码,可修改系统用户表进行默认密码检测并配合前端进行修改密码提示
nickname=nickname,
email=email,
avatar=avatar,
)
await user_dao.add_by_oauth2(db, new_sys_user)
await db.flush()
sys_user = await user_dao.get_by_username(db, username)

# 绑定社交用户
new_user_social = CreateUserSocialParam(sid=str(sid), source=social.value, user_id=sys_user.id)
await user_social_dao.create(db, new_user_social)

if not sys_user:
sys_user = await user_dao.get(db, user_social.user_id)
if avatar:
await user_dao.update_avatar(db, sys_user.id, avatar)

# 创建 token
access_token = await jwt.create_access_token(
str(sys_user_id),
str(sys_user.id),
sys_user.is_multi_login,
# extra info
username=sys_user.username,
nickname=sys_user.nickname,
nickname=sys_user.nickname or f'#{text_captcha(5)}',
last_login_time=timezone.t_str(timezone.now()),
ip=request.state.ip,
os=request.state.os,
browser=request.state.browser,
device=request.state.device,
)
refresh_token = await jwt.create_refresh_token(str(sys_user_id), multi_login=sys_user.is_multi_login)
refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login)
await user_dao.update_login_time(db, sys_user.username)
await db.refresh(sys_user)
login_log = dict(
Expand All @@ -115,8 +124,8 @@ async def create_with_login(
data = GetLoginToken(
access_token=access_token.access_token,
access_token_expire_time=access_token.access_token_expire_time,
user=sys_user, # type: ignore
session_uuid=access_token.session_uuid,
user=sys_user, # type: ignore
)
return data

Expand Down