diff --git a/.gitignore b/.gitignore index d210b77b2..779ee46ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ .idea/ .vscode/ +.cursor/ +.DS_Store venv/ .venv/ .python-version diff --git a/README.md b/README.md index e7aa1c4cf..8cbcbefe8 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,7 @@ English | [简体中文](./README.zh-CN.md) -A backend and frontend separation solution based on the FastAPI framework, following -the [pseudo 3-tier architecture](#pseudo-3-tier-architecture) design, supporting **Python 3.10** and above -versions +Enterprise-level backend architecture solution **🔥Continuously updated and maintained🔥** @@ -49,37 +47,39 @@ pattern, use templates to transform it to your heart's content! ## Features -- [x] Design with FastAPI PEP 593 Annotated Parameters -- [x] Global asynchronous design with async/await + asgiref -- [x] Follows Restful API specification -- [x] Global SQLAlchemy 2.0 syntax -- [x] Pydantic v1 and v2 (different branches) -- [x] Casbin RBAC access control model -- [x] Role menu RBAC access control model -- [x] Celery asynchronous tasks -- [x] JWT middleware whitelist authentication -- [x] Global customizable time zone time -- [x] Docker / Docker-compose deployment -- [x] Pytest Unit Testing - -## Built-in features - -- [x] User management: System User Role Management, Permission Allocation -- [x] Department management: Configure system organization (company, department, team...) -- [x] Menu management: Configure system menu, user menu, button permission tags -- [x] Role management: role menu permission allocation, role route permission allocation -- [x] Dictionary management: Maintain commonly used fixed data or parameters within the system -- [x] Token management: System user online status detection, supports kicking users offline -- [x] Login authentication: backend-based graphical captcha background authentication login -- [x] Multipoint login: One-click modification of multipoint login through user information -- [x] OAuth 2.0: Built-in self-developed OAuth 2.0 login integration -- [x] Code generation: automatic backend code generation, supports preview, writing, and download -- [x] Scheduled task: Automated task, asynchronous task, supports function calls -- [x] Plugin system: Say goodbye to high coupling integration through hot-pluggable plugin mode -- [x] Operation log: Record and query of system normal and abnormal operations -- [x] Login log: Record and query of normal and abnormal user login -- [x] Service monitoring: Server hardware device information and status -- [x] API documentation: Automatically generate online interactive API documentation +- [x] Global FastAPI PEP 593 Annotated parameter style +- [x] Comprehensive async/await + asgiref asynchronous design +- [x] Adheres to RESTful API specifications +- [x] Uses SQLAlchemy 2.0 with new syntax +- [x] Uses Pydantic v2 version +- [x] Implements role-menu RBAC access control +- [x] Integrates Casbin RBAC access control +- [x] Supports Celery asynchronous tasks +- [x] Custom-developed JWT authentication middleware +- [x] Supports global custom time zones +- [x] Supports Docker / Docker-compose deployment +- [x] Integrates Pytest unit testing + +## Built-in Functions + +- [x] User Management: Assign roles and permissions +- [x] Department Management: Configure organizational structure (company, department, team, etc.) +- [x] Menu Management: Set up menus and button-level permissions +- [x] Role Management: Configure roles, assign menus and permissions +- [x] Dictionary Management: Maintain common parameters and configurations +- [x] Parameter Management: Dynamically configure commonly used system parameters +- [x] Notification Announcements: Publish and maintain system notification and announcement information +- [x] Token Management: Detect online status, support forced logout +- [x] Multi-device Login: Support one-click switching between multi-device login modes +- [x] OAuth 2.0: Built-in custom-developed OAuth 2.0 authorization login +- [x] Plugin System: Hot-swappable plugin design to reduce coupling +- [x] Scheduled Tasks: Support scheduled, asynchronous tasks, and function calls +- [x] Code Generation: Automatically generate code with preview, write, and download support +- [x] Operation Logs: Record and query normal and abnormal operations +- [x] Login Logs: Record and query normal and abnormal logins +- [x] Cache Monitoring: Query system cache information and command statistics +- [x] Service Monitoring: View server hardware information and status +- [x] API Documentation: Automatically generate online interactive API documentation ## Development and deployment @@ -103,7 +103,7 @@ the [official documentation](https://fastapi-practices.github.io/fastapi_best_ar ## Interactivity -[TG / Discord](https://wu-clan.github.io/homepage/) +[Discord](https://wu-clan.github.io/homepage/) ## Sponsor us diff --git a/README.zh-CN.md b/README.zh-CN.md index 209e62c7f..49ffe3b20 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -6,7 +6,7 @@ 简体中文 | [English](./README.md) -基于 FastAPI 框架的前后端分离中后台解决方案,遵循[伪三层架构](#伪三层架构)设计, 支持 **python3.10** 及以上版本 +企业级后端架构解决方案 **🔥持续更新维护中🔥** @@ -43,38 +43,40 @@ mvc 架构作为常规设计模式,在 python web 中也很常见,但是三 | 数据访问 | dao / mapper | crud | | 模型 | model / entity | model | -## 特征 +## 特性 - [x] 全局 FastAPI PEP 593 Annotated 参数风格 -- [x] async/await + asgiref 的全局异步设计 -- [x] 遵循 Restful API 规范 -- [x] 全局 SQLAlchemy 2.0 语法 -- [x] Pydantic v1 和 v2 (不同分支) -- [x] Casbin RBAC 访问控制模型 -- [x] 角色菜单 RBAC 访问控制模型 -- [x] Celery 异步任务 -- [x] JWT 中间件白名单认证 -- [x] 全局自定义时区时间 -- [x] Docker / Docker-compose 部署 -- [x] Pytest 单元测试 +- [x] 全面 async/await + asgiref 异步设计 +- [x] 遵循 RESTful API 规范 +- [x] 使用 SQLAlchemy 2.0 全新语法 +- [x] 使用 Pydantic v2 版本 +- [x] 实现角色菜单 RBAC 访问控制 +- [x] 集成 Casbin RBAC 访问控制 +- [x] 支持 Celery 异步任务 +- [x] 自研 JWT 认证中间件 +- [x] 支持全局自定义时间时区 +- [x] 支持 Docker / Docker-compose 部署 +- [x] 集成 Pytest 单元测试 ## 内置功能 -- [x] 用户管理:系统用户角色管理,权限分配 -- [x] 部门管理:配置系统组织机构(公司、部门、小组...) -- [x] 菜单管理:配置系统菜单,用户菜单,按钮权限标识 -- [x] 角色管理:角色菜单权限分配,角色路由权限分配 -- [x] 字典管理:维护系统内部常用固定数据或参数 -- [x] 令牌管理:系统用户在线状态检测,支持踢人下线 -- [x] 登录认证:基于后端的图形验证码后台认证登录 -- [x] 多点登录:通过用户信息一键修改多点登录支持 -- [x] OAuth20:内置自研 OAuth 2.0 登录集成 -- [x] 代码生成:后端代码自动生成,支持预览,写入及下载 -- [x] 定时任务:自动化任务,异步任务,支持函数调用 -- [x] 插件系统:通过热插拔插件模式告别高耦合集成 -- [x] 操作日志:系统正常和异常操作的日志记录与查询 -- [x] 登录日志:用户正常和异常登录的日志记录与查询 -- [x] 服务监控:服务器硬件设备信息与状态 +- [x] 用户管理:分配角色和权限 +- [x] 部门管理:配置组织架构(公司、部门、小组等) +- [x] 菜单管理:设置菜单及按钮级权限 +- [x] 角色管理:配置角色、分配菜单和权限 +- [x] 字典管理:维护常用参数和配置 +- [x] 参数管理:系统常用参数动态配置 +- [x] 通知公告:发布和维护系统通知公告信息 +- [x] 令牌管理:检测在线状态,支持强制下线 +- [x] 多端登录:支持一键切换多端登录模式 +- [x] OAuth 2.0:内置自研 OAuth 2.0 授权登录 +- [x] 插件系统:热插拔插件设计,降低耦合 +- [x] 定时任务:支持定时,异步任务及函数调用 +- [x] 代码生成:自动生成代码,支持预览、写入和下载 +- [x] 操作日志:记录和查询正常和异常操作 +- [x] 登录日志:记录和查询正常和异常登录 +- [x] 缓存监控:查询系统缓存信息和命令统计 +- [x] 服务监控:查看服务器硬件信息和状态 - [x] 接口文档:自动生成在线交互式 API 文档 ## 开发部署 @@ -98,7 +100,7 @@ mvc 架构作为常规设计模式,在 python web 中也很常见,但是三 ## 互动 -[TG / Discord](https://wu-clan.github.io/homepage/) +[Discord](https://wu-clan.github.io/homepage/) ## 赞助我们 diff --git a/backend/app/admin/api/v1/auth/auth.py b/backend/app/admin/api/v1/auth/auth.py index 0cfdba7a7..c2c588aea 100644 --- a/backend/app/admin/api/v1/auth/auth.py +++ b/backend/app/admin/api/v1/auth/auth.py @@ -36,7 +36,7 @@ async def user_login( @router.post('/token/new', summary='创建新 token') -async def create_new_token(request: Request, response: Response) -> ResponseSchemaModel[GetNewToken]: +async def create_new_token(request: Request) -> ResponseSchemaModel[GetNewToken]: data = await auth_service.new_token(request=request) return response_base.success(data=data) diff --git a/backend/app/admin/api/v1/log/login_log.py b/backend/app/admin/api/v1/log/login_log.py index 13e246b7a..c2c86b641 100644 --- a/backend/app/admin/api/v1/log/login_log.py +++ b/backend/app/admin/api/v1/log/login_log.py @@ -18,7 +18,7 @@ @router.get( '', - summary='(模糊条件)分页获取登录日志', + summary='分页获取登录日志', dependencies=[ DependsJwtAuth, DependsPagination, @@ -26,9 +26,9 @@ ) async def get_pagination_login_logs( db: CurrentSession, - username: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, - ip: Annotated[str | None, Query()] = None, + username: Annotated[str | None, Query(description='用户名')] = None, + status: Annotated[int | None, Query(description='状态')] = None, + ip: Annotated[str | None, Query(description='IP 地址')] = None, ) -> ResponseSchemaModel[PageData[GetLoginLogDetail]]: log_select = await login_log_service.get_select(username=username, status=status, ip=ip) page_data = await paging_data(db, log_select) @@ -37,13 +37,13 @@ async def get_pagination_login_logs( @router.delete( '', - summary='(批量)删除登录日志', + summary='批量删除登录日志', dependencies=[ Depends(RequestPermission('log:login:del')), DependsRBAC, ], ) -async def delete_login_log(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_login_log(pk: Annotated[list[int], Query(description='登录日志 ID 列表')]) -> ResponseModel: count = await login_log_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/log/opera_log.py b/backend/app/admin/api/v1/log/opera_log.py index 4b7b8c41e..e1d7c6ac1 100644 --- a/backend/app/admin/api/v1/log/opera_log.py +++ b/backend/app/admin/api/v1/log/opera_log.py @@ -18,7 +18,7 @@ @router.get( '', - summary='(模糊条件)分页获取操作日志', + summary='分页获取操作日志', dependencies=[ DependsJwtAuth, DependsPagination, @@ -26,9 +26,9 @@ ) async def get_pagination_opera_logs( db: CurrentSession, - username: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, - ip: Annotated[str | None, Query()] = None, + username: Annotated[str | None, Query(description='用户名')] = None, + status: Annotated[int | None, Query(description='状态')] = None, + ip: Annotated[str | None, Query(description='IP 地址')] = None, ) -> ResponseSchemaModel[PageData[GetOperaLogDetail]]: log_select = await opera_log_service.get_select(username=username, status=status, ip=ip) page_data = await paging_data(db, log_select) @@ -37,13 +37,13 @@ async def get_pagination_opera_logs( @router.delete( '', - summary='(批量)删除操作日志', + summary='批量删除操作日志', dependencies=[ Depends(RequestPermission('log:opera:del')), DependsRBAC, ], ) -async def delete_opera_log(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_opera_log(pk: Annotated[list[int], Query(description='操作日志 ID 列表')]) -> ResponseModel: count = await opera_log_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/monitor/redis.py b/backend/app/admin/api/v1/monitor/redis.py index 23a885882..8f93641ae 100644 --- a/backend/app/admin/api/v1/monitor/redis.py +++ b/backend/app/admin/api/v1/monitor/redis.py @@ -19,5 +19,8 @@ ], ) async def get_redis_info() -> ResponseModel: - data = {'info': await redis_info.get_info(), 'stats': await redis_info.get_stats()} + data = { + 'info': await redis_info.get_info(), + 'stats': await redis_info.get_stats(), + } return response_base.success(data=data) diff --git a/backend/app/admin/api/v1/oauth2/linux_do.py b/backend/app/admin/api/v1/oauth2/linux_do.py index 41257fc65..1b1ee7010 100644 --- a/backend/app/admin/api/v1/oauth2/linux_do.py +++ b/backend/app/admin/api/v1/oauth2/linux_do.py @@ -19,7 +19,7 @@ _linux_do_oauth2 = FastAPIOAuth20(_linux_do_client, admin_settings.OAUTH2_LINUX_DO_REDIRECT_URI) -@router.get('', summary='获取 Linux Do 授权链接') +@router.get('', summary='获取 LinuxDo 授权链接') async def linux_do_auth2() -> ResponseSchemaModel[str]: auth_url = await _linux_do_client.get_authorization_url(redirect_uri=admin_settings.OAUTH2_LINUX_DO_REDIRECT_URI) return response_base.success(data=auth_url) @@ -27,8 +27,8 @@ async def linux_do_auth2() -> ResponseSchemaModel[str]: @router.get( '/callback', - summary='Linux Do 授权自动重定向', - description='Linux Do 授权后,自动重定向到当前地址并获取用户信息,通过用户信息自动创建系统用户', + summary='LinuxDo 授权自动重定向', + description='LinuxDo 授权后,自动重定向到当前地址并获取用户信息,通过用户信息自动创建系统用户', dependencies=[Depends(RateLimiter(times=5, minutes=1))], ) async def linux_do_login( @@ -45,6 +45,6 @@ async def linux_do_login( response=response, background_tasks=background_tasks, user=user, - social=UserSocialType.linuxdo, + social=UserSocialType.linux_do, ) return RedirectResponse(url=f'{admin_settings.OAUTH2_FRONTEND_REDIRECT_URI}?access_token={data.access_token}') diff --git a/backend/app/admin/api/v1/sys/__init__.py b/backend/app/admin/api/v1/sys/__init__.py index 234479dd9..0263bcec2 100644 --- a/backend/app/admin/api/v1/sys/__init__.py +++ b/backend/app/admin/api/v1/sys/__init__.py @@ -16,7 +16,7 @@ router = APIRouter(prefix='/sys') -router.include_router(config_router, prefix='/configs', tags=['系统配置']) +router.include_router(config_router, prefix='/configs', tags=['系统参数配置']) router.include_router(dept_router, prefix='/depts', tags=['系统部门']) router.include_router(dict_data_router, prefix='/dict-datas', tags=['系统字典数据']) router.include_router(dict_type_router, prefix='/dict-types', tags=['系统字典类型']) diff --git a/backend/app/admin/api/v1/sys/config.py b/backend/app/admin/api/v1/sys/config.py index 7a65447b1..a3b6b0543 100644 --- a/backend/app/admin/api/v1/sys/config.py +++ b/backend/app/admin/api/v1/sys/config.py @@ -21,7 +21,7 @@ router = APIRouter() -@router.get('/website', summary='获取网站配置信息', dependencies=[DependsJwtAuth]) +@router.get('/website', summary='获取网站参数配置', dependencies=[DependsJwtAuth]) async def get_website_config() -> ResponseSchemaModel[list[GetConfigDetail]]: config = await config_service.get_built_in_config('website') return response_base.success(data=config) @@ -29,7 +29,7 @@ async def get_website_config() -> ResponseSchemaModel[list[GetConfigDetail]]: @router.post( '/website', - summary='保存网站配置信息', + summary='保存网站参数配置', dependencies=[ Depends(RequestPermission('sys:config:website:add')), DependsRBAC, @@ -78,23 +78,23 @@ async def save_policy_config(objs: list[SaveBuiltInConfigParam]) -> ResponseMode return response_base.success() -@router.get('/{pk}', summary='获取系统参数配置详情', dependencies=[DependsJwtAuth]) -async def get_config(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetConfigDetail]: +@router.get('/{pk}', summary='获取参数配置详情', dependencies=[DependsJwtAuth]) +async def get_config(pk: Annotated[int, Path(description='参数配置 ID')]) -> ResponseSchemaModel[GetConfigDetail]: config = await config_service.get(pk) return response_base.success(data=config) @router.get( '', - summary='(模糊条件)分页获取所有系统参数配置', + summary='分页获取所有参数配置', dependencies=[ DependsJwtAuth, DependsPagination, ], ) -async def get_pagination_config( +async def get_pagination_configs( db: CurrentSession, - name: Annotated[str | None, Query()] = None, + name: Annotated[str | None, Query(description='参数配置名称')] = None, type: Annotated[str | None, Query()] = None, ) -> ResponseSchemaModel[PageData[GetConfigDetail]]: config_select = await config_service.get_select(name=name, type=type) @@ -104,7 +104,7 @@ async def get_pagination_config( @router.post( '', - summary='创建系统参数配置', + summary='创建参数配置', dependencies=[ Depends(RequestPermission('sys:config:add')), DependsRBAC, @@ -117,13 +117,13 @@ async def create_config(obj: CreateConfigParam) -> ResponseModel: @router.put( '/{pk}', - summary='更新系统参数配置', + summary='更新参数配置', dependencies=[ Depends(RequestPermission('sys:config:edit')), DependsRBAC, ], ) -async def update_config(pk: Annotated[int, Path(...)], obj: UpdateConfigParam) -> ResponseModel: +async def update_config(pk: Annotated[int, Path(description='参数配置 ID')], obj: UpdateConfigParam) -> ResponseModel: count = await config_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -132,13 +132,13 @@ async def update_config(pk: Annotated[int, Path(...)], obj: UpdateConfigParam) - @router.delete( '', - summary='(批量)删除系统参数配置', + summary='批量删除参数配置', dependencies=[ Depends(RequestPermission('sys:config:del')), DependsRBAC, ], ) -async def delete_config(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_config(pk: Annotated[list[int], Query(description='参数配置 ID 列表')]) -> ResponseModel: count = await config_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/data_rule.py b/backend/app/admin/api/v1/sys/data_rule.py index 25c72c9df..a4c487c8e 100644 --- a/backend/app/admin/api/v1/sys/data_rule.py +++ b/backend/app/admin/api/v1/sys/data_rule.py @@ -23,33 +23,37 @@ async def get_data_rule_models() -> ResponseSchemaModel[list[str]]: @router.get('/model/{model}/columns', summary='获取支持过滤的数据库模型列', dependencies=[DependsJwtAuth]) -async def get_data_rule_model_columns(model: Annotated[str, Path()]) -> ResponseSchemaModel[list[str]]: +async def get_data_rule_model_columns( + model: Annotated[str, Path(description='模型名称')], +) -> ResponseSchemaModel[list[str]]: models = await data_rule_service.get_columns(model=model) return response_base.success(data=models) @router.get('/all', summary='获取所有数据规则', dependencies=[DependsJwtAuth]) -async def get_all_data_rule() -> ResponseSchemaModel[list[GetDataRuleDetail]]: +async def get_all_data_rules() -> ResponseSchemaModel[list[GetDataRuleDetail]]: data = await data_rule_service.get_all() return response_base.success(data=data) @router.get('/{pk}', summary='获取数据权限规则详情', dependencies=[DependsJwtAuth]) -async def get_data_rule(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetDataRuleDetail]: +async def get_data_rule( + pk: Annotated[int, Path(description='数据规则 ID')], +) -> ResponseSchemaModel[GetDataRuleDetail]: data = await data_rule_service.get(pk=pk) return response_base.success(data=data) @router.get( '', - summary='(模糊条件)分页获取所有数据权限规则', + summary='分页获取所有数据权限规则', dependencies=[ DependsJwtAuth, DependsPagination, ], ) -async def get_pagination_data_rule( - db: CurrentSession, name: Annotated[str | None, Query()] = None +async def get_pagination_data_rules( + db: CurrentSession, name: Annotated[str | None, Query(description='规则名称')] = None ) -> ResponseSchemaModel[PageData[GetDataRuleDetail]]: data_rule_select = await data_rule_service.get_select(name=name) page_data = await paging_data(db, data_rule_select) @@ -77,7 +81,9 @@ async def create_data_rule(obj: CreateDataRuleParam) -> ResponseModel: DependsRBAC, ], ) -async def update_data_rule(pk: Annotated[int, Path(...)], obj: UpdateDataRuleParam) -> ResponseModel: +async def update_data_rule( + pk: Annotated[int, Path(description='数据规则 ID')], obj: UpdateDataRuleParam +) -> ResponseModel: count = await data_rule_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -86,13 +92,15 @@ async def update_data_rule(pk: Annotated[int, Path(...)], obj: UpdateDataRulePar @router.delete( '', - summary='(批量)删除数据权限规则', + summary='批量删除数据权限规则', dependencies=[ Depends(RequestPermission('data:rule:del')), DependsRBAC, ], ) -async def delete_data_rule(request: Request, pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_data_rule( + request: Request, pk: Annotated[list[int], Query(description='数据规则 ID 列表')] +) -> ResponseModel: count = await data_rule_service.delete(request=request, pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/dept.py b/backend/app/admin/api/v1/sys/dept.py index 5e99c6c4a..03267c20e 100644 --- a/backend/app/admin/api/v1/sys/dept.py +++ b/backend/app/admin/api/v1/sys/dept.py @@ -15,17 +15,17 @@ @router.get('/{pk}', summary='获取部门详情', dependencies=[DependsJwtAuth]) -async def get_dept(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetDeptDetail]: +async def get_dept(pk: Annotated[int, Path(description='部门 ID')]) -> ResponseSchemaModel[GetDeptDetail]: data = await dept_service.get(pk=pk) return response_base.success(data=data) @router.get('', summary='获取所有部门展示树', dependencies=[DependsJwtAuth]) -async def get_all_depts_tree( - name: Annotated[str | None, Query()] = None, - leader: Annotated[str | None, Query()] = None, - phone: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, +async def get_all_depts( + name: Annotated[str | None, Query(description='部门名称')] = None, + leader: Annotated[str | None, Query(description='部门负责人')] = None, + phone: Annotated[str | None, Query(description='联系电话')] = None, + status: Annotated[int | None, Query(description='状态')] = None, ) -> ResponseSchemaModel[list[dict[str, Any]]]: dept = await dept_service.get_dept_tree(name=name, leader=leader, phone=phone, status=status) return response_base.success(data=dept) @@ -52,7 +52,7 @@ async def create_dept(obj: CreateDeptParam) -> ResponseModel: DependsRBAC, ], ) -async def update_dept(pk: Annotated[int, Path(...)], obj: UpdateDeptParam) -> ResponseModel: +async def update_dept(pk: Annotated[int, Path(description='部门 ID')], obj: UpdateDeptParam) -> ResponseModel: count = await dept_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -67,7 +67,7 @@ async def update_dept(pk: Annotated[int, Path(...)], obj: UpdateDeptParam) -> Re DependsRBAC, ], ) -async def delete_dept(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +async def delete_dept(request: Request, pk: Annotated[int, Path(description='部门 ID')]) -> ResponseModel: count = await dept_service.delete(request=request, pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/dict_data.py b/backend/app/admin/api/v1/sys/dict_data.py index afb50b893..6fd1787c7 100644 --- a/backend/app/admin/api/v1/sys/dict_data.py +++ b/backend/app/admin/api/v1/sys/dict_data.py @@ -22,14 +22,16 @@ @router.get('/{pk}', summary='获取字典详情', dependencies=[DependsJwtAuth]) -async def get_dict_data(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetDictDataWithRelation]: +async def get_dict_data( + pk: Annotated[int, Path(description='字典数据 ID')], +) -> ResponseSchemaModel[GetDictDataWithRelation]: data = await dict_data_service.get(pk=pk) return response_base.success(data=data) @router.get( '', - summary='(模糊条件)分页获取所有字典', + summary='分页获取所有字典', dependencies=[ DependsJwtAuth, DependsPagination, @@ -37,9 +39,9 @@ async def get_dict_data(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[Ge ) async def get_pagination_dict_datas( db: CurrentSession, - label: Annotated[str | None, Query()] = None, - value: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, + label: Annotated[str | None, Query(description='字典数据标签')] = None, + value: Annotated[str | None, Query(description='字典数据键值')] = None, + status: Annotated[int | None, Query(description='状态')] = None, ) -> ResponseSchemaModel[PageData[GetDictDataDetail]]: dict_data_select = await dict_data_service.get_select(label=label, value=value, status=status) page_data = await paging_data(db, dict_data_select) @@ -67,7 +69,9 @@ async def create_dict_data(obj: CreateDictDataParam) -> ResponseModel: DependsRBAC, ], ) -async def update_dict_data(pk: Annotated[int, Path(...)], obj: UpdateDictDataParam) -> ResponseModel: +async def update_dict_data( + pk: Annotated[int, Path(description='字典数据 ID')], obj: UpdateDictDataParam +) -> ResponseModel: count = await dict_data_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -76,13 +80,13 @@ async def update_dict_data(pk: Annotated[int, Path(...)], obj: UpdateDictDataPar @router.delete( '', - summary='(批量)删除字典', + summary='批量删除字典', dependencies=[ Depends(RequestPermission('sys:dict:data:del')), DependsRBAC, ], ) -async def delete_dict_data(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_dict_data(pk: Annotated[list[int], Query(description='字典数据 ID 列表')]) -> ResponseModel: count = await dict_data_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/dict_type.py b/backend/app/admin/api/v1/sys/dict_type.py index 30ec3fddf..44ba7f8fc 100644 --- a/backend/app/admin/api/v1/sys/dict_type.py +++ b/backend/app/admin/api/v1/sys/dict_type.py @@ -18,7 +18,7 @@ @router.get( '', - summary='(模糊条件)分页获取所有字典类型', + summary='分页获取所有字典类型', dependencies=[ DependsJwtAuth, DependsPagination, @@ -26,9 +26,9 @@ ) async def get_pagination_dict_types( db: CurrentSession, - name: Annotated[str | None, Query()] = None, - code: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, + name: Annotated[str | None, Query(description='字典类型名称')] = None, + code: Annotated[str | None, Query(description='字典类型编码')] = None, + status: Annotated[int | None, Query(description='状态')] = None, ) -> ResponseSchemaModel[PageData[GetDictTypeDetail]]: dict_type_select = await dict_type_service.get_select(name=name, code=code, status=status) page_data = await paging_data(db, dict_type_select) @@ -56,7 +56,9 @@ async def create_dict_type(obj: CreateDictTypeParam) -> ResponseModel: DependsRBAC, ], ) -async def update_dict_type(pk: Annotated[int, Path(...)], obj: UpdateDictTypeParam) -> ResponseModel: +async def update_dict_type( + pk: Annotated[int, Path(description='字典类型 ID')], obj: UpdateDictTypeParam +) -> ResponseModel: count = await dict_type_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -65,13 +67,13 @@ async def update_dict_type(pk: Annotated[int, Path(...)], obj: UpdateDictTypePar @router.delete( '', - summary='(批量)删除字典类型', + summary='批量删除字典类型', dependencies=[ Depends(RequestPermission('sys:dict:type:del')), DependsRBAC, ], ) -async def delete_dict_type(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_dict_type(pk: Annotated[list[int], Query(description='字典类型 ID 列表')]) -> ResponseModel: count = await dict_type_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/menu.py b/backend/app/admin/api/v1/sys/menu.py index d29bef13d..6e1aaaa39 100644 --- a/backend/app/admin/api/v1/sys/menu.py +++ b/backend/app/admin/api/v1/sys/menu.py @@ -14,21 +14,22 @@ router = APIRouter() -@router.get('/sidebar', summary='获取用户菜单展示树', dependencies=[DependsJwtAuth]) -async def get_user_sidebar_tree(request: Request) -> ResponseSchemaModel[list[dict[str, Any]]]: +@router.get('/sidebar', summary='获取用户侧边栏', dependencies=[DependsJwtAuth]) +async def get_user_sidebar(request: Request) -> ResponseSchemaModel[list[dict[str, Any]]]: menu = await menu_service.get_user_menu_tree(request=request) return response_base.success(data=menu) @router.get('/{pk}', summary='获取菜单详情', dependencies=[DependsJwtAuth]) -async def get_menu(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetMenuDetail]: +async def get_menu(pk: Annotated[int, Path(description='菜单 ID')]) -> ResponseSchemaModel[GetMenuDetail]: data = await menu_service.get(pk=pk) return response_base.success(data=data) @router.get('', summary='获取所有菜单展示树', dependencies=[DependsJwtAuth]) async def get_all_menus( - title: Annotated[str | None, Query()] = None, status: Annotated[int | None, Query()] = None + title: Annotated[str | None, Query(description='菜单标题')] = None, + status: Annotated[int | None, Query(description='状体')] = None, ) -> ResponseSchemaModel[list[dict[str, Any]]]: menu = await menu_service.get_menu_tree(title=title, status=status) return response_base.success(data=menu) @@ -55,7 +56,7 @@ async def create_menu(obj: CreateMenuParam) -> ResponseModel: DependsRBAC, ], ) -async def update_menu(pk: Annotated[int, Path(...)], obj: UpdateMenuParam) -> ResponseModel: +async def update_menu(pk: Annotated[int, Path(description='菜单 ID')], obj: UpdateMenuParam) -> ResponseModel: count = await menu_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -70,7 +71,7 @@ async def update_menu(pk: Annotated[int, Path(...)], obj: UpdateMenuParam) -> Re DependsRBAC, ], ) -async def delete_menu(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +async def delete_menu(request: Request, pk: Annotated[int, Path(description='菜单 ID 列表')]) -> ResponseModel: count = await menu_service.delete(request=request, pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/plugin.py b/backend/app/admin/api/v1/sys/plugin.py index 2cfb0aa82..58c92389a 100644 --- a/backend/app/admin/api/v1/sys/plugin.py +++ b/backend/app/admin/api/v1/sys/plugin.py @@ -52,16 +52,17 @@ async def install_plugin(file: Annotated[UploadFile, File()]) -> ResponseModel: full_plugin_path = os.path.join(PLUGIN_DIR, plugin_name) if os.path.exists(full_plugin_path): raise errors.ForbiddenError(msg='此插件已安装') - os.makedirs(full_plugin_path) + else: + os.makedirs(full_plugin_path, exist_ok=True) - # 解压安装 + # 解压(安装) members = [] for member in zf.infolist(): if member.filename.startswith(plugin_dir_in_zip): - member.filename = member.filename.replace(plugin_dir_in_zip, '') - if not member.filename: - continue - members.append(member) + new_filename = member.filename.replace(plugin_dir_in_zip, '') + if new_filename: + member.filename = new_filename + members.append(member) zf.extractall(PLUGIN_DIR, members) if os.path.exists(os.path.join(full_plugin_path, 'requirements.txt')): await install_requirements_async() @@ -77,10 +78,11 @@ async def install_plugin(file: Annotated[UploadFile, File()]) -> ResponseModel: DependsRBAC, ], ) -async def build_plugin_zip(plugin: Annotated[str, Query()]): +async def build_plugin(plugin: Annotated[str, Query(description='插件名称')]) -> StreamingResponse: plugin_dir = os.path.join(PLUGIN_DIR, plugin) if not os.path.exists(plugin_dir): raise errors.ForbiddenError(msg='插件不存在') + bio = io.BytesIO() with zipfile.ZipFile(bio, 'w') as zf: for root, dirs, files in os.walk(plugin_dir): @@ -89,6 +91,7 @@ async def build_plugin_zip(plugin: Annotated[str, Query()]): file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, start=plugin_dir) zf.write(file_path, arcname) + bio.seek(0) return StreamingResponse( bio, diff --git a/backend/app/admin/api/v1/sys/role.py b/backend/app/admin/api/v1/sys/role.py index f883363e6..bfd72f3c5 100644 --- a/backend/app/admin/api/v1/sys/role.py +++ b/backend/app/admin/api/v1/sys/role.py @@ -32,32 +32,38 @@ async def get_all_roles() -> ResponseSchemaModel[list[GetRoleDetail]]: @router.get('/{pk}/all', summary='获取用户所有角色', dependencies=[DependsJwtAuth]) -async def get_user_all_roles(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[list[GetRoleDetail]]: +async def get_user_all_roles( + pk: Annotated[int, Path(description='用户 ID')], +) -> ResponseSchemaModel[list[GetRoleDetail]]: data = await role_service.get_by_user(pk=pk) return response_base.success(data=data) @router.get('/{pk}/menus', summary='获取角色所有菜单', dependencies=[DependsJwtAuth]) -async def get_role_all_menus(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[list[dict[str, Any]]]: +async def get_role_all_menus( + pk: Annotated[int, Path(description='角色 ID')], +) -> ResponseSchemaModel[list[dict[str, Any]]]: menu = await menu_service.get_role_menu_tree(pk=pk) return response_base.success(data=menu) @router.get('/{pk}/rules', summary='获取角色所有数据规则', dependencies=[DependsJwtAuth]) -async def get_role_all_rules(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[list[int]]: +async def get_role_all_rules(pk: Annotated[int, Path(description='角色 ID')]) -> ResponseSchemaModel[list[int]]: rule = await data_rule_service.get_role_rules(pk=pk) return response_base.success(data=rule) @router.get('/{pk}', summary='获取角色详情', dependencies=[DependsJwtAuth]) -async def get_role(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetRoleWithRelationDetail]: +async def get_role( + pk: Annotated[int, Path(description='角色 ID')], +) -> ResponseSchemaModel[GetRoleWithRelationDetail]: data = await role_service.get(pk=pk) return response_base.success(data=data) @router.get( '', - summary='(模糊条件)分页获取所有角色', + summary='分页获取所有角色', dependencies=[ DependsJwtAuth, DependsPagination, @@ -65,8 +71,8 @@ async def get_role(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetRole ) async def get_pagination_roles( db: CurrentSession, - name: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, + name: Annotated[str | None, Query(description='角色名称')] = None, + status: Annotated[int | None, Query(description='状态')] = None, ) -> ResponseSchemaModel[PageData[GetRoleDetail]]: role_select = await role_service.get_select(name=name, status=status) page_data = await paging_data(db, role_select) @@ -94,7 +100,7 @@ async def create_role(obj: CreateRoleParam) -> ResponseModel: DependsRBAC, ], ) -async def update_role(pk: Annotated[int, Path(...)], obj: UpdateRoleParam) -> ResponseModel: +async def update_role(pk: Annotated[int, Path(description='角色 ID')], obj: UpdateRoleParam) -> ResponseModel: count = await role_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -110,7 +116,7 @@ async def update_role(pk: Annotated[int, Path(...)], obj: UpdateRoleParam) -> Re ], ) async def update_role_menus( - request: Request, pk: Annotated[int, Path(...)], menu_ids: UpdateRoleMenuParam + request: Request, pk: Annotated[int, Path(description='角色 ID')], menu_ids: UpdateRoleMenuParam ) -> ResponseModel: count = await role_service.update_role_menu(request=request, pk=pk, menu_ids=menu_ids) if count > 0: @@ -127,7 +133,7 @@ async def update_role_menus( ], ) async def update_role_rules( - request: Request, pk: Annotated[int, Path(...)], rule_ids: UpdateRoleRuleParam + request: Request, pk: Annotated[int, Path(description='角色 ID')], rule_ids: UpdateRoleRuleParam ) -> ResponseModel: count = await role_service.update_role_rule(request=request, pk=pk, rule_ids=rule_ids) if count > 0: @@ -137,13 +143,13 @@ async def update_role_rules( @router.delete( '', - summary='(批量)删除角色', + summary='批量删除角色', dependencies=[ Depends(RequestPermission('sys:role:del')), DependsRBAC, ], ) -async def delete_role(request: Request, pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_role(request: Request, pk: Annotated[list[int], Query(description='角色 ID 列表')]) -> ResponseModel: count = await role_service.delete(request=request, pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/admin/api/v1/sys/token.py b/backend/app/admin/api/v1/sys/token.py index d9207a6ff..38810afe8 100644 --- a/backend/app/admin/api/v1/sys/token.py +++ b/backend/app/admin/api/v1/sys/token.py @@ -9,7 +9,7 @@ from backend.app.admin.schema.token import GetTokenDetail, KickOutToken from backend.common.enums import StatusType from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base -from backend.common.security.jwt import DependsJwtAuth, jwt_decode, superuser_verify +from backend.common.security.jwt import DependsJwtAuth, jwt_decode, revoke_token, superuser_verify from backend.common.security.permission import RequestPermission from backend.common.security.rbac import DependsRBAC from backend.core.conf import settings @@ -19,10 +19,28 @@ @router.get('', summary='获取令牌列表', dependencies=[DependsJwtAuth]) -async def get_tokens(username: Annotated[str | None, Query()] = None) -> ResponseSchemaModel[list[GetTokenDetail]]: +async def get_tokens( + username: Annotated[str | None, Query(description='用户名')] = None, +) -> ResponseSchemaModel[list[GetTokenDetail]]: token_keys = await redis_client.keys(f'{settings.TOKEN_REDIS_PREFIX}:*') - token_online = await redis_client.smembers(settings.TOKEN_ONLINE_REDIS_PREFIX) - data = [] + online_clients = await redis_client.smembers(settings.TOKEN_ONLINE_REDIS_PREFIX) + data: list[GetTokenDetail] = [] + + def append_token_detail() -> None: + data.append( + token_detail.model_copy( + update={ + 'username': extra_info.get('username', '未知'), + 'nickname': extra_info.get('nickname', '未知'), + 'ip': extra_info.get('ip', '未知'), + 'os': extra_info.get('os', '未知'), + 'browser': extra_info.get('browser', '未知'), + 'device': extra_info.get('device', '未知'), + 'last_login_time': extra_info.get('last_login_time', '未知'), + } + ) + ) + for key in token_keys: token = await redis_client.get(key) token_payload = jwt_decode(token) @@ -36,31 +54,15 @@ async def get_tokens(username: Annotated[str | None, Query()] = None) -> Respons os='未知', browser='未知', device='未知', - status=StatusType.disable if session_uuid not in token_online else StatusType.enable, + status=StatusType.enable if session_uuid in online_clients else StatusType.disable, last_login_time='未知', expire_time=token_payload.expire_time, ) extra_info = await redis_client.get(f'{settings.TOKEN_EXTRA_INFO_REDIS_PREFIX}:{session_uuid}') if extra_info: - - def append_token_detail(): - data.append( - token_detail.model_copy( - update={ - 'username': extra_info.get('username'), - 'nickname': extra_info.get('nickname'), - 'ip': extra_info.get('ip'), - 'os': extra_info.get('os'), - 'browser': extra_info.get('browser'), - 'device': extra_info.get('device'), - 'last_login_time': extra_info.get('last_login_time'), - } - ) - ) - extra_info = json.loads(extra_info) if extra_info.get('login_type') != 'swagger': - if username: + if username is not None: if username == extra_info.get('username'): append_token_detail() else: @@ -78,7 +80,9 @@ def append_token_detail(): DependsRBAC, ], ) -async def kick_out(request: Request, pk: Annotated[int, Path(...)], session_uuid: KickOutToken) -> ResponseModel: +async def kick_out( + request: Request, pk: Annotated[int, Path(description='用户 ID')], obj: KickOutToken +) -> ResponseModel: superuser_verify(request) - await redis_client.delete(f'{settings.TOKEN_REDIS_PREFIX}:{pk}:{session_uuid}') + await revoke_token(str(pk), obj.session_uuid) return response_base.success() diff --git a/backend/app/admin/api/v1/sys/user.py b/backend/app/admin/api/v1/sys/user.py index 00fd54648..e136fbd03 100644 --- a/backend/app/admin/api/v1/sys/user.py +++ b/backend/app/admin/api/v1/sys/user.py @@ -46,20 +46,24 @@ async def password_reset(request: Request, obj: ResetPasswordParam) -> ResponseM return response_base.fail() -@router.get('/me', summary='获取当前用户信息', dependencies=[DependsJwtAuth], response_model_exclude={'password'}) +@router.get('/me', summary='获取当前用户信息', dependencies=[DependsJwtAuth]) async def get_current_user(request: Request) -> ResponseSchemaModel[GetCurrentUserInfoWithRelationDetail]: data = request.user.model_dump() return response_base.success(data=data) @router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth]) -async def get_user(username: Annotated[str, Path(...)]) -> ResponseSchemaModel[GetUserInfoWithRelationDetail]: +async def get_user( + username: Annotated[str, Path(description='用户名')], +) -> ResponseSchemaModel[GetUserInfoWithRelationDetail]: data = await user_service.get_userinfo(username=username) return response_base.success(data=data) @router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth]) -async def update_user(request: Request, username: Annotated[str, Path(...)], obj: UpdateUserParam) -> ResponseModel: +async def update_user( + request: Request, username: Annotated[str, Path(description='用户名')], obj: UpdateUserParam +) -> ResponseModel: count = await user_service.update(request=request, username=username, obj=obj) if count > 0: return response_base.success() @@ -75,14 +79,16 @@ async def update_user(request: Request, username: Annotated[str, Path(...)], obj ], ) async def update_user_role( - request: Request, username: Annotated[str, Path(...)], obj: UpdateUserRoleParam + request: Request, username: Annotated[str, Path(description='用户名')], obj: UpdateUserRoleParam ) -> ResponseModel: await user_service.update_roles(request=request, username=username, obj=obj) return response_base.success() @router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth]) -async def update_avatar(request: Request, username: Annotated[str, Path(...)], avatar: AvatarParam) -> ResponseModel: +async def update_avatar( + request: Request, username: Annotated[str, Path(description='用户名')], avatar: AvatarParam +) -> ResponseModel: count = await user_service.update_avatar(request=request, username=username, avatar=avatar) if count > 0: return response_base.success() @@ -91,7 +97,7 @@ async def update_avatar(request: Request, username: Annotated[str, Path(...)], a @router.get( '', - summary='(模糊条件)分页获取所有用户', + summary='分页获取所有用户', dependencies=[ DependsJwtAuth, DependsPagination, @@ -99,10 +105,10 @@ async def update_avatar(request: Request, username: Annotated[str, Path(...)], a ) async def get_pagination_users( db: CurrentSession, - dept: Annotated[int | None, Query()] = None, - username: Annotated[str | None, Query()] = None, - phone: Annotated[str | None, Query()] = None, - status: Annotated[int | None, Query()] = None, + dept: Annotated[int | None, Query(description='部门 ID')] = None, + username: Annotated[str | None, Query(description='用户名')] = None, + phone: Annotated[str | None, Query(description='手机号')] = None, + status: Annotated[int | None, Query(description='状态')] = None, ) -> ResponseSchemaModel[PageData[GetUserInfoWithRelationDetail]]: user_select = await user_service.get_select(dept=dept, username=username, phone=phone, status=status) page_data = await paging_data(db, user_select) @@ -110,7 +116,7 @@ async def get_pagination_users( @router.put('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsRBAC]) -async def super_set(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +async def super_set(request: Request, pk: Annotated[int, Path(description='用户 ID')]) -> ResponseModel: count = await user_service.update_permission(request=request, pk=pk) if count > 0: return response_base.success() @@ -118,7 +124,7 @@ async def super_set(request: Request, pk: Annotated[int, Path(...)]) -> Response @router.put('/{pk}/staff', summary='修改用户后台登录权限', dependencies=[DependsRBAC]) -async def staff_set(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +async def staff_set(request: Request, pk: Annotated[int, Path(description='用户 ID')]) -> ResponseModel: count = await user_service.update_staff(request=request, pk=pk) if count > 0: return response_base.success() @@ -126,15 +132,15 @@ async def staff_set(request: Request, pk: Annotated[int, Path(...)]) -> Response @router.put('/{pk}/status', summary='修改用户状态', dependencies=[DependsRBAC]) -async def status_set(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +async def status_set(request: Request, pk: Annotated[int, Path(description='用户 ID')]) -> ResponseModel: count = await user_service.update_status(request=request, pk=pk) if count > 0: return response_base.success() return response_base.fail() -@router.put('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsRBAC]) -async def multi_set(request: Request, pk: Annotated[int, Path(...)]) -> ResponseModel: +@router.put('/{pk}/multi', summary='修改用户多端登录状态', dependencies=[DependsRBAC]) +async def multi_set(request: Request, pk: Annotated[int, Path(description='用户 ID')]) -> ResponseModel: count = await user_service.update_multi_login(request=request, pk=pk) if count > 0: return response_base.success() @@ -150,7 +156,7 @@ async def multi_set(request: Request, pk: Annotated[int, Path(...)]) -> Response DependsRBAC, ], ) -async def delete_user(username: Annotated[str, Path(...)]) -> ResponseModel: +async def delete_user(username: Annotated[str, Path(description='用户名')]) -> ResponseModel: count = await user_service.delete(username=username) if count > 0: return response_base.success() diff --git a/backend/app/admin/conf.py b/backend/app/admin/conf.py index b653a45ac..20fa6c08e 100644 --- a/backend/app/admin/conf.py +++ b/backend/app/admin/conf.py @@ -4,39 +4,36 @@ from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class AdminSettings(BaseSettings): - """Admin Settings""" + """Admin 配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict(env_file=f'{BASE_PATH}/.env', env_file_encoding='utf-8', extra='ignore') - # OAuth2:https://github.com/fastapi-practices/fastapi_oauth20 - # GitHub + # .env OAuth2 OAUTH2_GITHUB_CLIENT_ID: str OAUTH2_GITHUB_CLIENT_SECRET: str - OAUTH2_GITHUB_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/github/callback' - - # Linux Do OAUTH2_LINUX_DO_CLIENT_ID: str OAUTH2_LINUX_DO_CLIENT_SECRET: str - OAUTH2_LINUX_DO_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/linux-do/callback' - # Front-end redirect address + # OAuth2 + OAUTH2_GITHUB_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/github/callback' + OAUTH2_LINUX_DO_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/linux-do/callback' OAUTH2_FRONTEND_REDIRECT_URI: str = 'http://localhost:5173/oauth2/callback' - # Captcha + # 验证码 CAPTCHA_LOGIN_REDIS_PREFIX: str = 'fba:login:captcha' - CAPTCHA_LOGIN_EXPIRE_SECONDS: int = 60 * 5 # 过期时间,单位:秒 + CAPTCHA_LOGIN_EXPIRE_SECONDS: int = 60 * 5 # 3 分钟 - # Config - CONFIG_BUILT_IN_TYPES: list = ['website', 'protocol', 'policy'] + # 参数配置 + CONFIG_BUILT_IN_TYPES: list[str] = ['website', 'protocol', 'policy'] @lru_cache def get_admin_settings() -> AdminSettings: - """获取 admin 配置""" + """获取 admin 参数配置""" return AdminSettings() diff --git a/backend/app/admin/crud/crud_config.py b/backend/app/admin/crud/crud_config.py index db0b66d5a..872174b37 100644 --- a/backend/app/admin/crud/crud_config.py +++ b/backend/app/admin/crud/crud_config.py @@ -12,54 +12,55 @@ class CRUDConfig(CRUDPlus[Config]): + """系统参数参数配置数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> Config | None: """ - 获取系统参数配置 + 获取参数配置详情 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 参数配置 ID :return: """ return await self.select_model_by_column(db, id=pk, type__not_in=admin_settings.CONFIG_BUILT_IN_TYPES) async def get_by_type(self, db: AsyncSession, type: str) -> Sequence[Config]: """ - 通过 type 获取内置系统配置 + 通过类型获取参数配置 - :param db: - :param type: + :param db: 数据库会话 + :param type: 参数配置类型 :return: """ return await self.select_models(db, type=type) async def get_by_key_and_type(self, db: AsyncSession, key: str, type: str) -> Config | None: """ - 通过 name 和 type 获取内置系统配置 + 通过键名和类型获取参数配置 - :param db: - :param key: - :param type: + :param db: 数据库会话 + :param key: 参数配置键名 + :param type: 参数配置类型 :return: """ return await self.select_model_by_column(db, key=key, type=type) async def get_by_key(self, db: AsyncSession, key: str) -> Config | None: """ - 通过 key 获取系统配置参数 + 通过键名获取参数配置 - :param db: - :param key: - :param built_in: + :param db: 数据库会话 + :param key: 参数配置键名 :return: """ return await self.select_model_by_column(db, key=key) - async def get_list(self, name: str = None, type: str = None) -> Select: + async def get_list(self, name: str | None = None, type: str | None = None) -> Select: """ - 获取系统参数配置列表 + 获取参数配置列表 - :param name: - :param type: + :param name: 参数配置名称 + :param type: 参数配置类型 :return: """ filters = {'type__not_in': admin_settings.CONFIG_BUILT_IN_TYPES} @@ -69,33 +70,33 @@ async def get_list(self, name: str = None, type: str = None) -> Select: filters.update(type__like=f'%{type}%') return await self.select_order('created_time', 'desc', **filters) - async def create(self, db: AsyncSession, obj_in: CreateConfigParam) -> None: + async def create(self, db: AsyncSession, obj: CreateConfigParam) -> None: """ - 创建 Config + 创建参数配置 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建参数配置参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateConfigParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateConfigParam) -> int: """ - 更新 Config + 更新参数配置 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 参数配置 ID + :param obj: 更新参数配置参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ - 删除 Config + 删除参数配置 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 参数配置 ID 列表 :return: """ return await self.delete_model_by_column( diff --git a/backend/app/admin/crud/crud_data_rule.py b/backend/app/admin/crud/crud_data_rule.py index 2a65800a9..573271c51 100644 --- a/backend/app/admin/crud/crud_data_rule.py +++ b/backend/app/admin/crud/crud_data_rule.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Sequence -from sqlalchemy import Select, desc, select +from sqlalchemy import Select, and_, desc, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from sqlalchemy_crud_plus import CRUDPlus @@ -12,76 +12,82 @@ class CRUDDataRule(CRUDPlus[DataRule]): + """数据权限规则数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> DataRule | None: """ - 获取数据权限规则 + 获取规则详情 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 规则 ID :return: """ return await self.select_model(db, pk) - async def get_list(self, name: str = None) -> Select: + async def get_list(self, name: str | None = None) -> Select: """ - 获取数据权限规则列表 + 获取规则列表 + :param name: 规则名称 :return: """ stmt = select(self.model).options(noload(self.model.roles)).order_by(desc(self.model.created_time)) - where_list = [] + + filters = [] if name is not None: - where_list.append(self.model.name.like(f'%{name}%')) - if where_list: - stmt = stmt.where(*where_list) + filters.append(self.model.name.like(f'%{name}%')) + + if filters: + stmt = stmt.where(and_(*filters)) + return stmt - async def get_by_name(self, db: AsyncSession, name: str): + async def get_by_name(self, db: AsyncSession, name: str) -> DataRule | None: """ - 通过 name 获取数据权限规则 + 通过名称获取规则 - :param db: - :param name: + :param db: 数据库会话 + :param name: 规则名称 :return: """ return await self.select_model_by_column(db, name=name) async def get_all(self, db: AsyncSession) -> Sequence[DataRule]: """ - 获取所有数据权限规则 + 获取所有规则 - :param db: + :param db: 数据库会话 :return: """ return await self.select_models(db) - async def create(self, db: AsyncSession, obj_in: CreateDataRuleParam) -> None: + async def create(self, db: AsyncSession, obj: CreateDataRuleParam) -> None: """ - 创建数据权限规则 + 创建规则 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建规则参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDataRuleParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateDataRuleParam) -> int: """ - 更新数据权限规则 + 更新规则 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 规则 ID + :param obj: 更新规则参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ - 删除数据权限规则 + 删除规则 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 规则 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) diff --git a/backend/app/admin/crud/crud_dept.py b/backend/app/admin/crud/crud_dept.py index 80bf286c3..53d889cc0 100644 --- a/backend/app/admin/crud/crud_dept.py +++ b/backend/app/admin/crud/crud_dept.py @@ -7,42 +7,49 @@ from sqlalchemy.orm import selectinload from sqlalchemy_crud_plus import CRUDPlus -from backend.app.admin.model import Dept, User +from backend.app.admin.model import Dept from backend.app.admin.schema.dept import CreateDeptParam, UpdateDeptParam class CRUDDept(CRUDPlus[Dept]): + """部门数据库操作类""" + async def get(self, db: AsyncSession, dept_id: int) -> Dept | None: """ - 获取部门 + 获取部门详情 - :param db: - :param dept_id: + :param db: 数据库会话 + :param dept_id: 部门 ID :return: """ return await self.select_model_by_column(db, id=dept_id, del_flag=0) async def get_by_name(self, db: AsyncSession, name: str) -> Dept | None: """ - 通过 name 获取 API + 通过名称获取部门 - :param db: - :param name: + :param db: 数据库会话 + :param name: 部门名称 :return: """ return await self.select_model_by_column(db, name=name, del_flag=0) async def get_all( - self, db: AsyncSession, name: str = None, leader: str = None, phone: str = None, status: int = None + self, + db: AsyncSession, + name: str | None = None, + leader: str | None = None, + phone: str | None = None, + status: int | None = None, ) -> Sequence[Dept]: """ 获取所有部门 - :param db: - :param name: - :param leader: - :param phone: - :param status: + :param db: 数据库会话 + :param name: 部门名称 + :param leader: 负责人 + :param phone: 联系电话 + :param status: 部门状态 :return: """ filters = {'del_flag__eq': 0} @@ -56,56 +63,55 @@ async def get_all( filters.update(status=status) return await self.select_models_order(db, sort_columns='sort', **filters) - async def create(self, db: AsyncSession, obj_in: CreateDeptParam) -> None: + async def create(self, db: AsyncSession, obj: CreateDeptParam) -> None: """ 创建部门 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建部门参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, dept_id: int, obj_in: UpdateDeptParam) -> int: + async def update(self, db: AsyncSession, dept_id: int, obj: UpdateDeptParam) -> int: """ 更新部门 - :param db: - :param dept_id: - :param obj_in: + :param db: 数据库会话 + :param dept_id: 部门 ID + :param obj: 更新部门参数 :return: """ - return await self.update_model(db, dept_id, obj_in) + return await self.update_model(db, dept_id, obj) async def delete(self, db: AsyncSession, dept_id: int) -> int: """ 删除部门 - :param db: - :param dept_id: + :param db: 数据库会话 + :param dept_id: 部门 ID :return: """ return await self.delete_model_by_column(db, id=dept_id, logical_deletion=True, deleted_flag_column='del_flag') - async def get_with_relation(self, db: AsyncSession, dept_id: int) -> list[User]: + async def get_with_relation(self, db: AsyncSession, dept_id: int) -> Dept | None: """ - 获取关联 + 获取部门及关联数据 - :param db: - :param dept_id: + :param db: 数据库会话 + :param dept_id: 部门 ID :return: """ stmt = select(self.model).options(selectinload(self.model.users)).where(self.model.id == dept_id) result = await db.execute(stmt) - user_relation = result.scalars().first() - return user_relation.users + return result.scalars().first() - async def get_children(self, db: AsyncSession, dept_id: int) -> Sequence[Dept]: + async def get_children(self, db: AsyncSession, dept_id: int) -> Sequence[Dept | None]: """ - 获取子部门 + 获取子部门列表 - :param db: - :param dept_id: + :param db: 数据库会话 + :param dept_id: 部门 ID :return: """ stmt = select(self.model).where(self.model.parent_id == dept_id, self.model.del_flag == 0) diff --git a/backend/app/admin/crud/crud_dict_data.py b/backend/app/admin/crud/crud_dict_data.py index aa6dfafe6..fbb96c179 100644 --- a/backend/app/admin/crud/crud_dict_data.py +++ b/backend/app/admin/crud/crud_dict_data.py @@ -10,84 +10,89 @@ class CRUDDictData(CRUDPlus[DictData]): + """字典数据数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> DictData | None: """ - 获取字典数据 + 获取字典数据详情 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 字典数据 ID :return: """ return await self.select_model(db, pk) - async def get_list(self, label: str = None, value: str = None, status: int = None) -> Select: + async def get_list(self, label: str | None = None, value: str | None = None, status: int | None = None) -> Select: """ - 获取所有字典数据 + 获取字典数据列表 - :param label: - :param value: - :param status: + :param label: 字典数据标签 + :param value: 字典数据键值 + :param status: 字典状态 :return: """ stmt = select(self.model).options(noload(self.model.type)).order_by(desc(self.model.sort)) - where_list = [] + + filters = [] if label is not None: - where_list.append(self.model.label.like(f'%{label}%')) + filters.append(self.model.label.like(f'%{label}%')) if value is not None: - where_list.append(self.model.value.like(f'%{value}%')) + filters.append(self.model.value.like(f'%{value}%')) if status is not None: - where_list.append(self.model.status == status) - if where_list: - stmt = stmt.where(and_(*where_list)) + filters.append(self.model.status == status) + + if filters: + stmt = stmt.where(and_(*filters)) + return stmt async def get_by_label(self, db: AsyncSession, label: str) -> DictData | None: """ - 通过 label 获取字典数据 + 通过标签获取字典数据 - :param db: - :param label: + :param db: 数据库会话 + :param label: 字典标签 :return: """ return await self.select_model_by_column(db, label=label) - async def create(self, db: AsyncSession, obj_in: CreateDictDataParam) -> None: + async def create(self, db: AsyncSession, obj: CreateDictDataParam) -> None: """ - 创建数据字典 + 创建字典数据 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建字典数据参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictDataParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateDictDataParam) -> int: """ - 更新数据字典 + 更新字典数据 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 字典数据 ID + :param obj: 更新字典数据参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除字典数据 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 字典数据 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) async def get_with_relation(self, db: AsyncSession, pk: int) -> DictData | None: """ - 获取字典数据和类型 + 获取字典数据及关联数据 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 字典数据 ID :return: """ stmt = select(self.model).options(selectinload(self.model.type)).where(self.model.id == pk) diff --git a/backend/app/admin/crud/crud_dict_type.py b/backend/app/admin/crud/crud_dict_type.py index 52d74803d..c44e14e83 100644 --- a/backend/app/admin/crud/crud_dict_type.py +++ b/backend/app/admin/crud/crud_dict_type.py @@ -9,23 +9,25 @@ class CRUDDictType(CRUDPlus[DictType]): + """字典类型数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> DictType | None: """ - 获取字典类型 + 获取字典类型详情 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 字典类型 ID :return: """ return await self.select_model(db, pk) - async def get_list(self, *, name: str = None, code: str = None, status: int = None) -> Select: + async def get_list(self, *, name: str | None = None, code: str | None = None, status: int | None = None) -> Select: """ - 获取所有字典类型 + 获取字典类型列表 - :param name: - :param code: - :param status: + :param name: 字典类型名称 + :param code: 字典类型编码 + :param status: 字典状态 :return: """ filters = {} @@ -39,41 +41,41 @@ async def get_list(self, *, name: str = None, code: str = None, status: int = No async def get_by_code(self, db: AsyncSession, code: str) -> DictType | None: """ - 通过 code 获取字典类型 + 通过编码获取字典类型 - :param db: - :param code: + :param db: 数据库会话 + :param code: 字典编码 :return: """ return await self.select_model_by_column(db, code=code) - async def create(self, db: AsyncSession, obj_in: CreateDictTypeParam) -> None: + async def create(self, db: AsyncSession, obj: CreateDictTypeParam) -> None: """ 创建字典类型 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建字典类型参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictTypeParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateDictTypeParam) -> int: """ 更新字典类型 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 字典类型 ID + :param obj: 更新字典类型参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除字典类型 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 字典类型 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) diff --git a/backend/app/admin/crud/crud_login_log.py b/backend/app/admin/crud/crud_login_log.py index 48c44da6d..f7af92f6b 100644 --- a/backend/app/admin/crud/crud_login_log.py +++ b/backend/app/admin/crud/crud_login_log.py @@ -9,13 +9,15 @@ class CRUDLoginLog(CRUDPlus[LoginLog]): + """登录日志数据库操作类""" + async def get_list(self, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: """ 获取登录日志列表 - :param username: - :param status: - :param ip: + :param username: 用户名 + :param status: 登录状态 + :param ip: IP 地址 :return: """ filters = {} @@ -27,31 +29,31 @@ async def get_list(self, username: str | None = None, status: int | None = None, filters.update(ip__like=f'%{ip}%') return await self.select_order('created_time', 'desc', **filters) - async def create(self, db: AsyncSession, obj_in: CreateLoginLogParam) -> None: + async def create(self, db: AsyncSession, obj: CreateLoginLogParam) -> None: """ 创建登录日志 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建登录日志参数 :return: """ - await self.create_model(db, obj_in, commit=True) + await self.create_model(db, obj, commit=True) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除登录日志 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 登录日志 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) async def delete_all(self, db: AsyncSession) -> int: """ - 删除所有登录日志 + 删除所有日志 - :param db: + :param db: 数据库会话 :return: """ return await self.delete_model_by_column(db, allow_multiple=True) diff --git a/backend/app/admin/crud/crud_menu.py b/backend/app/admin/crud/crud_menu.py index 43cdb4a19..c66d0b004 100644 --- a/backend/app/admin/crud/crud_menu.py +++ b/backend/app/admin/crud/crud_menu.py @@ -3,6 +3,7 @@ from typing import Sequence from sqlalchemy import and_, asc, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy_crud_plus import CRUDPlus @@ -11,33 +12,35 @@ class CRUDMenu(CRUDPlus[Menu]): - async def get(self, db, menu_id: int) -> Menu | None: + """菜单数据库操作类""" + + async def get(self, db: AsyncSession, menu_id: int) -> Menu | None: """ - 获取菜单 + 获取菜单详情 - :param db: - :param menu_id: + :param db: 数据库会话 + :param menu_id: 菜单 ID :return: """ return await self.select_model(db, menu_id) - async def get_by_title(self, db, title: str) -> Menu | None: + async def get_by_title(self, db: AsyncSession, title: str) -> Menu | None: """ - 通过 title 获取菜单 + 通过标题获取菜单 - :param db: - :param title: + :param db: 数据库会话 + :param title: 菜单标题 :return: """ return await self.select_model_by_column(db, title=title, menu_type__ne=2) - async def get_all(self, db, title: str | None = None, status: int | None = None) -> Sequence[Menu]: + async def get_all(self, db: AsyncSession, title: str | None = None, status: int | None = None) -> Sequence[Menu]: """ - 获取所有菜单 + 获取菜单列表 - :param db: - :param title: - :param status: + :param db: 数据库会话 + :param title: 菜单标题 + :param status: 菜单状态 :return: """ filters = {} @@ -47,60 +50,60 @@ async def get_all(self, db, title: str | None = None, status: int | None = None) filters.update(status=status) return await self.select_models_order(db, 'sort', **filters) - async def get_role_menus(self, db, superuser: bool, menu_ids: list[int]) -> Sequence[Menu]: + async def get_role_menus(self, db: AsyncSession, superuser: bool, menu_ids: list[int]) -> Sequence[Menu]: """ - 获取角色菜单 + 获取角色菜单列表 - :param db: - :param superuser: - :param menu_ids: + :param db: 数据库会话 + :param superuser: 是否超级管理员 + :param menu_ids: 菜单 ID 列表 :return: """ stmt = select(self.model).order_by(asc(self.model.sort)) - where_list = [self.model.menu_type.in_([0, 1])] + filters = [self.model.menu_type.in_([0, 1])] if not superuser: - where_list.append(self.model.id.in_(menu_ids)) - stmt = stmt.where(and_(*where_list)) + filters.append(self.model.id.in_(menu_ids)) + stmt = stmt.where(and_(*filters)) menu = await db.execute(stmt) return menu.scalars().all() - async def create(self, db, obj_in: CreateMenuParam) -> None: + async def create(self, db: AsyncSession, obj: CreateMenuParam) -> None: """ 创建菜单 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建菜单参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db, menu_id: int, obj_in: UpdateMenuParam) -> int: + async def update(self, db: AsyncSession, menu_id: int, obj: UpdateMenuParam) -> int: """ 更新菜单 - :param db: - :param menu_id: - :param obj_in: + :param db: 数据库会话 + :param menu_id: 菜单 ID + :param obj: 更新菜单参数 :return: """ - return await self.update_model(db, menu_id, obj_in) + return await self.update_model(db, menu_id, obj) - async def delete(self, db, menu_id: int) -> int: + async def delete(self, db: AsyncSession, menu_id: int) -> int: """ 删除菜单 - :param db: - :param menu_id: + :param db: 数据库会话 + :param menu_id: 菜单 ID :return: """ return await self.delete_model(db, menu_id) - async def get_children(self, db, menu_id: int) -> list[Menu]: + async def get_children(self, db: AsyncSession, menu_id: int) -> list[Menu | None]: """ - 获取子菜单 + 获取子菜单列表 - :param db: - :param menu_id: + :param db: 数据库会话 + :param menu_id: 菜单 ID :return: """ stmt = select(self.model).options(selectinload(self.model.children)).where(self.model.id == menu_id) diff --git a/backend/app/admin/crud/crud_opera_log.py b/backend/app/admin/crud/crud_opera_log.py index 027b4bbe1..3d06d7bfe 100644 --- a/backend/app/admin/crud/crud_opera_log.py +++ b/backend/app/admin/crud/crud_opera_log.py @@ -9,13 +9,15 @@ class CRUDOperaLogDao(CRUDPlus[OperaLog]): + """操作日志数据库操作类""" + async def get_list(self, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: """ 获取操作日志列表 - :param username: - :param status: - :param ip: + :param username: 用户名 + :param status: 操作状态 + :param ip: IP 地址 :return: """ filters = {} @@ -27,31 +29,31 @@ async def get_list(self, username: str | None = None, status: int | None = None, filters.update(ip__like=f'%{ip}%') return await self.select_order('created_time', 'desc', **filters) - async def create(self, db: AsyncSession, obj_in: CreateOperaLogParam) -> None: + async def create(self, db: AsyncSession, obj: CreateOperaLogParam) -> None: """ 创建操作日志 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建操作日志参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除操作日志 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 操作日志 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) async def delete_all(self, db: AsyncSession) -> int: """ - 删除所有操作日志 + 删除所有日志 - :param db: + :param db: 数据库会话 :return: """ return await self.delete_model_by_column(db, allow_multiple=True) diff --git a/backend/app/admin/crud/crud_role.py b/backend/app/admin/crud/crud_role.py index a46d0776c..aef49a376 100644 --- a/backend/app/admin/crud/crud_role.py +++ b/backend/app/admin/crud/crud_role.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- from typing import Sequence -from sqlalchemy import Select, desc, select +from sqlalchemy import Select, and_, desc, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload, selectinload from sqlalchemy_crud_plus import CRUDPlus @@ -16,22 +17,24 @@ class CRUDRole(CRUDPlus[Role]): - async def get(self, db, role_id: int) -> Role | None: + """角色数据库操作类""" + + async def get(self, db: AsyncSession, role_id: int) -> Role | None: """ - 获取角色 + 获取角色详情 - :param db: - :param role_id: + :param db: 数据库会话 + :param role_id: 角色 ID :return: """ return await self.select_model(db, role_id) - async def get_with_relation(self, db, role_id: int) -> Role | None: + async def get_with_relation(self, db: AsyncSession, role_id: int) -> Role | None: """ - 获取角色和菜单 + 获取角色及关联数据 - :param db: - :param role_id: + :param db: 数据库会话 + :param role_id: 角色 ID :return: """ stmt = ( @@ -42,33 +45,33 @@ async def get_with_relation(self, db, role_id: int) -> Role | None: role = await db.execute(stmt) return role.scalars().first() - async def get_all(self, db) -> Sequence[Role]: + async def get_all(self, db: AsyncSession) -> Sequence[Role]: """ 获取所有角色 - :param db: + :param db: 数据库会话 :return: """ return await self.select_models(db) - async def get_by_user(self, db, user_id: int) -> Sequence[Role]: + async def get_by_user(self, db: AsyncSession, user_id: int) -> Sequence[Role]: """ - 获取用户所有角色 + 获取用户角色列表 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ stmt = select(self.model).join(self.model.users).where(User.id == user_id) roles = await db.execute(stmt) return roles.scalars().all() - async def get_list(self, name: str = None, status: int = None) -> Select: + async def get_list(self, name: str | None = None, status: int | None = None) -> Select: """ 获取角色列表 - :param name: - :param status: + :param name: 角色名称 + :param status: 角色状态 :return: """ stmt = ( @@ -76,84 +79,85 @@ async def get_list(self, name: str = None, status: int = None) -> Select: .options(noload(self.model.users), noload(self.model.menus), noload(self.model.rules)) .order_by(desc(self.model.created_time)) ) - where_list = [] - if name: - where_list.append(self.model.name.like(f'%{name}%')) + + filters = [] + if name is not None: + filters.append(self.model.name.like(f'%{name}%')) if status is not None: - where_list.append(self.model.status == status) - if where_list: - stmt = stmt.where(*where_list) + filters.append(self.model.status == status) + + if filters: + stmt = stmt.where(and_(*filters)) + return stmt - async def get_by_name(self, db, name: str) -> Role | None: + async def get_by_name(self, db: AsyncSession, name: str) -> Role | None: """ - 通过 name 获取角色 + 通过名称获取角色 - :param db: - :param name: + :param db: 数据库会话 + :param name: 角色名称 :return: """ return await self.select_model_by_column(db, name=name) - async def create(self, db, obj_in: CreateRoleParam) -> None: + async def create(self, db: AsyncSession, obj: CreateRoleParam) -> None: """ 创建角色 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建角色参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db, role_id: int, obj_in: UpdateRoleParam) -> int: + async def update(self, db: AsyncSession, role_id: int, obj: UpdateRoleParam) -> int: """ 更新角色 - :param db: - :param role_id: - :param obj_in: + :param db: 数据库会话 + :param role_id: 角色 ID + :param obj: 更新角色参数 :return: """ - return await self.update_model(db, role_id, obj_in) + return await self.update_model(db, role_id, obj) - async def update_menus(self, db, role_id: int, menu_ids: UpdateRoleMenuParam) -> int: + async def update_menus(self, db: AsyncSession, role_id: int, menu_ids: UpdateRoleMenuParam) -> int: """ 更新角色菜单 - :param db: - :param role_id: - :param menu_ids: + :param db: 数据库会话 + :param role_id: 角色 ID + :param menu_ids: 菜单 ID 列表 :return: """ current_role = await self.get_with_relation(db, role_id) - # 更新菜单 stmt = select(Menu).where(Menu.id.in_(menu_ids.menus)) menus = await db.execute(stmt) current_role.menus = menus.scalars().all() return len(current_role.menus) - async def update_rules(self, db, role_id: int, rule_ids: UpdateRoleRuleParam) -> int: + async def update_rules(self, db: AsyncSession, role_id: int, rule_ids: UpdateRoleRuleParam) -> int: """ - 更新角色数据权限 + 更新角色数据规则 - :param db: - :param role_id: - :param rule_ids: + :param db: 数据库会话 + :param role_id: 角色 ID + :param rule_ids: 权限规则 ID 列表 :return: """ current_role = await self.get_with_relation(db, role_id) - # 更新数据权限 stmt = select(DataRule).where(DataRule.id.in_(rule_ids.rules)) rules = await db.execute(stmt) current_role.rules = rules.scalars().all() return len(current_role.rules) - async def delete(self, db, role_id: list[int]) -> int: + async def delete(self, db: AsyncSession, role_id: list[int]) -> int: """ 删除角色 - :param db: - :param role_id: + :param db: 数据库会话 + :param role_id: 角色 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=role_id) diff --git a/backend/app/admin/crud/crud_user.py b/backend/app/admin/crud/crud_user.py index 1c0a06f12..d3c3c9e5f 100644 --- a/backend/app/admin/crud/crud_user.py +++ b/backend/app/admin/crud/crud_user.py @@ -21,42 +21,44 @@ class CRUDUser(CRUDPlus[User]): + """用户数据库操作类""" + async def get(self, db: AsyncSession, user_id: int) -> User | None: """ - 获取用户 + 获取用户详情 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ return await self.select_model(db, user_id) async def get_by_username(self, db: AsyncSession, username: str) -> User | None: """ - 通过 username 获取用户 + 通过用户名获取用户 - :param db: - :param username: + :param db: 数据库会话 + :param username: 用户名 :return: """ return await self.select_model_by_column(db, username=username) async def get_by_nickname(self, db: AsyncSession, nickname: str) -> User | None: """ - 通过 nickname 获取用户 + 通过昵称获取用户 - :param db: - :param nickname: + :param db: 数据库会话 + :param nickname: 用户昵称 :return: """ return await self.select_model_by_column(db, nickname=nickname) async def update_login_time(self, db: AsyncSession, username: str) -> int: """ - 更新用户登录时间 + 更新用户最后登录时间 - :param db: - :param username: + :param db: 数据库会话 + :param username: 用户名 :return: """ return await self.update_model_by_column(db, {'last_login_time': timezone.now()}, username=username) @@ -65,9 +67,9 @@ async def create(self, db: AsyncSession, obj: RegisterUserParam, *, social: bool """ 创建用户 - :param db: - :param obj: - :param social: 社交用户,适配 oauth 2.0 + :param db: 数据库会话 + :param obj: 注册用户参数 + :param social: 是否社交用户 :return: """ if not social: @@ -83,10 +85,10 @@ async def create(self, db: AsyncSession, obj: RegisterUserParam, *, social: bool async def add(self, db: AsyncSession, obj: AddUserParam) -> None: """ - 后台添加用户 + 添加用户 - :param db: - :param obj: + :param db: 数据库会话 + :param obj: 添加用户参数 :return: """ salt = bcrypt.gensalt() @@ -94,19 +96,21 @@ async def add(self, db: AsyncSession, obj: AddUserParam) -> None: dict_obj = obj.model_dump(exclude={'roles'}) dict_obj.update({'salt': salt}) new_user = self.model(**dict_obj) + role_list = [] for role_id in obj.roles: role_list.append(await db.get(Role, role_id)) new_user.roles.extend(role_list) + db.add(new_user) async def update_userinfo(self, db: AsyncSession, input_user: int, obj: UpdateUserParam) -> int: """ 更新用户信息 - :param db: - :param input_user: - :param obj: + :param db: 数据库会话 + :param input_user: 用户 ID + :param obj: 更新用户参数 :return: """ return await self.update_model(db, input_user, obj) @@ -116,15 +120,14 @@ async def update_role(db: AsyncSession, input_user: User, obj: UpdateUserRolePar """ 更新用户角色 - :param db: - :param input_user: - :param obj: + :param db: 数据库会话 + :param input_user: 用户对象 + :param obj: 更新角色参数 :return: """ - # 删除用户所有角色 for i in list(input_user.roles): input_user.roles.remove(i) - # 添加用户角色 + role_list = [] for role_id in obj.roles: role_list.append(await db.get(Role, role_id)) @@ -134,9 +137,9 @@ async def update_avatar(self, db: AsyncSession, input_user: int, avatar: AvatarP """ 更新用户头像 - :param db: - :param input_user: - :param avatar: + :param db: 数据库会话 + :param input_user: 用户 ID + :param avatar: 头像地址 :return: """ return await self.update_model(db, input_user, {'avatar': avatar.url}) @@ -145,18 +148,18 @@ async def delete(self, db: AsyncSession, user_id: int) -> int: """ 删除用户 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ return await self.delete_model(db, user_id) async def check_email(self, db: AsyncSession, email: str) -> User | None: """ - 检查邮箱是否存在 + 检查邮箱是否已被注册 - :param db: - :param email: + :param db: 数据库会话 + :param email: 电子邮箱 :return: """ return await self.select_model_by_column(db, email=email) @@ -165,21 +168,23 @@ async def reset_password(self, db: AsyncSession, pk: int, new_pwd: str) -> int: """ 重置用户密码 - :param db: - :param pk: - :param new_pwd: + :param db: 数据库会话 + :param pk: 用户 ID + :param new_pwd: 新密码(已加密) :return: """ return await self.update_model(db, pk, {'password': new_pwd}) - async def get_list(self, dept: int = None, username: str = None, phone: str = None, status: int = None) -> Select: + async def get_list( + self, dept: int | None = None, username: str | None = None, phone: str | None = None, status: int | None = None + ) -> Select: """ 获取用户列表 - :param dept: - :param username: - :param phone: - :param status: + :param dept: 部门 ID + :param username: 用户名 + :param phone: 电话号码 + :param status: 用户状态 :return: """ stmt = ( @@ -191,25 +196,28 @@ async def get_list(self, dept: int = None, username: str = None, phone: str = No ) .order_by(desc(self.model.join_time)) ) - where_list = [] + + filters = [] if dept: - where_list.append(self.model.dept_id == dept) + filters.append(self.model.dept_id == dept) if username: - where_list.append(self.model.username.like(f'%{username}%')) + filters.append(self.model.username.like(f'%{username}%')) if phone: - where_list.append(self.model.phone.like(f'%{phone}%')) + filters.append(self.model.phone.like(f'%{phone}%')) if status is not None: - where_list.append(self.model.status == status) - if where_list: - stmt = stmt.where(and_(*where_list)) + filters.append(self.model.status == status) + + if filters: + stmt = stmt.where(and_(*filters)) + return stmt async def get_super(self, db: AsyncSession, user_id: int) -> bool: """ - 获取用户超级管理员状态 + 获取用户是否为超级管理员 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ user = await self.get(db, user_id) @@ -217,10 +225,10 @@ async def get_super(self, db: AsyncSession, user_id: int) -> bool: async def get_staff(self, db: AsyncSession, user_id: int) -> bool: """ - 获取用户后台登录状态 + 获取用户是否可以登录后台 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ user = await self.get(db, user_id) @@ -230,8 +238,8 @@ async def get_status(self, db: AsyncSession, user_id: int) -> int: """ 获取用户状态 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ user = await self.get(db, user_id) @@ -239,81 +247,85 @@ async def get_status(self, db: AsyncSession, user_id: int) -> int: async def get_multi_login(self, db: AsyncSession, user_id: int) -> bool: """ - 获取用户多点登录状态 + 获取用户是否允许多端登录 - :param db: - :param user_id: + :param db: 数据库会话 + :param user_id: 用户 ID :return: """ user = await self.get(db, user_id) return user.is_multi_login - async def set_super(self, db: AsyncSession, user_id: int, _super: bool) -> int: + async def set_super(self, db: AsyncSession, user_id: int, is_super: bool) -> int: """ - 设置用户超级管理员 + 设置用户超级管理员状态 - :param db: - :param user_id: - :param _super: + :param db: 数据库会话 + :param user_id: 用户 ID + :param is_super: 是否超级管理员 :return: """ - return await self.update_model(db, user_id, {'is_superuser': _super}) + return await self.update_model(db, user_id, {'is_superuser': is_super}) - async def set_staff(self, db: AsyncSession, user_id: int, staff: bool) -> int: + async def set_staff(self, db: AsyncSession, user_id: int, is_staff: bool) -> int: """ - 设置用户后台登录 + 设置用户后台登录状态 - :param db: - :param user_id: - :param staff: + :param db: 数据库会话 + :param user_id: 用户 ID + :param is_staff: 是否可登录后台 :return: """ - return await self.update_model(db, user_id, {'is_staff': staff}) + return await self.update_model(db, user_id, {'is_staff': is_staff}) - async def set_status(self, db: AsyncSession, user_id: int, status: bool) -> int: + async def set_status(self, db: AsyncSession, user_id: int, status: int) -> int: """ 设置用户状态 - :param db: - :param user_id: - :param status: + :param db: 数据库会话 + :param user_id: 用户 ID + :param status: 状态 :return: """ return await self.update_model(db, user_id, {'status': status}) async def set_multi_login(self, db: AsyncSession, user_id: int, multi_login: bool) -> int: """ - 设置用户多点登录 + 设置用户多端登录状态 - :param db: - :param user_id: - :param multi_login: + :param db: 数据库会话 + :param user_id: 用户 ID + :param multi_login: 是否允许多端登录 :return: """ return await self.update_model(db, user_id, {'is_multi_login': multi_login}) - async def get_with_relation(self, db: AsyncSession, *, user_id: int = None, username: str = None) -> User | None: + async def get_with_relation( + self, db: AsyncSession, *, user_id: int | None = None, username: str | None = None + ) -> User | None: """ - 获取用户和(部门,角色,菜单,规则) + 获取用户关联信息 - :param db: - :param user_id: - :param username: + :param db: 数据库会话 + :param user_id: 用户 ID + :param username: 用户名 :return: """ stmt = select(self.model).options( selectinload(self.model.dept), - selectinload(self.model.roles).options( - selectinload(Role.menus), - selectinload(Role.rules), - ), + selectinload(self.model.roles).options(selectinload(Role.menus), selectinload(Role.rules)), ) + filters = [] if user_id: filters.append(self.model.id == user_id) if username: filters.append(self.model.username == username) - user = await db.execute(stmt.where(*filters)) + + if filters: + stmt = stmt.where(and_(*filters)) + + user = await db.execute(stmt) return user.scalars().first() diff --git a/backend/app/admin/crud/crud_user_social.py b/backend/app/admin/crud/crud_user_social.py index 71509c4d7..d356ada60 100644 --- a/backend/app/admin/crud/crud_user_social.py +++ b/backend/app/admin/crud/crud_user_social.py @@ -8,37 +8,39 @@ from backend.common.enums import UserSocialType -class CRUDOUserSocial(CRUDPlus[UserSocial]): +class CRUDUserSocial(CRUDPlus[UserSocial]): + """用户社交账号数据库操作类""" + async def get(self, db: AsyncSession, pk: int, source: UserSocialType) -> UserSocial | None: """ - 获取用户社交账号绑定 + 获取用户社交账号绑定详情 - :param db: - :param pk: - :param source: + :param db: 数据库会话 + :param pk: 用户 ID + :param source: 社交账号类型 :return: """ return await self.select_model_by_column(db, user_id=pk, source=source) - async def create(self, db: AsyncSession, obj_in: CreateUserSocialParam) -> None: + async def create(self, db: AsyncSession, obj: CreateUserSocialParam) -> None: """ 创建用户社交账号绑定 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建用户社交账号绑定参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) async def delete(self, db: AsyncSession, social_id: int) -> int: """ 删除用户社交账号绑定 - :param db: - :param social_id: + :param db: 数据库会话 + :param social_id: 社交账号绑定 ID :return: """ return await self.delete_model(db, social_id) -user_social_dao: CRUDOUserSocial = CRUDOUserSocial(UserSocial) +user_social_dao: CRUDUserSocial = CRUDUserSocial(UserSocial) diff --git a/backend/app/admin/model/config.py b/backend/app/admin/model/config.py index 9465f4385..f5e3df779 100644 --- a/backend/app/admin/model/config.py +++ b/backend/app/admin/model/config.py @@ -9,7 +9,7 @@ class Config(Base): - """系统配置表""" + """参数配置表""" __tablename__ = 'sys_config' diff --git a/backend/app/admin/model/data_rule.py b/backend/app/admin/model/data_rule.py index fb9903e9f..4ec7da564 100644 --- a/backend/app/admin/model/data_rule.py +++ b/backend/app/admin/model/data_rule.py @@ -1,11 +1,18 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import String from sqlalchemy.orm import Mapped, mapped_column, relationship from backend.app.admin.model.m2m import sys_role_data_rule from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import Role + class DataRule(Base): """数据权限规则表""" @@ -23,4 +30,4 @@ class DataRule(Base): value: Mapped[str] = mapped_column(String(255), comment='规则值') # 角色规则多对多 - roles: Mapped[list['Role']] = relationship(init=False, secondary=sys_role_data_rule, back_populates='rules') # noqa: F821 + roles: Mapped[list[Role]] = relationship(init=False, secondary=sys_role_data_rule, back_populates='rules') diff --git a/backend/app/admin/model/dept.py b/backend/app/admin/model/dept.py index fe88e412b..f85fa407e 100644 --- a/backend/app/admin/model/dept.py +++ b/backend/app/admin/model/dept.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional from sqlalchemy import Boolean, ForeignKey, String from sqlalchemy.dialects.postgresql import INTEGER @@ -8,6 +10,9 @@ from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import User + class Dept(Base): """部门表""" @@ -29,8 +34,8 @@ class Dept(Base): parent_id: Mapped[int | None] = mapped_column( ForeignKey('sys_dept.id', ondelete='SET NULL'), default=None, index=True, comment='父部门ID' ) - parent: Mapped[Union['Dept', None]] = relationship(init=False, back_populates='children', remote_side=[id]) - children: Mapped[list['Dept'] | None] = relationship(init=False, back_populates='parent') + parent: Mapped[Optional['Dept']] = relationship(init=False, back_populates='children', remote_side=[id]) + children: Mapped[Optional[list['Dept']]] = relationship(init=False, back_populates='parent') # 部门用户一对多 - users: Mapped[list['User']] = relationship(init=False, back_populates='dept') # noqa: F821 + users: Mapped[list[User]] = relationship(init=False, back_populates='dept') diff --git a/backend/app/admin/model/dict_data.py b/backend/app/admin/model/dict_data.py index d90be9182..af11d494a 100644 --- a/backend/app/admin/model/dict_data.py +++ b/backend/app/admin/model/dict_data.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import ForeignKey, String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.postgresql import TEXT @@ -7,9 +11,12 @@ from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import DictType + class DictData(Base): - """字典数据""" + """字典数据表""" __tablename__ = 'sys_dict_data' @@ -26,4 +33,4 @@ class DictData(Base): type_id: Mapped[int] = mapped_column( ForeignKey('sys_dict_type.id', ondelete='CASCADE'), default=0, comment='字典类型关联ID' ) - type: Mapped['DictType'] = relationship(init=False, back_populates='datas') # noqa: F821 + type: Mapped[DictType] = relationship(init=False, back_populates='datas') diff --git a/backend/app/admin/model/dict_type.py b/backend/app/admin/model/dict_type.py index 88fc42464..c901a7840 100644 --- a/backend/app/admin/model/dict_type.py +++ b/backend/app/admin/model/dict_type.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.postgresql import TEXT @@ -7,9 +11,12 @@ from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import DictData + class DictType(Base): - """字典类型""" + """字典类型表""" __tablename__ = 'sys_dict_type' @@ -22,4 +29,4 @@ class DictType(Base): ) # 字典类型一对多 - datas: Mapped[list['DictData']] = relationship(init=False, back_populates='type') # noqa: F821 + datas: Mapped[list[DictData]] = relationship(init=False, back_populates='type') diff --git a/backend/app/admin/model/menu.py b/backend/app/admin/model/menu.py index c0f38a460..6d74ec2e6 100644 --- a/backend/app/admin/model/menu.py +++ b/backend/app/admin/model/menu.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional from sqlalchemy import ForeignKey, String from sqlalchemy.dialects.mysql import LONGTEXT @@ -10,6 +12,9 @@ from backend.app.admin.model.m2m import sys_role_menu from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import Role + class Menu(Base): """菜单表""" @@ -36,8 +41,8 @@ class Menu(Base): parent_id: Mapped[int | None] = mapped_column( ForeignKey('sys_menu.id', ondelete='SET NULL'), default=None, index=True, comment='父菜单ID' ) - parent: Mapped[Union['Menu', None]] = relationship(init=False, back_populates='children', remote_side=[id]) - children: Mapped[list['Menu'] | None] = relationship(init=False, back_populates='parent') + parent: Mapped[Optional['Menu']] = relationship(init=False, back_populates='children', remote_side=[id]) + children: Mapped[Optional[list['Menu']]] = relationship(init=False, back_populates='parent') # 菜单角色多对多 - roles: Mapped[list['Role']] = relationship(init=False, secondary=sys_role_menu, back_populates='menus') # noqa: F821 + roles: Mapped[list[Role]] = relationship(init=False, secondary=sys_role_menu, back_populates='menus') diff --git a/backend/app/admin/model/opera_log.py b/backend/app/admin/model/opera_log.py index 64ffeb871..aac24468c 100644 --- a/backend/app/admin/model/opera_log.py +++ b/backend/app/admin/model/opera_log.py @@ -3,7 +3,8 @@ from datetime import datetime from sqlalchemy import DateTime, String -from sqlalchemy.dialects.mysql import JSON, LONGTEXT, TEXT +from sqlalchemy.dialects.mysql import JSON, LONGTEXT +from sqlalchemy.dialects.postgresql import TEXT from sqlalchemy.orm import Mapped, mapped_column from backend.common.model import DataClassBase, id_key diff --git a/backend/app/admin/model/role.py b/backend/app/admin/model/role.py index aabeb1c30..54227c7e1 100644 --- a/backend/app/admin/model/role.py +++ b/backend/app/admin/model/role.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations + +from typing import TYPE_CHECKING + from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.postgresql import TEXT @@ -8,6 +12,9 @@ from backend.app.admin.model.m2m import sys_role_data_rule, sys_role_menu, sys_user_role from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import DataRule, Menu, User + class Role(Base): """角色表""" @@ -22,10 +29,10 @@ class Role(Base): ) # 角色用户多对多 - users: Mapped[list['User']] = relationship(init=False, secondary=sys_user_role, back_populates='roles') # noqa: F821 + users: Mapped[list[User]] = relationship(init=False, secondary=sys_user_role, back_populates='roles') # 角色菜单多对多 - menus: Mapped[list['Menu']] = relationship(init=False, secondary=sys_role_menu, back_populates='roles') # noqa: F821 + menus: Mapped[list[Menu]] = relationship(init=False, secondary=sys_role_menu, back_populates='roles') # 角色数据权限规则多对多 - rules: Mapped[list['DataRule']] = relationship(init=False, secondary=sys_role_data_rule, back_populates='roles') # noqa: F821 + rules: Mapped[list[DataRule]] = relationship(init=False, secondary=sys_role_data_rule, back_populates='roles') diff --git a/backend/app/admin/model/user.py b/backend/app/admin/model/user.py index a999e3a0b..4f072a4d7 100644 --- a/backend/app/admin/model/user.py +++ b/backend/app/admin/model/user.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations + from datetime import datetime -from typing import Union +from typing import TYPE_CHECKING from sqlalchemy import VARBINARY, Boolean, DateTime, ForeignKey, String from sqlalchemy.dialects.postgresql import BYTEA, INTEGER @@ -12,6 +14,9 @@ from backend.database.db import uuid4_str from backend.utils.timezone import timezone +if TYPE_CHECKING: + from backend.app.admin.model import Dept, Role, UserSocial + class User(Base): """用户表""" @@ -31,7 +36,7 @@ class User(Base): is_staff: Mapped[bool] = mapped_column( Boolean().with_variant(INTEGER, 'postgresql'), default=False, comment='后台管理登陆(0否 1是)' ) - status: Mapped[int] = mapped_column(default=1, comment='用户账号状态(0停用 1正常)') + status: Mapped[int] = mapped_column(default=1, index=True, comment='用户账号状态(0停用 1正常)') is_multi_login: Mapped[bool] = mapped_column( Boolean().with_variant(INTEGER, 'postgresql'), default=False, comment='是否重复登陆(0否 1是)' ) @@ -48,10 +53,10 @@ class User(Base): dept_id: Mapped[int | None] = mapped_column( ForeignKey('sys_dept.id', ondelete='SET NULL'), default=None, comment='部门关联ID' ) - dept: Mapped[Union['Dept', None]] = relationship(init=False, back_populates='users') # noqa: F821 + dept: Mapped[Dept | None] = relationship(init=False, back_populates='users') # 用户社交信息一对多 - socials: Mapped[list['UserSocial']] = relationship(init=False, back_populates='user') # noqa: F821 + socials: Mapped[list[UserSocial]] = relationship(init=False, back_populates='user') # 用户角色多对多 - roles: Mapped[list['Role']] = relationship(init=False, secondary=sys_user_role, back_populates='users') # noqa: F821 + roles: Mapped[list[Role]] = relationship(init=False, secondary=sys_user_role, back_populates='users') diff --git a/backend/app/admin/model/user_social.py b/backend/app/admin/model/user_social.py index 50439aa61..87cf10e1b 100644 --- a/backend/app/admin/model/user_social.py +++ b/backend/app/admin/model/user_social.py @@ -1,12 +1,17 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Union +from __future__ import annotations + +from typing import TYPE_CHECKING from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column, relationship from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.admin.model import User + class UserSocial(Base): """用户社交表(OAuth2)""" @@ -25,4 +30,4 @@ class UserSocial(Base): user_id: Mapped[int | None] = mapped_column( ForeignKey('sys_user.id', ondelete='SET NULL'), default=None, comment='用户关联ID' ) - user: Mapped[Union['User', None]] = relationship(init=False, back_populates='socials') # noqa: F821 + user: Mapped[User | None] = relationship(init=False, back_populates='socials') diff --git a/backend/app/admin/schema/captcha.py b/backend/app/admin/schema/captcha.py index 71ea24cad..0c1bee48b 100644 --- a/backend/app/admin/schema/captcha.py +++ b/backend/app/admin/schema/captcha.py @@ -6,5 +6,7 @@ class GetCaptchaDetail(SchemaBase): + """验证码详情""" + image_type: str = Field(description='图片类型') image: str = Field(description='图片内容') diff --git a/backend/app/admin/schema/config.py b/backend/app/admin/schema/config.py index ff9a85d93..c0216084d 100644 --- a/backend/app/admin/schema/config.py +++ b/backend/app/admin/schema/config.py @@ -2,37 +2,43 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase class SaveBuiltInConfigParam(SchemaBase): - name: str - key: str - value: str + """保存内置参数配置参数""" + + name: str = Field(description='参数配置名称') + key: str = Field(description='参数配置键名') + value: str = Field(description='参数配置值') class ConfigSchemaBase(SchemaBase): - name: str - type: str | None - key: str - value: str - is_frontend: bool - remark: str | None + """参数配置基础模型""" + + name: str = Field(description='参数配置名称') + type: str | None = Field(None, description='参数配置类型') + key: str = Field(description='参数配置键名') + value: str = Field(description='参数配置值') + is_frontend: bool = Field(description='是否前端参数配置') + remark: str | None = Field(None, description='备注') class CreateConfigParam(ConfigSchemaBase): - pass + """创建参数配置参数""" class UpdateConfigParam(ConfigSchemaBase): - pass + """更新参数配置参数""" class GetConfigDetail(ConfigSchemaBase): + """参数配置详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='参数配置 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/app/admin/schema/data_rule.py b/backend/app/admin/schema/data_rule.py index 80f8211ba..17bfe4d16 100644 --- a/backend/app/admin/schema/data_rule.py +++ b/backend/app/admin/schema/data_rule.py @@ -9,28 +9,33 @@ class DataRuleSchemaBase(SchemaBase): - name: str - model: str - column: str - operator: RoleDataRuleOperatorType = Field(RoleDataRuleOperatorType.OR) - expression: RoleDataRuleExpressionType = Field(RoleDataRuleExpressionType.eq) - value: str + """数据规则基础模型""" + + name: str = Field(description='规则名称') + model: str = Field(description='模型名称') + column: str = Field(description='字段名称') + operator: RoleDataRuleOperatorType = Field(RoleDataRuleOperatorType.OR, description='操作符(AND/OR)') + expression: RoleDataRuleExpressionType = Field(RoleDataRuleExpressionType.eq, description='表达式类型') + value: str = Field(description='规则值') class CreateDataRuleParam(DataRuleSchemaBase): - pass + """创建数据规则参数""" class UpdateDataRuleParam(DataRuleSchemaBase): - pass + """更新数据规则参数""" class GetDataRuleDetail(DataRuleSchemaBase): + """数据规则详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='规则 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') - def __hash__(self): + def __hash__(self) -> int: + """计算哈希值""" return hash(self.name) diff --git a/backend/app/admin/schema/dept.py b/backend/app/admin/schema/dept.py index c8186076d..e66c56edf 100644 --- a/backend/app/admin/schema/dept.py +++ b/backend/app/admin/schema/dept.py @@ -9,27 +9,31 @@ class DeptSchemaBase(SchemaBase): - name: str - parent_id: int | None = Field(default=None, description='部门父级ID') - sort: int = Field(default=0, ge=0, description='排序') - leader: str | None = None - phone: CustomPhoneNumber | None = None - email: CustomEmailStr | None = None - status: StatusType = Field(default=StatusType.enable) + """部门基础模型""" + + name: str = Field(description='部门名称') + parent_id: int | None = Field(None, description='部门父级 ID') + sort: int = Field(0, ge=0, description='排序') + leader: str | None = Field(None, description='负责人') + phone: CustomPhoneNumber | None = Field(None, description='联系电话') + email: CustomEmailStr | None = Field(None, description='邮箱') + status: StatusType = Field(StatusType.enable, description='状态') class CreateDeptParam(DeptSchemaBase): - pass + """创建部门参数""" class UpdateDeptParam(DeptSchemaBase): - pass + """更新部门参数""" class GetDeptDetail(DeptSchemaBase): + """部门详情""" + model_config = ConfigDict(from_attributes=True) - id: int - del_flag: bool - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='部门 ID') + del_flag: bool = Field(description='是否删除') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/app/admin/schema/dict_data.py b/backend/app/admin/schema/dict_data.py index 013223ef0..e901bcf0a 100644 --- a/backend/app/admin/schema/dict_data.py +++ b/backend/app/admin/schema/dict_data.py @@ -10,29 +10,35 @@ class DictDataSchemaBase(SchemaBase): - type_id: int - label: str - value: str - sort: int - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """字典数据基础模型""" + + type_id: int = Field(description='字典类型 ID') + label: str = Field(description='字典标签') + value: str = Field(description='字典值') + sort: int = Field(description='排序') + status: StatusType = Field(StatusType.enable, description='状态') + remark: str | None = Field(None, description='备注') class CreateDictDataParam(DictDataSchemaBase): - pass + """创建字典数据参数""" class UpdateDictDataParam(DictDataSchemaBase): - pass + """更新字典数据参数""" class GetDictDataDetail(DictDataSchemaBase): + """字典数据详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='字典数据 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') class GetDictDataWithRelation(DictDataSchemaBase): - type: GetDictTypeDetail | None = None + """字典数据关联详情""" + + type: GetDictTypeDetail | None = Field(None, description='字典类型信息') diff --git a/backend/app/admin/schema/dict_type.py b/backend/app/admin/schema/dict_type.py index 42fe4916c..92b91beba 100644 --- a/backend/app/admin/schema/dict_type.py +++ b/backend/app/admin/schema/dict_type.py @@ -9,23 +9,27 @@ class DictTypeSchemaBase(SchemaBase): - name: str - code: str - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """字典类型基础模型""" + + name: str = Field(description='字典名称') + code: str = Field(description='字典编码') + status: StatusType = Field(StatusType.enable, description='状态') + remark: str | None = Field(None, description='备注') class CreateDictTypeParam(DictTypeSchemaBase): - pass + """创建字典类型参数""" class UpdateDictTypeParam(DictTypeSchemaBase): - pass + """更新字典类型参数""" class GetDictTypeDetail(DictTypeSchemaBase): + """字典类型详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='字典类型 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/app/admin/schema/login_log.py b/backend/app/admin/schema/login_log.py index 50218e8c1..3d122c43d 100644 --- a/backend/app/admin/schema/login_log.py +++ b/backend/app/admin/schema/login_log.py @@ -2,37 +2,41 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase class LoginLogSchemaBase(SchemaBase): - user_uuid: str - username: str - status: int - ip: str - country: str | None - region: str | None - city: str | None - user_agent: str - browser: str | None - os: str | None - device: str | None - msg: str - login_time: datetime + """登录日志基础模型""" + + user_uuid: str = Field(description='用户 UUID') + username: str = Field(description='用户名') + status: int = Field(description='登录状态') + ip: str = Field(description='IP 地址') + country: str | None = Field(None, description='国家') + region: str | None = Field(None, description='地区') + city: str | None = Field(None, description='城市') + user_agent: str = Field(description='用户代理') + browser: str | None = Field(None, description='浏览器') + os: str | None = Field(None, description='操作系统') + device: str | None = Field(None, description='设备') + msg: str = Field(description='消息') + login_time: datetime = Field(description='登录时间') class CreateLoginLogParam(LoginLogSchemaBase): - pass + """创建登录日志参数""" class UpdateLoginLogParam(LoginLogSchemaBase): - pass + """更新登录日志参数""" class GetLoginLogDetail(LoginLogSchemaBase): + """登录日志详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime + id: int = Field(description='日志 ID') + created_time: datetime = Field(description='创建时间') diff --git a/backend/app/admin/schema/menu.py b/backend/app/admin/schema/menu.py index 7961d1e7c..a23930f4a 100644 --- a/backend/app/admin/schema/menu.py +++ b/backend/app/admin/schema/menu.py @@ -9,32 +9,36 @@ class MenuSchemaBase(SchemaBase): - title: str - name: str - parent_id: int | None = Field(default=None, description='菜单父级ID') - sort: int = Field(default=0, ge=0, description='排序') - icon: str | None = None - path: str | None = None - menu_type: MenuType = Field(default=MenuType.directory, description='菜单类型(0目录 1菜单 2按钮)') - component: str | None = None - perms: str | None = None - status: StatusType = Field(default=StatusType.enable) - display: StatusType = Field(default=StatusType.enable) - cache: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """菜单基础模型""" + + title: str = Field(description='菜单标题') + name: str = Field(description='菜单名称') + parent_id: int | None = Field(None, description='菜单父级 ID') + sort: int = Field(0, ge=0, description='排序') + icon: str | None = Field(None, description='图标') + path: str | None = Field(None, description='路由路径') + menu_type: MenuType = Field(MenuType.directory, description='菜单类型(0目录 1菜单 2按钮)') + component: str | None = Field(None, description='组件路径') + perms: str | None = Field(None, description='权限标识') + status: StatusType = Field(StatusType.enable, description='状态') + display: StatusType = Field(StatusType.enable, description='是否显示') + cache: StatusType = Field(StatusType.enable, description='是否缓存') + remark: str | None = Field(None, description='备注') class CreateMenuParam(MenuSchemaBase): - pass + """创建菜单参数""" class UpdateMenuParam(MenuSchemaBase): - pass + """更新菜单参数""" class GetMenuDetail(MenuSchemaBase): + """菜单详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='菜单 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/app/admin/schema/opera_log.py b/backend/app/admin/schema/opera_log.py index 0791c4885..6d93d2691 100644 --- a/backend/app/admin/schema/opera_log.py +++ b/backend/app/admin/schema/opera_log.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from datetime import datetime +from typing import Any from pydantic import ConfigDict, Field @@ -9,37 +10,41 @@ class OperaLogSchemaBase(SchemaBase): - trace_id: str - username: str | None = None - method: str - title: str - path: str - ip: str - country: str | None = None - region: str | None = None - city: str | None = None - user_agent: str - os: str | None = None - browser: str | None = None - device: str | None = None - args: dict | None = None - status: StatusType = Field(default=StatusType.enable) - code: str - msg: str | None = None - cost_time: float - opera_time: datetime + """操作日志基础模型""" + + trace_id: str = Field(description='追踪 ID') + username: str | None = Field(None, description='用户名') + method: str = Field(description='请求方法') + title: str = Field(description='操作标题') + path: str = Field(description='请求路径') + ip: str = Field(description='IP 地址') + country: str | None = Field(None, description='国家') + region: str | None = Field(None, description='地区') + city: str | None = Field(None, description='城市') + user_agent: str = Field(description='用户代理') + os: str | None = Field(None, description='操作系统') + browser: str | None = Field(None, description='浏览器') + device: str | None = Field(None, description='设备') + args: dict[str, Any] | None = Field(None, description='请求参数') + status: StatusType = Field(StatusType.enable, description='状态') + code: str = Field(description='状态码') + msg: str | None = Field(None, description='消息') + cost_time: float = Field(description='耗时') + opera_time: datetime = Field(description='操作时间') class CreateOperaLogParam(OperaLogSchemaBase): - pass + """创建操作日志参数""" class UpdateOperaLogParam(OperaLogSchemaBase): - pass + """更新操作日志参数""" class GetOperaLogDetail(OperaLogSchemaBase): + """操作日志详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime + id: int = Field(description='日志 ID') + created_time: datetime = Field(description='创建时间') diff --git a/backend/app/admin/schema/role.py b/backend/app/admin/schema/role.py index 32c4e21ca..70f95b810 100644 --- a/backend/app/admin/schema/role.py +++ b/backend/app/admin/schema/role.py @@ -11,35 +11,45 @@ class RoleSchemaBase(SchemaBase): - name: str - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """角色基础模型""" + + name: str = Field(description='角色名称') + status: StatusType = Field(StatusType.enable, description='状态') + remark: str | None = Field(None, description='备注') class CreateRoleParam(RoleSchemaBase): - pass + """创建角色参数""" class UpdateRoleParam(RoleSchemaBase): - pass + """更新角色参数""" class UpdateRoleMenuParam(SchemaBase): - menus: list[int] + """更新角色菜单参数""" + + menus: list[int] = Field(description='菜单 ID 列表') class UpdateRoleRuleParam(SchemaBase): - rules: list[int] + """更新角色规则参数""" + + rules: list[int] = Field(description='数据规则 ID 列表') class GetRoleDetail(RoleSchemaBase): + """角色详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='角色 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') class GetRoleWithRelationDetail(GetRoleDetail): - menus: list[GetMenuDetail | None] = [] - rules: list[GetDataRuleDetail | None] = [] + """角色关联详情""" + + menus: list[GetMenuDetail | None] = Field([], description='菜单详情列表') + rules: list[GetDataRuleDetail | None] = Field([], description='数据规则详情列表') diff --git a/backend/app/admin/schema/token.py b/backend/app/admin/schema/token.py index 9e0b6301f..5c33e07ef 100644 --- a/backend/app/admin/schema/token.py +++ b/backend/app/admin/schema/token.py @@ -2,44 +2,56 @@ # -*- coding: utf-8 -*- from datetime import datetime +from pydantic import Field + from backend.app.admin.schema.user import GetUserInfoDetail from backend.common.enums import StatusType from backend.common.schema import SchemaBase class GetSwaggerToken(SchemaBase): - access_token: str - token_type: str = 'Bearer' - user: GetUserInfoDetail + """Swagger 认证令牌""" + + access_token: str = Field(description='访问令牌') + token_type: str = Field('Bearer', description='令牌类型') + user: GetUserInfoDetail = Field(description='用户信息') class AccessTokenBase(SchemaBase): - access_token: str - access_token_expire_time: datetime - session_uuid: str + """访问令牌基础模型""" + + access_token: str = Field(description='访问令牌') + access_token_expire_time: datetime = Field(description='令牌过期时间') + session_uuid: str = Field(description='会话 UUID') class GetNewToken(AccessTokenBase): - pass + """获取新令牌""" class GetLoginToken(AccessTokenBase): - user: GetUserInfoDetail + """获取登录令牌""" + + user: GetUserInfoDetail = Field(description='用户信息') class KickOutToken(SchemaBase): - session_uuid: str + """踢出令牌""" + + session_uuid: str = Field(description='会话 UUID') class GetTokenDetail(SchemaBase): - id: int - session_uuid: str - username: str - nickname: str - ip: str - os: str - browser: str - device: str - status: StatusType - last_login_time: str - expire_time: datetime + """令牌详情""" + + id: int = Field(description='用户 ID') + session_uuid: str = Field(description='会话 UUID') + username: str = Field(description='用户名') + nickname: str = Field(description='昵称') + ip: str = Field(description='IP 地址') + os: str = Field(description='操作系统') + browser: str = Field(description='浏览器') + device: str = Field(description='设备') + status: StatusType = Field(description='状态') + last_login_time: str = Field(description='最后登录时间') + expire_time: datetime = Field(description='过期时间') diff --git a/backend/app/admin/schema/user.py b/backend/app/admin/schema/user.py index 284b127b2..762d0d67c 100644 --- a/backend/app/admin/schema/user.py +++ b/backend/app/admin/schema/user.py @@ -13,84 +13,106 @@ class AuthSchemaBase(SchemaBase): - username: str - password: str | None + """用户认证基础模型""" + + username: str = Field(description='用户名') + password: str | None = Field(description='密码') class AuthLoginParam(AuthSchemaBase): - captcha: str + """用户登录参数""" + + captcha: str = Field(description='验证码') class RegisterUserParam(AuthSchemaBase): - nickname: str | None = None - email: EmailStr = Field(examples=['user@example.com']) + """用户注册参数""" + + nickname: str | None = Field(None, description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') class AddUserParam(AuthSchemaBase): - dept_id: int - roles: list[int] - nickname: str | None = None - email: EmailStr = Field(examples=['user@example.com']) + """添加用户参数""" + + dept_id: int = Field(description='部门 ID') + roles: list[int] = Field(description='角色 ID 列表') + nickname: str | None = Field(None, description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') class ResetPasswordParam(SchemaBase): - old_password: str - new_password: str - confirm_password: str + """重置密码参数""" + + old_password: str = Field(description='旧密码') + new_password: str = Field(description='新密码') + confirm_password: str = Field(description='确认密码') class UserInfoSchemaBase(SchemaBase): - dept_id: int | None = None - username: str - nickname: str - email: EmailStr = Field(examples=['user@example.com']) - phone: CustomPhoneNumber | None = None + """用户信息基础模型""" + + dept_id: int | None = Field(None, description='部门 ID') + username: str = Field(description='用户名') + nickname: str = Field(description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') + phone: CustomPhoneNumber | None = Field(None, description='手机号') class UpdateUserParam(UserInfoSchemaBase): - pass + """更新用户参数""" class UpdateUserRoleParam(SchemaBase): - roles: list[int] + """更新用户角色参数""" + + roles: list[int] = Field(description='角色 ID 列表') class AvatarParam(SchemaBase): + """更新头像参数""" + url: HttpUrl = Field(description='头像 http 地址') class GetUserInfoDetail(UserInfoSchemaBase): + """用户信息详情""" + model_config = ConfigDict(from_attributes=True) - dept_id: int | None = None - id: int - uuid: str - avatar: str | None = None - status: StatusType = Field(default=StatusType.enable) - is_superuser: bool - is_staff: bool - is_multi_login: bool - join_time: datetime = None - last_login_time: datetime | None = None + dept_id: int | None = Field(None, description='部门 ID') + id: int = Field(description='用户 ID') + uuid: str = Field(description='用户 UUID') + avatar: str | None = Field(None, description='头像') + status: StatusType = Field(StatusType.enable, description='状态') + is_superuser: bool = Field(description='是否超级管理员') + is_staff: bool = Field(description='是否管理员') + is_multi_login: bool = Field(description='是否允许多端登录') + join_time: datetime = Field(description='加入时间') + last_login_time: datetime | None = Field(None, description='最后登录时间') class GetUserInfoWithRelationDetail(GetUserInfoDetail): + """用户信息关联详情""" + model_config = ConfigDict(from_attributes=True) - dept: GetDeptDetail | None = None - roles: list[GetRoleWithRelationDetail] + dept: GetDeptDetail | None = Field(None, description='部门信息') + roles: list[GetRoleWithRelationDetail] = Field(description='角色列表') class GetCurrentUserInfoWithRelationDetail(GetUserInfoWithRelationDetail): + """当前用户信息关联详情""" + model_config = ConfigDict(from_attributes=True) - dept: str | None = None - roles: list[str] + dept: str | None = Field(None, description='部门名称') + roles: list[str] = Field(description='角色名称列表') @model_validator(mode='before') @classmethod def handel(cls, data: Any) -> Self: - """处理部门和角色""" + """处理部门和角色数据""" dept = data['dept'] if dept: data['dept'] = dept['name'] diff --git a/backend/app/admin/schema/user_social.py b/backend/app/admin/schema/user_social.py index 464d150fa..70d6d1b64 100644 --- a/backend/app/admin/schema/user_social.py +++ b/backend/app/admin/schema/user_social.py @@ -1,21 +1,27 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from pydantic import Field + from backend.common.enums import UserSocialType from backend.common.schema import SchemaBase class UserSocialSchemaBase(SchemaBase): - source: UserSocialType - open_id: str | None = None - uid: str | None = None - union_id: str | None = None - scope: str | None = None - code: str | None = None + """用户社交基础模型""" + + source: UserSocialType = Field(description='社交平台') + open_id: str | None = Field(None, description='开放平台 ID') + uid: str | None = Field(None, description='用户 ID') + union_id: str | None = Field(None, description='开放平台唯一 ID') + scope: str | None = Field(None, description='授权范围') + code: str | None = Field(None, description='授权码') class CreateUserSocialParam(UserSocialSchemaBase): - user_id: int + """创建用户社交参数""" + + user_id: int = Field(description='用户 ID') class UpdateUserSocialParam(SchemaBase): - pass + """更新用户社交参数""" diff --git a/backend/app/admin/service/auth_service.py b/backend/app/admin/service/auth_service.py index 382c29d78..f232ba321 100644 --- a/backend/app/admin/service/auth_service.py +++ b/backend/app/admin/service/auth_service.py @@ -30,8 +30,18 @@ class AuthService: + """认证服务类""" + @staticmethod async def user_verify(db: AsyncSession, username: str, password: str) -> User: + """ + 验证用户名和密码 + + :param db: 数据库会话 + :param username: 用户名 + :param password: 密码 + :return: + """ user = await user_dao.get_by_username(db, username) if not user: raise errors.NotFoundError(msg='用户名或密码有误') @@ -42,6 +52,12 @@ async def user_verify(db: AsyncSession, username: str, password: str) -> User: return user async def swagger_login(self, *, obj: HTTPBasicCredentials) -> tuple[str, User]: + """ + Swagger 文档登录 + + :param obj: 登录凭证 + :return: + """ async with async_db_session.begin() as db: user = await self.user_verify(db, obj.username, obj.password) await user_dao.update_login_time(db, obj.username) @@ -56,6 +72,15 @@ async def swagger_login(self, *, obj: HTTPBasicCredentials) -> tuple[str, User]: async def login( self, *, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks ) -> GetLoginToken: + """ + 用户登录 + + :param request: 请求对象 + :param response: 响应对象 + :param obj: 登录参数 + :param background_tasks: 后台任务 + :return: + """ async with async_db_session.begin() as db: user = None try: @@ -133,6 +158,12 @@ async def login( @staticmethod async def new_token(*, request: Request) -> GetNewToken: + """ + 获取新的访问令牌 + + :param request: FastAPI 请求对象 + :return: + """ refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) if not refresh_token: raise errors.TokenError(msg='Refresh Token 已过期,请重新登录') @@ -168,6 +199,13 @@ async def new_token(*, request: Request) -> GetNewToken: @staticmethod async def logout(*, request: Request, response: Response) -> None: + """ + 用户登出 + + :param request: FastAPI 请求对象 + :param response: FastAPI 响应对象 + :return: + """ token = get_token(request) token_payload = jwt_decode(token) user_id = token_payload.id diff --git a/backend/app/admin/service/config_service.py b/backend/app/admin/service/config_service.py index 6cf6d001a..618c88cd8 100644 --- a/backend/app/admin/service/config_service.py +++ b/backend/app/admin/service/config_service.py @@ -17,13 +17,28 @@ class ConfigService: + """参数配置服务类""" + @staticmethod async def get_built_in_config(type: str) -> Sequence[Config]: + """ + 获取内置参数配置 + + :param type: 参数配置类型 + :return: + """ async with async_db_session() as db: return await config_dao.get_by_type(db, type) @staticmethod async def save_built_in_config(objs: list[SaveBuiltInConfigParam], type: str) -> None: + """ + 保存内置参数配置 + + :param objs: 参数配置参数列表 + :param type: 参数配置类型 + :return: + """ async with async_db_session.begin() as db: for obj in objs: config = await config_dao.get_by_key_and_type(db, obj.key, type) @@ -35,7 +50,13 @@ async def save_built_in_config(objs: list[SaveBuiltInConfigParam], type: str) -> await config_dao.update_model(db, config.id, obj, type=type) @staticmethod - async def get(pk) -> Config | dict: + async def get(pk: int) -> Config: + """ + 获取参数配置详情 + + :param pk: 参数配置 ID + :return: + """ async with async_db_session() as db: config = await config_dao.get(db, pk) if not config: @@ -43,11 +64,24 @@ async def get(pk) -> Config | dict: return config @staticmethod - async def get_select(*, name: str = None, type: str = None) -> Select: + async def get_select(*, name: str | None = None, type: str | None = None) -> Select: + """ + 获取参数配置列表查询条件 + + :param name: 参数配置名称 + :param type: 参数配置类型 + :return: + """ return await config_dao.get_list(name=name, type=type) @staticmethod async def create(*, obj: CreateConfigParam) -> None: + """ + 创建参数配置 + + :param obj: 参数配置创建参数 + :return: + """ async with async_db_session.begin() as db: if obj.type in admin_settings.CONFIG_BUILT_IN_TYPES: raise errors.ForbiddenError(msg='非法类型参数') @@ -58,6 +92,13 @@ async def create(*, obj: CreateConfigParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateConfigParam) -> int: + """ + 更新参数配置 + + :param pk: 参数配置 ID + :param obj: 参数配置更新参数 + :return: + """ async with async_db_session.begin() as db: config = await config_dao.get(db, pk) if not config: @@ -71,6 +112,12 @@ async def update(*, pk: int, obj: UpdateConfigParam) -> int: @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除参数配置 + + :param pk: 参数配置 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await config_dao.delete(db, pk) return count diff --git a/backend/app/admin/service/data_rule_service.py b/backend/app/admin/service/data_rule_service.py index 5a322dc52..d55893de9 100644 --- a/backend/app/admin/service/data_rule_service.py +++ b/backend/app/admin/service/data_rule_service.py @@ -17,8 +17,16 @@ class DataRuleService: + """数据权限规则服务类""" + @staticmethod async def get(*, pk: int) -> DataRule: + """ + 获取数据规则详情 + + :param pk: 规则 ID + :return: + """ async with async_db_session() as db: data_rule = await data_rule_dao.get(db, pk) if not data_rule: @@ -27,6 +35,12 @@ async def get(*, pk: int) -> DataRule: @staticmethod async def get_role_rules(*, pk: int) -> list[int]: + """ + 获取角色的数据规则列表 + + :param pk: 角色 ID + :return: + """ async with async_db_session() as db: role = await role_dao.get_with_relation(db, pk) if not role: @@ -36,10 +50,17 @@ async def get_role_rules(*, pk: int) -> list[int]: @staticmethod async def get_models() -> list[str]: + """获取所有数据模型""" return list(settings.DATA_PERMISSION_MODELS.keys()) @staticmethod async def get_columns(model: str) -> list[str]: + """ + 获取数据模型的字段列表 + + :param model: 模型名称 + :return: + """ if model not in settings.DATA_PERMISSION_MODELS: raise errors.NotFoundError(msg='数据模型不存在') model_ins = dynamic_import_data_model(settings.DATA_PERMISSION_MODELS[model]) @@ -49,17 +70,30 @@ async def get_columns(model: str) -> list[str]: return model_columns @staticmethod - async def get_select(*, name: str = None) -> Select: + async def get_select(*, name: str | None = None) -> Select: + """ + 获取数据规则列表查询条件 + + :param name: 规则名称 + :return: + """ return await data_rule_dao.get_list(name=name) @staticmethod async def get_all() -> Sequence[DataRule]: + """获取所有数据规则""" async with async_db_session() as db: data_rules = await data_rule_dao.get_all(db) return data_rules @staticmethod async def create(*, obj: CreateDataRuleParam) -> None: + """ + 创建数据规则 + + :param obj: 规则创建参数 + :return: + """ async with async_db_session.begin() as db: data_rule = await data_rule_dao.get_by_name(db, obj.name) if data_rule: @@ -68,6 +102,13 @@ async def create(*, obj: CreateDataRuleParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateDataRuleParam) -> int: + """ + 更新数据规则 + + :param pk: 规则 ID + :param obj: 规则更新参数 + :return: + """ async with async_db_session.begin() as db: data_rule = await data_rule_dao.get(db, pk) if not data_rule: @@ -77,6 +118,13 @@ async def update(*, pk: int, obj: UpdateDataRuleParam) -> int: @staticmethod async def delete(*, request: Request, pk: list[int]) -> int: + """ + 删除数据规则 + + :param request: FastAPI 请求对象 + :param pk: 规则 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await data_rule_dao.delete(db, pk) await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{request.user.id}') diff --git a/backend/app/admin/service/dept_service.py b/backend/app/admin/service/dept_service.py index 61f22a997..d1fe974f5 100644 --- a/backend/app/admin/service/dept_service.py +++ b/backend/app/admin/service/dept_service.py @@ -15,8 +15,16 @@ class DeptService: + """部门服务类""" + @staticmethod async def get(*, pk: int) -> Dept: + """ + 获取部门详情 + + :param pk: 部门 ID + :return: + """ async with async_db_session() as db: dept = await dept_dao.get(db, pk) if not dept: @@ -27,6 +35,15 @@ async def get(*, pk: int) -> Dept: async def get_dept_tree( *, name: str | None = None, leader: str | None = None, phone: str | None = None, status: int | None = None ) -> list[dict[str, Any]]: + """ + 获取部门树形结构 + + :param name: 部门名称 + :param leader: 部门负责人 + :param phone: 联系电话 + :param status: 状态 + :return: + """ async with async_db_session() as db: dept_select = await dept_dao.get_all(db=db, name=name, leader=leader, phone=phone, status=status) tree_data = get_tree_data(dept_select) @@ -34,6 +51,12 @@ async def get_dept_tree( @staticmethod async def create(*, obj: CreateDeptParam) -> None: + """ + 创建部门 + + :param obj: 部门创建参数 + :return: + """ async with async_db_session.begin() as db: dept = await dept_dao.get_by_name(db, obj.name) if dept: @@ -46,6 +69,13 @@ async def create(*, obj: CreateDeptParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateDeptParam) -> int: + """ + 更新部门 + + :param pk: 部门 ID + :param obj: 部门更新参数 + :return: + """ async with async_db_session.begin() as db: dept = await dept_dao.get(db, pk) if not dept: @@ -64,8 +94,16 @@ async def update(*, pk: int, obj: UpdateDeptParam) -> int: @staticmethod async def delete(*, request: Request, pk: int) -> int: + """ + 删除部门 + + :param request: FastAPI 请求对象 + :param pk: 部门 ID + :return: + """ async with async_db_session.begin() as db: - dept_user = await dept_dao.get_with_relation(db, pk) + dept = await dept_dao.get_with_relation(db, pk) + dept_user = dept.users if dept_user: raise errors.ForbiddenError(msg='部门下存在用户,无法删除') children = await dept_dao.get_children(db, pk) diff --git a/backend/app/admin/service/dict_data_service.py b/backend/app/admin/service/dict_data_service.py index 6b52ecc88..48fdd2190 100644 --- a/backend/app/admin/service/dict_data_service.py +++ b/backend/app/admin/service/dict_data_service.py @@ -11,8 +11,16 @@ class DictDataService: + """字典数据服务类""" + @staticmethod async def get(*, pk: int) -> DictData: + """ + 获取字典数据详情 + + :param pk: 字典数据 ID + :return: + """ async with async_db_session() as db: dict_data = await dict_data_dao.get_with_relation(db, pk) if not dict_data: @@ -20,11 +28,25 @@ async def get(*, pk: int) -> DictData: return dict_data @staticmethod - async def get_select(*, label: str = None, value: str = None, status: int = None) -> Select: + async def get_select(*, label: str | None = None, value: str | None = None, status: int | None = None) -> Select: + """ + 获取字典数据列表查询条件 + + :param label: 字典数据标签 + :param value: 字典数据键值 + :param status: 状态 + :return: + """ return await dict_data_dao.get_list(label=label, value=value, status=status) @staticmethod async def create(*, obj: CreateDictDataParam) -> None: + """ + 创建字典数据 + + :param obj: 字典数据创建参数 + :return: + """ async with async_db_session.begin() as db: dict_data = await dict_data_dao.get_by_label(db, obj.label) if dict_data: @@ -36,6 +58,13 @@ async def create(*, obj: CreateDictDataParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateDictDataParam) -> int: + """ + 更新字典数据 + + :param pk: 字典数据 ID + :param obj: 字典数据更新参数 + :return: + """ async with async_db_session.begin() as db: dict_data = await dict_data_dao.get(db, pk) if not dict_data: @@ -51,6 +80,12 @@ async def update(*, pk: int, obj: UpdateDictDataParam) -> int: @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除字典数据 + + :param pk: 字典数据 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await dict_data_dao.delete(db, pk) return count diff --git a/backend/app/admin/service/dict_type_service.py b/backend/app/admin/service/dict_type_service.py index 1ba660910..6da2ca88c 100644 --- a/backend/app/admin/service/dict_type_service.py +++ b/backend/app/admin/service/dict_type_service.py @@ -9,12 +9,28 @@ class DictTypeService: + """字典类型服务类""" + @staticmethod - async def get_select(*, name: str = None, code: str = None, status: int = None) -> Select: + async def get_select(*, name: str | None = None, code: str | None = None, status: int | None = None) -> Select: + """ + 获取字典类型列表查询条件 + + :param name: 字典类型名称 + :param code: 字典类型编码 + :param status: 状态 + :return: + """ return await dict_type_dao.get_list(name=name, code=code, status=status) @staticmethod async def create(*, obj: CreateDictTypeParam) -> None: + """ + 创建字典类型 + + :param obj: 字典类型创建参数 + :return: + """ async with async_db_session.begin() as db: dict_type = await dict_type_dao.get_by_code(db, obj.code) if dict_type: @@ -23,6 +39,13 @@ async def create(*, obj: CreateDictTypeParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateDictTypeParam) -> int: + """ + 更新字典类型 + + :param pk: 字典类型 ID + :param obj: 字典类型更新参数 + :return: + """ async with async_db_session.begin() as db: dict_type = await dict_type_dao.get(db, pk) if not dict_type: @@ -35,6 +58,12 @@ async def update(*, pk: int, obj: UpdateDictTypeParam) -> int: @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除字典类型 + + :param pk: 字典类型 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await dict_type_dao.delete(db, pk) return count diff --git a/backend/app/admin/service/login_log_service.py b/backend/app/admin/service/login_log_service.py index 7cd56d856..72b3b907c 100644 --- a/backend/app/admin/service/login_log_service.py +++ b/backend/app/admin/service/login_log_service.py @@ -13,8 +13,18 @@ class LoginLogService: + """登录日志服务类""" + @staticmethod - async def get_select(*, username: str, status: int, ip: str) -> Select: + async def get_select(*, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: + """ + 获取登录日志列表查询条件 + + :param username: 用户名 + :param status: 状态 + :param ip: IP 地址 + :return: + """ return await login_log_dao.get_list(username=username, status=status, ip=ip) @staticmethod @@ -28,8 +38,20 @@ async def create( status: int, msg: str, ) -> None: + """ + 创建登录日志 + + :param db: 数据库会话 + :param request: FastAPI 请求对象 + :param user_uuid: 用户 UUID + :param username: 用户名 + :param login_time: 登录时间 + :param status: 状态 + :param msg: 消息 + :return: + """ try: - obj_in = CreateLoginLogParam( + obj = CreateLoginLogParam( user_uuid=user_uuid, username=username, status=status, @@ -44,18 +66,25 @@ async def create( msg=msg, login_time=login_time, ) - await login_log_dao.create(db, obj_in) + await login_log_dao.create(db, obj) except Exception as e: log.error(f'登录日志创建失败: {e}') @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除登录日志 + + :param pk: 日志 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await login_log_dao.delete(db, pk) return count @staticmethod async def delete_all() -> int: + """清空所有登录日志""" async with async_db_session.begin() as db: count = await login_log_dao.delete_all(db) return count diff --git a/backend/app/admin/service/menu_service.py b/backend/app/admin/service/menu_service.py index d09e98a04..12dd959be 100644 --- a/backend/app/admin/service/menu_service.py +++ b/backend/app/admin/service/menu_service.py @@ -16,8 +16,16 @@ class MenuService: + """菜单服务类""" + @staticmethod async def get(*, pk: int) -> Menu: + """ + 获取菜单详情 + + :param pk: 菜单 ID + :return: + """ async with async_db_session() as db: menu = await menu_dao.get(db, menu_id=pk) if not menu: @@ -26,6 +34,13 @@ async def get(*, pk: int) -> Menu: @staticmethod async def get_menu_tree(*, title: str | None = None, status: int | None = None) -> list[dict[str, Any]]: + """ + 获取菜单树形结构 + + :param title: 菜单标题 + :param status: 状态 + :return: + """ async with async_db_session() as db: menu_select = await menu_dao.get_all(db, title=title, status=status) menu_tree = get_tree_data(menu_select) @@ -33,6 +48,12 @@ async def get_menu_tree(*, title: str | None = None, status: int | None = None) @staticmethod async def get_role_menu_tree(*, pk: int) -> list[dict[str, Any]]: + """ + 获取角色的菜单树形结构 + + :param pk: 角色 ID + :return: + """ async with async_db_session() as db: role = await role_dao.get_with_relation(db, pk) if not role: @@ -44,6 +65,12 @@ async def get_role_menu_tree(*, pk: int) -> list[dict[str, Any]]: @staticmethod async def get_user_menu_tree(*, request: Request) -> list[dict[str, Any]]: + """ + 获取用户的菜单树形结构 + + :param request: FastAPI 请求对象 + :return: + """ async with async_db_session() as db: roles = request.user.roles menu_ids = [] @@ -57,6 +84,12 @@ async def get_user_menu_tree(*, request: Request) -> list[dict[str, Any]]: @staticmethod async def create(*, obj: CreateMenuParam) -> None: + """ + 创建菜单 + + :param obj: 菜单创建参数 + :return: + """ async with async_db_session.begin() as db: title = await menu_dao.get_by_title(db, obj.title) if title: @@ -69,6 +102,13 @@ async def create(*, obj: CreateMenuParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateMenuParam) -> int: + """ + 更新菜单 + + :param pk: 菜单 ID + :param obj: 菜单更新参数 + :return: + """ async with async_db_session.begin() as db: menu = await menu_dao.get(db, pk) if not menu: @@ -87,6 +127,13 @@ async def update(*, pk: int, obj: UpdateMenuParam) -> int: @staticmethod async def delete(*, request: Request, pk: int) -> int: + """ + 删除菜单 + + :param request: FastAPI 请求对象 + :param pk: 菜单 ID + :return: + """ async with async_db_session.begin() as db: children = await menu_dao.get_children(db, pk) if children: diff --git a/backend/app/admin/service/oauth2_service.py b/backend/app/admin/service/oauth2_service.py index 539b6535a..5adcb1ae3 100644 --- a/backend/app/admin/service/oauth2_service.py +++ b/backend/app/admin/service/oauth2_service.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + from fast_captcha import text_captcha from fastapi import BackgroundTasks, Request, Response @@ -20,15 +22,27 @@ class OAuth2Service: + """OAuth2 认证服务类""" + @staticmethod async def create_with_login( *, request: Request, response: Response, background_tasks: BackgroundTasks, - user: dict, + user: dict[str, Any], social: UserSocialType, ) -> GetLoginToken | None: + """ + 创建 OAuth2 用户并登录 + + :param request: FastAPI 请求对象 + :param response: FastAPI 响应对象 + :param background_tasks: FastAPI 后台任务 + :param user: OAuth2 用户信息 + :param social: 社交平台类型 + :return: + """ async with async_db_session.begin() as db: # 获取 OAuth2 平台用户信息 social_id = user.get('id') @@ -37,7 +51,7 @@ async def create_with_login( social_username = user.get('login') social_nickname = user.get('name') social_email = user.get('email') - if social == UserSocialType.linuxdo: # 不提供明文邮箱的平台 + if social == UserSocialType.linux_do: # 不提供明文邮箱的平台 social_email = f'{social_username}@linux.do' if not social_email: raise AuthorizationError(msg=f'授权失败,{social.value} 账户未绑定邮箱') diff --git a/backend/app/admin/service/opera_log_service.py b/backend/app/admin/service/opera_log_service.py index a5637f7aa..cc2e2dd50 100644 --- a/backend/app/admin/service/opera_log_service.py +++ b/backend/app/admin/service/opera_log_service.py @@ -8,23 +8,46 @@ class OperaLogService: + """操作日志服务类""" + @staticmethod async def get_select(*, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: + """ + 获取操作日志列表查询条件 + + :param username: 用户名 + :param status: 状态 + :param ip: IP 地址 + :return: + """ return await opera_log_dao.get_list(username=username, status=status, ip=ip) @staticmethod - async def create(*, obj_in: CreateOperaLogParam): + async def create(*, obj: CreateOperaLogParam) -> None: + """ + 创建操作日志 + + :param obj: 操作日志创建参数 + :return: + """ async with async_db_session.begin() as db: - await opera_log_dao.create(db, obj_in) + await opera_log_dao.create(db, obj) @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除操作日志 + + :param pk: 日志 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await opera_log_dao.delete(db, pk) return count @staticmethod async def delete_all() -> int: + """清空所有操作日志""" async with async_db_session.begin() as db: count = await opera_log_dao.delete_all(db) return count diff --git a/backend/app/admin/service/role_service.py b/backend/app/admin/service/role_service.py index 119bb0b4e..2f9a70a7b 100644 --- a/backend/app/admin/service/role_service.py +++ b/backend/app/admin/service/role_service.py @@ -22,8 +22,16 @@ class RoleService: + """角色服务类""" + @staticmethod async def get(*, pk: int) -> Role: + """ + 获取角色详情 + + :param pk: 角色 ID + :return: + """ async with async_db_session() as db: role = await role_dao.get_with_relation(db, pk) if not role: @@ -32,22 +40,42 @@ async def get(*, pk: int) -> Role: @staticmethod async def get_all() -> Sequence[Role]: + """获取所有角色""" async with async_db_session() as db: roles = await role_dao.get_all(db) return roles @staticmethod async def get_by_user(*, pk: int) -> Sequence[Role]: + """ + 获取用户的角色列表 + + :param pk: 用户 ID + :return: + """ async with async_db_session() as db: roles = await role_dao.get_by_user(db, user_id=pk) return roles @staticmethod - async def get_select(*, name: str = None, status: int = None) -> Select: + async def get_select(*, name: str | None = None, status: int | None = None) -> Select: + """ + 获取角色列表查询条件 + + :param name: 角色名称 + :param status: 状态 + :return: + """ return await role_dao.get_list(name=name, status=status) @staticmethod async def create(*, obj: CreateRoleParam) -> None: + """ + 创建角色 + + :param obj: 角色创建参数 + :return: + """ async with async_db_session.begin() as db: role = await role_dao.get_by_name(db, obj.name) if role: @@ -56,6 +84,13 @@ async def create(*, obj: CreateRoleParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateRoleParam) -> int: + """ + 更新角色 + + :param pk: 角色 ID + :param obj: 角色更新参数 + :return: + """ async with async_db_session.begin() as db: role = await role_dao.get(db, pk) if not role: @@ -69,6 +104,14 @@ async def update(*, pk: int, obj: UpdateRoleParam) -> int: @staticmethod async def update_role_menu(*, request: Request, pk: int, menu_ids: UpdateRoleMenuParam) -> int: + """ + 更新角色菜单 + + :param request: FastAPI 请求对象 + :param pk: 角色 ID + :param menu_ids: 菜单 ID 列表 + :return: + """ async with async_db_session.begin() as db: role = await role_dao.get(db, pk) if not role: @@ -84,6 +127,14 @@ async def update_role_menu(*, request: Request, pk: int, menu_ids: UpdateRoleMen @staticmethod async def update_role_rule(*, request: Request, pk: int, rule_ids: UpdateRoleRuleParam) -> int: + """ + 更新角色数据权限 + + :param request: FastAPI 请求对象 + :param pk: 角色 ID + :param rule_ids: 权限规则 ID 列表 + :return: + """ async with async_db_session.begin() as db: role = await role_dao.get(db, pk) if not role: @@ -99,6 +150,13 @@ async def update_role_rule(*, request: Request, pk: int, rule_ids: UpdateRoleRul @staticmethod async def delete(*, request: Request, pk: list[int]) -> int: + """ + 删除角色 + + :param request: FastAPI 请求对象 + :param pk: 角色 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await role_dao.delete(db, pk) await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{request.user.id}') diff --git a/backend/app/admin/service/user_service.py b/backend/app/admin/service/user_service.py index b9757e147..5d371b189 100644 --- a/backend/app/admin/service/user_service.py +++ b/backend/app/admin/service/user_service.py @@ -25,8 +25,16 @@ class UserService: + """用户服务类""" + @staticmethod async def register(*, obj: RegisterUserParam) -> None: + """ + 注册新用户 + + :param obj: 用户注册参数 + :return: + """ async with async_db_session.begin() as db: if not obj.password: raise errors.ForbiddenError(msg='密码为空') @@ -44,6 +52,13 @@ async def register(*, obj: RegisterUserParam) -> None: @staticmethod async def add(*, request: Request, obj: AddUserParam) -> None: + """ + 添加新用户 + + :param request: FastAPI 请求对象 + :param obj: 用户添加参数 + :return: + """ async with async_db_session.begin() as db: superuser_verify(request) username = await user_dao.get_by_username(db, obj.username) @@ -69,13 +84,20 @@ async def add(*, request: Request, obj: AddUserParam) -> None: @staticmethod async def pwd_reset(*, request: Request, obj: ResetPasswordParam) -> int: + """ + 重置用户密码 + + :param request: FastAPI 请求对象 + :param obj: 密码重置参数 + :return: + """ async with async_db_session.begin() as db: user = await user_dao.get(db, request.user.id) + if not user: + raise errors.NotFoundError(msg='用户不存在') if not password_verify(obj.old_password, user.password): raise errors.ForbiddenError(msg='原密码错误') - np1 = obj.new_password - np2 = obj.confirm_password - if np1 != np2: + if obj.new_password != obj.confirm_password: raise errors.ForbiddenError(msg='密码输入不一致') new_pwd = get_hash_password(obj.new_password, user.salt) count = await user_dao.reset_password(db, request.user.id, new_pwd) @@ -90,6 +112,12 @@ async def pwd_reset(*, request: Request, obj: ResetPasswordParam) -> int: @staticmethod async def get_userinfo(*, username: str) -> User: + """ + 获取用户信息 + + :param username: 用户名 + :return: + """ async with async_db_session() as db: user = await user_dao.get_with_relation(db, username=username) if not user: @@ -98,10 +126,17 @@ async def get_userinfo(*, username: str) -> User: @staticmethod async def update(*, request: Request, username: str, obj: UpdateUserParam) -> int: + """ + 更新用户信息 + + :param request: FastAPI 请求对象 + :param username: 用户名 + :param obj: 用户更新参数 + :return: + """ async with async_db_session.begin() as db: - if not request.user.is_superuser: - if request.user.username != username: - raise errors.ForbiddenError(msg='你只能修改自己的信息') + if not request.user.is_superuser and request.user.username != username: + raise errors.ForbiddenError(msg='你只能修改自己的信息') input_user = await user_dao.get_with_relation(db, username=username) if not input_user: raise errors.NotFoundError(msg='用户不存在') @@ -123,10 +158,17 @@ async def update(*, request: Request, username: str, obj: UpdateUserParam) -> in @staticmethod async def update_roles(*, request: Request, username: str, obj: UpdateUserRoleParam) -> None: + """ + 更新用户角色 + + :param request: FastAPI 请求对象 + :param username: 用户名 + :param obj: 角色更新参数 + :return: + """ async with async_db_session.begin() as db: - if not request.user.is_superuser: - if request.user.username != username: - raise errors.AuthorizationError + if not request.user.is_superuser and request.user.username != username: + raise errors.AuthorizationError input_user = await user_dao.get_with_relation(db, username=username) if not input_user: raise errors.NotFoundError(msg='用户不存在') @@ -139,10 +181,17 @@ async def update_roles(*, request: Request, username: str, obj: UpdateUserRolePa @staticmethod async def update_avatar(*, request: Request, username: str, avatar: AvatarParam) -> int: + """ + 更新用户头像 + + :param request: FastAPI 请求对象 + :param username: 用户名 + :param avatar: 头像参数 + :return: + """ async with async_db_session.begin() as db: - if not request.user.is_superuser: - if request.user.username != username: - raise errors.AuthorizationError + if not request.user.is_superuser and request.user.username != username: + raise errors.AuthorizationError input_user = await user_dao.get_by_username(db, username) if not input_user: raise errors.NotFoundError(msg='用户不存在') @@ -152,96 +201,138 @@ async def update_avatar(*, request: Request, username: str, avatar: AvatarParam) @staticmethod async def get_select(*, dept: int, username: str = None, phone: str = None, status: int = None) -> Select: + """ + 获取用户列表查询条件 + + :param dept: 部门 ID + :param username: 用户名 + :param phone: 手机号 + :param status: 状态 + :return: + """ return await user_dao.get_list(dept=dept, username=username, phone=phone, status=status) @staticmethod async def update_permission(*, request: Request, pk: int) -> int: + """ + 更新用户权限 + + :param request: FastAPI 请求对象 + :param pk: 用户 ID + :return: + """ async with async_db_session.begin() as db: superuser_verify(request) - if not await user_dao.get(db, pk): + user = await user_dao.get(db, pk) + if not user: raise errors.NotFoundError(msg='用户不存在') - else: - if pk == request.user.id: - raise errors.ForbiddenError(msg='非法操作') - super_status = await user_dao.get_super(db, pk) - count = await user_dao.set_super(db, pk, False if super_status else True) - await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') - return count + if pk == request.user.id: + raise errors.ForbiddenError(msg='非法操作') + super_status = await user_dao.get_super(db, pk) + count = await user_dao.set_super(db, pk, not super_status) + await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') + return count @staticmethod async def update_staff(*, request: Request, pk: int) -> int: + """ + 更新用户职员状态 + + :param request: FastAPI 请求对象 + :param pk: 用户 ID + :return: + """ async with async_db_session.begin() as db: superuser_verify(request) - if not await user_dao.get(db, pk): + user = await user_dao.get(db, pk) + if not user: raise errors.NotFoundError(msg='用户不存在') - else: - if pk == request.user.id: - raise errors.ForbiddenError(msg='非法操作') - staff_status = await user_dao.get_staff(db, pk) - count = await user_dao.set_staff(db, pk, False if staff_status else True) - await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') - return count + if pk == request.user.id: + raise errors.ForbiddenError(msg='非法操作') + staff_status = await user_dao.get_staff(db, pk) + count = await user_dao.set_staff(db, pk, not staff_status) + await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') + return count @staticmethod async def update_status(*, request: Request, pk: int) -> int: + """ + 更新用户状态 + + :param request: FastAPI 请求对象 + :param pk: 用户 ID + :return: + """ async with async_db_session.begin() as db: superuser_verify(request) - if not await user_dao.get(db, pk): + user = await user_dao.get(db, pk) + if not user: raise errors.NotFoundError(msg='用户不存在') - else: - if pk == request.user.id: - raise errors.ForbiddenError(msg='非法操作') - status = await user_dao.get_status(db, pk) - count = await user_dao.set_status(db, pk, False if status else True) - await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') - return count + if pk == request.user.id: + raise errors.ForbiddenError(msg='非法操作') + status = await user_dao.get_status(db, pk) + count = await user_dao.set_status(db, pk, 0 if status == 1 else 1) + await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{pk}') + return count @staticmethod async def update_multi_login(*, request: Request, pk: int) -> int: + """ + 更新用户多端登录状态 + + :param request: FastAPI 请求对象 + :param pk: 用户 ID + :return: + """ async with async_db_session.begin() as db: superuser_verify(request) - if not await user_dao.get(db, pk): + user = await user_dao.get(db, pk) + if not user: raise errors.NotFoundError(msg='用户不存在') + user_id = request.user.id + multi_login = await user_dao.get_multi_login(db, pk) if pk != user_id else request.user.is_multi_login + count = await user_dao.set_multi_login(db, pk, not multi_login) + # 删除当前用户缓存 + await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{request.user.id}') + token = get_token(request) + token_payload = jwt_decode(token) + latest_multi_login = await user_dao.get_multi_login(db, pk) + # 超级用户修改自身时,除当前 token 外,其他 token 失效 + if pk == user_id: + if not latest_multi_login: + key_prefix = f'{settings.TOKEN_REDIS_PREFIX}:{pk}' + await redis_client.delete_prefix(key_prefix, exclude=f'{key_prefix}:{token_payload.session_uuid}') + refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) + if refresh_token: + key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{pk}' + await redis_client.delete_prefix(key_prefix, exclude=f'{key_prefix}:{refresh_token}') + # 超级用户修改他人时,其他 token 将全部失效 else: - user_id = request.user.id - multi_login = await user_dao.get_multi_login(db, pk) if pk != user_id else request.user.is_multi_login - count = await user_dao.set_multi_login(db, pk, False if multi_login else True) - await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{request.user.id}') - token = get_token(request) - token_payload = jwt_decode(token) - latest_multi_login = await user_dao.get_multi_login(db, pk) - # 超级用户修改自身时,除当前token外,其他token失效 - if pk == user_id: - if not latest_multi_login: - key_prefix = f'{settings.TOKEN_REDIS_PREFIX}:{pk}' - await redis_client.delete_prefix( - key_prefix, exclude=f'{key_prefix}:{token_payload.session_uuid}' - ) - refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) - if refresh_token: - key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{pk}' - await redis_client.delete_prefix(key_prefix, exclude=f'{key_prefix}:{refresh_token}') - # 超级用户修改他人时,其他token将全部失效 - else: - if not latest_multi_login: - key_prefix = [f'{settings.TOKEN_REDIS_PREFIX}:{pk}'] - refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) - if refresh_token: - key_prefix.append(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{pk}') - for prefix in key_prefix: - await redis_client.delete_prefix(prefix) - return count + if not latest_multi_login: + key_prefix = [f'{settings.TOKEN_REDIS_PREFIX}:{pk}'] + refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY) + if refresh_token: + key_prefix.append(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{pk}') + for prefix in key_prefix: + await redis_client.delete_prefix(prefix) + return count @staticmethod async def delete(*, username: str) -> int: + """ + 删除用户 + + :param username: 用户名 + :return: + """ async with async_db_session.begin() as db: - input_user = await user_dao.get_by_username(db, username) - if not input_user: + user = await user_dao.get_by_username(db, username) + if not user: raise errors.NotFoundError(msg='用户不存在') - count = await user_dao.delete(db, input_user.id) + count = await user_dao.delete(db, user.id) key_prefix = [ - f'{settings.TOKEN_REDIS_PREFIX}:{input_user.id}', - f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{input_user.id}', + f'{settings.TOKEN_REDIS_PREFIX}:{user.id}', + f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user.id}', ] for key in key_prefix: await redis_client.delete_prefix(key) diff --git a/backend/app/admin/tests/utils/db.py b/backend/app/admin/tests/utils/db.py index 203d3a54d..9c03c0171 100644 --- a/backend/app/admin/tests/utils/db.py +++ b/backend/app/admin/tests/utils/db.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio.session import AsyncSession from backend.database.db import create_async_engine_and_session, create_database_url @@ -8,7 +11,7 @@ _, async_test_db_session = create_async_engine_and_session(TEST_SQLALCHEMY_DATABASE_URL) -async def override_get_db(): +async def override_get_db() -> AsyncGenerator[AsyncSession, None]: """session 生成器""" async with async_test_db_session() as session: yield session diff --git a/backend/app/generator/api/v1/gen.py b/backend/app/generator/api/v1/gen.py index b9270af8e..3713db314 100644 --- a/backend/app/generator/api/v1/gen.py +++ b/backend/app/generator/api/v1/gen.py @@ -18,7 +18,7 @@ @router.get('/tables', summary='获取数据库表') async def get_all_tables( - table_schema: Annotated[str, Query(..., description='数据库名')] = 'fba', + table_schema: Annotated[str, Query(description='数据库名')] = 'fba', ) -> ResponseSchemaModel[list[str]]: data = await gen_service.get_tables(table_schema=table_schema) return response_base.success(data=data) @@ -38,13 +38,13 @@ async def import_table(obj: ImportParam) -> ResponseModel: @router.get('/preview/{pk}', summary='生成代码预览', dependencies=[DependsJwtAuth]) -async def preview_code(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseSchemaModel[dict[str, bytes]]: +async def preview_code(pk: Annotated[int, Path(description='业务 ID')]) -> ResponseSchemaModel[dict[str, bytes]]: data = await gen_service.preview(pk=pk) return response_base.success(data=data) @router.get('/generate/{pk}/path', summary='获取代码生成路径', dependencies=[DependsJwtAuth]) -async def generate_path(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseSchemaModel[list[str]]: +async def generate_path(pk: Annotated[int, Path(description='业务 ID')]) -> ResponseSchemaModel[list[str]]: data = await gen_service.get_generate_path(pk=pk) return response_base.success(data=data) @@ -58,13 +58,13 @@ async def generate_path(pk: Annotated[int, Path(..., description='业务ID')]) - DependsRBAC, ], ) -async def generate_code(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseModel: +async def generate_code(pk: Annotated[int, Path(description='业务 ID')]) -> ResponseModel: await gen_service.generate(pk=pk) return response_base.success() @router.get('/download/{pk}', summary='下载代码', dependencies=[DependsJwtAuth]) -async def download_code(pk: Annotated[int, Path(..., description='业务ID')]): +async def download_code(pk: Annotated[int, Path(description='业务 ID')]): bio = await gen_service.download(pk=pk) return StreamingResponse( bio, diff --git a/backend/app/generator/api/v1/gen_business.py b/backend/app/generator/api/v1/gen_business.py index 30b8024e9..33975a236 100644 --- a/backend/app/generator/api/v1/gen_business.py +++ b/backend/app/generator/api/v1/gen_business.py @@ -27,13 +27,17 @@ async def get_all_businesses() -> ResponseSchemaModel[list[GetGenBusinessDetail] @router.get('/{pk}', summary='获取代码生成业务详情', dependencies=[DependsJwtAuth]) -async def get_business(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetGenBusinessDetail]: +async def get_business( + pk: Annotated[int, Path(description='业务 ID')], +) -> ResponseSchemaModel[GetGenBusinessDetail]: data = await gen_business_service.get(pk=pk) return response_base.success(data=data) @router.get('/{pk}/models', summary='获取代码生成业务所有模型', dependencies=[DependsJwtAuth]) -async def get_business_all_models(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[list[GetGenModelDetail]]: +async def get_business_all_models( + pk: Annotated[int, Path(description='业务 ID')], +) -> ResponseSchemaModel[list[GetGenModelDetail]]: data = await gen_model_service.get_by_business(business_id=pk) return response_base.success(data=data) @@ -60,7 +64,9 @@ async def create_business(obj: CreateGenBusinessParam) -> ResponseModel: DependsRBAC, ], ) -async def update_business(pk: Annotated[int, Path(...)], obj: UpdateGenBusinessParam) -> ResponseModel: +async def update_business( + pk: Annotated[int, Path(description='业务 ID')], obj: UpdateGenBusinessParam +) -> ResponseModel: count = await gen_business_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -75,7 +81,7 @@ async def update_business(pk: Annotated[int, Path(...)], obj: UpdateGenBusinessP DependsRBAC, ], ) -async def delete_business(pk: Annotated[int, Path(...)]) -> ResponseModel: +async def delete_business(pk: Annotated[int, Path(description='业务 ID')]) -> ResponseModel: count = await gen_business_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/generator/api/v1/gen_model.py b/backend/app/generator/api/v1/gen_model.py index 95c29ad35..2b1b97509 100644 --- a/backend/app/generator/api/v1/gen_model.py +++ b/backend/app/generator/api/v1/gen_model.py @@ -21,7 +21,7 @@ async def get_model_types() -> ResponseSchemaModel[list[str]]: @router.get('/{pk}', summary='获取代码生成模型详情', dependencies=[DependsJwtAuth]) -async def get_model(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetGenModelDetail]: +async def get_model(pk: Annotated[int, Path(description='模型 ID')]) -> ResponseSchemaModel[GetGenModelDetail]: data = await gen_model_service.get(pk=pk) return response_base.success(data=data) @@ -47,7 +47,7 @@ async def create_model(obj: CreateGenModelParam) -> ResponseModel: DependsRBAC, ], ) -async def update_model(pk: Annotated[int, Path(...)], obj: UpdateGenModelParam) -> ResponseModel: +async def update_model(pk: Annotated[int, Path(description='模型 ID')], obj: UpdateGenModelParam) -> ResponseModel: count = await gen_model_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -62,7 +62,7 @@ async def update_model(pk: Annotated[int, Path(...)], obj: UpdateGenModelParam) DependsRBAC, ], ) -async def delete_model(pk: Annotated[int, Path(...)]) -> ResponseModel: +async def delete_model(pk: Annotated[int, Path(description='模型 ID')]) -> ResponseModel: count = await gen_model_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/app/generator/conf.py b/backend/app/generator/conf.py index 954493e4e..4ba3517c0 100644 --- a/backend/app/generator/conf.py +++ b/backend/app/generator/conf.py @@ -6,9 +6,9 @@ class GeneratorSettings(BaseSettings): - """Admin Settings""" + """代码生成配置""" - # 模版目录 + # 模版 TEMPLATE_BACKEND_DIR_NAME: str = 'py' # 代码下载 @@ -17,7 +17,7 @@ class GeneratorSettings(BaseSettings): @lru_cache def get_generator_settings() -> GeneratorSettings: - """获取 generator 配置""" + """获取代码生成配置""" return GeneratorSettings() diff --git a/backend/app/generator/crud/crud_gen.py b/backend/app/generator/crud/crud_gen.py index 5fe1ae9b6..d789da267 100644 --- a/backend/app/generator/crud/crud_gen.py +++ b/backend/app/generator/crud/crud_gen.py @@ -9,8 +9,17 @@ class CRUDGen: + """代码生成 CRUD 类""" + @staticmethod - async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]: + async def get_all_tables(db: AsyncSession, table_schema: str) -> list[str]: + """ + 获取所有表名 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT table_name AS table_name FROM information_schema.tables @@ -30,6 +39,13 @@ async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]: @staticmethod async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]: + """ + 获取表信息 + + :param db: 数据库会话 + :param table_name: 表名 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT table_name AS table_name, table_comment AS table_comment FROM information_schema.tables @@ -51,6 +67,14 @@ async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]: @staticmethod async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]: + """ + 获取所有列信息 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :param table_name: 表名 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT column_name AS column_name, diff --git a/backend/app/generator/crud/crud_gen_business.py b/backend/app/generator/crud/crud_gen_business.py index a2e95c6b5..80ea8bc98 100644 --- a/backend/app/generator/crud/crud_gen_business.py +++ b/backend/app/generator/crud/crud_gen_business.py @@ -10,61 +10,64 @@ class CRUDGenBusiness(CRUDPlus[GenBusiness]): + """代码生成业务 CRUD 类""" + async def get(self, db: AsyncSession, pk: int) -> GenBusiness | None: """ - 获取代码生成业务表 + 获取代码生成业务 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 代码生成业务 ID :return: """ return await self.select_model(db, pk) async def get_by_name(self, db: AsyncSession, name: str) -> GenBusiness | None: """ - 通过 name 获取代码生成业务表 + 通过 name 获取代码生成业务 - :param db: - :param name: + :param db: 数据库会话 + :param name: 表名 :return: """ return await self.select_model_by_column(db, table_name_en=name) async def get_all(self, db: AsyncSession) -> Sequence[GenBusiness]: """ - 获取所有代码生成业务表 + 获取所有代码生成业务 + :param db: 数据库会话 :return: """ return await self.select_models(db) - async def create(self, db: AsyncSession, obj_in: CreateGenBusinessParam) -> None: + async def create(self, db: AsyncSession, obj: CreateGenBusinessParam) -> None: """ - 创建代码生成业务表 + 创建代码生成业务 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建代码生成业务参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenBusinessParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateGenBusinessParam) -> int: """ - 更新代码生成业务表 + 更新代码生成业务 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 代码生成业务 ID + :param obj: 更新代码生成业务参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: int) -> int: """ - 删除代码生成业务表 + 删除代码生成业务 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 代码生成业务 ID :return: """ return await self.delete_model(db, pk) diff --git a/backend/app/generator/crud/crud_gen_model.py b/backend/app/generator/crud/crud_gen_model.py index b681268cd..419c16d36 100644 --- a/backend/app/generator/crud/crud_gen_model.py +++ b/backend/app/generator/crud/crud_gen_model.py @@ -10,53 +10,57 @@ class CRUDGenModel(CRUDPlus[GenModel]): + """代码生成模型 CRUD 类""" + async def get(self, db: AsyncSession, pk: int) -> GenModel | None: """ 获取代码生成模型列 + :param db: 数据库会话 + :param pk: 代码生成模型 ID :return: """ return await self.select_model(db, pk) - async def get_all_by_business_id(self, db: AsyncSession, business_id: int) -> Sequence[GenModel]: + async def get_all_by_business(self, db: AsyncSession, business_id: int) -> Sequence[GenModel]: """ 获取所有代码生成模型列 - :param db: - :param business_id: + :param db: 数据库会话 + :param business_id: 业务 ID :return: """ return await self.select_models_order(db, sort_columns='sort', gen_business_id=business_id) - async def create(self, db: AsyncSession, obj_in: CreateGenModelParam, pd_type: str | None = None) -> None: + async def create(self, db: AsyncSession, obj: CreateGenModelParam, pd_type: str | None = None) -> None: """ - 创建代码生成模型表 + 创建代码生成模型 - :param db: - :param obj_in: - :param pd_type: + :param db: 数据库会话 + :param obj: 创建代码生成模型参数 + :param pd_type: Pydantic 类型 :return: """ - await self.create_model(db, obj_in, pd_type=pd_type) + await self.create_model(db, obj, pd_type=pd_type) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenModelParam, pd_type: str | None = None) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateGenModelParam, pd_type: str | None = None) -> int: """ - 更细代码生成模型表 + 更新代码生成模型 - :param db: - :param pk: - :param obj_in: - :param pd_type: + :param db: 数据库会话 + :param pk: 代码生成模型 ID + :param obj: 更新代码生成模型参数 + :param pd_type: Pydantic 类型 :return: """ - return await self.update_model(db, pk, obj_in, pd_type=pd_type) + return await self.update_model(db, pk, obj, pd_type=pd_type) async def delete(self, db: AsyncSession, pk: int) -> int: """ - 删除代码生成模型表 + 删除代码生成模型 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 代码生成模型 ID :return: """ return await self.delete_model(db, pk) diff --git a/backend/app/generator/model/gen_business.py b/backend/app/generator/model/gen_business.py index 15a9dfa5b..d0b47e77d 100644 --- a/backend/app/generator/model/gen_business.py +++ b/backend/app/generator/model/gen_business.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import TYPE_CHECKING + from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.postgresql import TEXT @@ -7,6 +9,9 @@ from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.generator.model import GenModel + class GenBusiness(Base): """代码生成业务表""" @@ -28,4 +33,4 @@ class GenBusiness(Base): LONGTEXT().with_variant(TEXT, 'postgresql'), default=None, comment='备注' ) # 代码生成业务模型一对多 - gen_model: Mapped[list['GenModel']] = relationship(init=False, back_populates='gen_business') # noqa: F821 + gen_model: Mapped[list['GenModel']] = relationship(init=False, back_populates='gen_business') diff --git a/backend/app/generator/model/gen_model.py b/backend/app/generator/model/gen_model.py index 74905f096..5433469db 100644 --- a/backend/app/generator/model/gen_model.py +++ b/backend/app/generator/model/gen_model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Union +from typing import TYPE_CHECKING, Union from sqlalchemy import ForeignKey, String from sqlalchemy.dialects.mysql import LONGTEXT @@ -9,6 +9,9 @@ from backend.common.model import DataClassBase, id_key +if TYPE_CHECKING: + from backend.app.generator.model import GenBusiness + class GenModel(DataClassBase): """代码生成模型表""" @@ -32,4 +35,4 @@ class GenModel(DataClassBase): gen_business_id: Mapped[int] = mapped_column( ForeignKey('sys_gen_business.id', ondelete='CASCADE'), default=0, comment='代码生成业务ID' ) - gen_business: Mapped[Union['GenBusiness', None]] = relationship(init=False, back_populates='gen_model') # noqa: F821 + gen_business: Mapped[Union['GenBusiness', None]] = relationship(init=False, back_populates='gen_model') diff --git a/backend/app/generator/schema/gen.py b/backend/app/generator/schema/gen.py index fce60d38a..85e18cb9a 100644 --- a/backend/app/generator/schema/gen.py +++ b/backend/app/generator/schema/gen.py @@ -6,6 +6,8 @@ class ImportParam(SchemaBase): + """导入参数""" + app: str = Field(description='应用名称,用于代码生成到指定 app') table_name: str = Field(description='数据库表名') table_schema: str = Field(description='数据库名') diff --git a/backend/app/generator/schema/gen_business.py b/backend/app/generator/schema/gen_business.py index 982c6cdd8..53cc72472 100644 --- a/backend/app/generator/schema/gen_business.py +++ b/backend/app/generator/schema/gen_business.py @@ -9,35 +9,40 @@ class GenBusinessSchemaBase(SchemaBase): - app_name: str - table_name_en: str - table_name_zh: str - table_simple_name_zh: str - table_comment: str | None = None - schema_name: str | None = None - default_datetime_column: bool = Field(default=True) - api_version: str = Field(default='v1') - gen_path: str | None = None - remark: str | None = None + """代码生成业务基础模型""" + + app_name: str = Field(description='应用名称(英文)') + table_name_en: str = Field(description='表名称(英文)') + table_name_zh: str = Field(description='表名称(中文)') + table_simple_name_zh: str = Field(description='表名称(中文简称)') + table_comment: str | None = Field(None, description='表描述') + schema_name: str | None = Field(None, description='Schema 名称 (默认为英文表名称)') + default_datetime_column: bool = Field(True, description='是否存在默认时间列') + api_version: str = Field('v1', description='代码生成 api 版本') + gen_path: str | None = Field(None, description='代码生成路径(默认为 app 根路径)') + remark: str | None = Field(None, description='备注') @model_validator(mode='after') def check_schema_name(self) -> Self: + """检查并设置 schema 名称""" if self.schema_name is None: self.schema_name = self.table_name_en return self class CreateGenBusinessParam(GenBusinessSchemaBase): - pass + """创建代码生成业务参数""" class UpdateGenBusinessParam(GenBusinessSchemaBase): - pass + """更新代码生成业务参数""" class GetGenBusinessDetail(GenBusinessSchemaBase): + """获取代码生成业务详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='主键 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/app/generator/schema/gen_model.py b/backend/app/generator/schema/gen_model.py index 524f96fae..c2c3783c3 100644 --- a/backend/app/generator/schema/gen_model.py +++ b/backend/app/generator/schema/gen_model.py @@ -7,32 +7,37 @@ class GenModelSchemaBase(SchemaBase): - name: str - comment: str | None = None - type: str - default: str | None = None - sort: int - length: int - is_pk: bool = Field(default=False) - is_nullable: bool = Field(default=False) - gen_business_id: int | None = Field(ge=1) + """代码生成模型基础模型""" + + name: str = Field(description='列名称') + comment: str | None = Field(None, description='列描述') + type: str = Field(description='SQLA 模型列类型') + default: str | None = Field(None, description='列默认值') + sort: int = Field(description='列排序') + length: int = Field(description='列长度') + is_pk: bool = Field(False, description='是否主键') + is_nullable: bool = Field(False, description='是否可为空') + gen_business_id: int = Field(description='代码生成业务ID') @field_validator('type') @classmethod - def type_update(cls, v): + def type_update(cls, v: str) -> str: + """更新列类型""" return sql_type_to_sqlalchemy(v) class CreateGenModelParam(GenModelSchemaBase): - pass + """创建代码生成模型参数""" class UpdateGenModelParam(GenModelSchemaBase): - pass + """更新代码生成模型参数""" class GetGenModelDetail(GenModelSchemaBase): + """获取代码生成模型详情""" + model_config = ConfigDict(from_attributes=True) - id: int - pd_type: str + id: int = Field(description='主键 ID') + pd_type: str = Field(description='列类型对应的 pydantic 类型') diff --git a/backend/app/generator/service/gen_business_service.py b/backend/app/generator/service/gen_business_service.py index ec8ece114..c2062f2b1 100644 --- a/backend/app/generator/service/gen_business_service.py +++ b/backend/app/generator/service/gen_business_service.py @@ -10,8 +10,16 @@ class GenBusinessService: + """代码生成业务服务类""" + @staticmethod async def get(*, pk: int) -> GenBusiness: + """ + 获取指定 ID 的业务 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: @@ -20,12 +28,18 @@ async def get(*, pk: int) -> GenBusiness: @staticmethod async def get_all() -> Sequence[GenBusiness]: + """获取所有业务""" async with async_db_session() as db: - businesses = await gen_business_dao.get_all(db) - return businesses + return await gen_business_dao.get_all(db) @staticmethod async def create(*, obj: CreateGenBusinessParam) -> None: + """ + 创建业务 + + :param obj: 创建业务参数 + :return: + """ async with async_db_session.begin() as db: business = await gen_business_dao.get_by_name(db, obj.table_name_en) if business: @@ -34,15 +48,26 @@ async def create(*, obj: CreateGenBusinessParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateGenBusinessParam) -> int: + """ + 更新业务 + + :param pk: 业务 ID + :param obj: 更新业务参数 + :return: + """ async with async_db_session.begin() as db: - count = await gen_business_dao.update(db, pk, obj) - return count + return await gen_business_dao.update(db, pk, obj) @staticmethod async def delete(*, pk: int) -> int: + """ + 删除业务 + + :param pk: 业务 ID + :return: + """ async with async_db_session.begin() as db: - count = await gen_business_dao.delete(db, pk) - return count + return await gen_business_dao.delete(db, pk) gen_business_service: GenBusinessService = GenBusinessService() diff --git a/backend/app/generator/service/gen_model_service.py b/backend/app/generator/service/gen_model_service.py index 3d284f3b3..70e1f7555 100644 --- a/backend/app/generator/service/gen_model_service.py +++ b/backend/app/generator/service/gen_model_service.py @@ -12,50 +12,85 @@ class GenModelService: + """代码生成模型服务类""" + @staticmethod async def get(*, pk: int) -> GenModel: + """ + 获取指定 ID 的模型 + + :param pk: 模型 ID + :return: + """ async with async_db_session() as db: - gen_model = await gen_model_dao.get(db, pk) - return gen_model + model = await gen_model_dao.get(db, pk) + if not model: + raise errors.NotFoundError(msg='代码生成模型不存在') + return model @staticmethod async def get_types() -> list[str]: + """获取所有 MySQL 列类型""" types = GenModelMySQLColumnType.get_member_keys() types.sort() return types @staticmethod async def get_by_business(*, business_id: int) -> Sequence[GenModel]: + """ + 获取指定业务的所有模型 + + :param business_id: 业务 ID + :return: + """ async with async_db_session() as db: - gen_models = await gen_model_dao.get_all_by_business_id(db, business_id) - return gen_models + return await gen_model_dao.get_all_by_business(db, business_id) @staticmethod async def create(*, obj: CreateGenModelParam) -> None: + """ + 创建模型 + + :param obj: 创建模型参数 + :return: + """ async with async_db_session.begin() as db: - gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id) + gen_models = await gen_model_dao.get_all_by_business(db, obj.gen_business_id) if obj.name in [gen_model.name for gen_model in gen_models]: raise errors.ForbiddenError(msg='禁止添加相同列到同一模型表') + pd_type = sql_type_to_pydantic(obj.type) await gen_model_dao.create(db, obj, pd_type=pd_type) @staticmethod async def update(*, pk: int, obj: UpdateGenModelParam) -> int: + """ + 更新模型 + + :param pk: 模型 ID + :param obj: 更新模型参数 + :return: + """ async with async_db_session.begin() as db: model = await gen_model_dao.get(db, pk) if obj.name != model.name: - gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id) + gen_models = await gen_model_dao.get_all_by_business(db, obj.gen_business_id) if obj.name in [gen_model.name for gen_model in gen_models]: raise errors.ForbiddenError(msg='模型列名已存在') + pd_type = sql_type_to_pydantic(obj.type) - count = await gen_model_dao.update(db, pk, obj, pd_type=pd_type) - return count + return await gen_model_dao.update(db, pk, obj, pd_type=pd_type) @staticmethod async def delete(*, pk: int) -> int: + """ + 删除模型 + + :param pk: 模型 ID + :return: + """ async with async_db_session.begin() as db: - count = await gen_model_dao.delete(db, pk) - return count + return await gen_model_dao.delete(db, pk) gen_model_service: GenModelService = GenModelService() diff --git a/backend/app/generator/service/gen_service.py b/backend/app/generator/service/gen_service.py index 2a712fd66..19a480bbb 100644 --- a/backend/app/generator/service/gen_service.py +++ b/backend/app/generator/service/gen_service.py @@ -5,7 +5,6 @@ import zipfile from pathlib import Path -from typing import Sequence import aiofiles @@ -20,27 +19,43 @@ from backend.app.generator.schema.gen_model import CreateGenModelParam from backend.app.generator.service.gen_model_service import gen_model_service from backend.common.exception import errors -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH from backend.database.db import async_db_session from backend.utils.gen_template import gen_template from backend.utils.type_conversion import sql_type_to_pydantic class GenService: + """代码生成服务类""" + @staticmethod - async def get_tables(*, table_schema: str) -> Sequence[str]: + async def get_tables(*, table_schema: str) -> list[str]: + """ + 获取指定 schema 下的所有表名 + + :param table_schema: 数据库 schema 名称 + :return: + """ async with async_db_session() as db: return await gen_dao.get_all_tables(db, table_schema) @staticmethod async def import_business_and_model(*, obj: ImportParam) -> None: + """ + 导入业务和模型数据 + + :param obj: 导入参数对象 + :return: + """ async with async_db_session.begin() as db: table_info = await gen_dao.get_table(db, obj.table_name) if not table_info: raise errors.NotFoundError(msg='数据库表不存在') + business_info = await gen_business_dao.get_by_name(db, obj.table_name) if business_info: raise errors.ForbiddenError(msg='已存在相同数据库表业务') + table_name = table_info[0] business_data = { 'app_name': obj.app, @@ -52,6 +67,7 @@ async def import_business_and_model(*, obj: ImportParam) -> None: new_business = GenBusiness(**CreateGenBusinessParam(**business_data).model_dump()) db.add(new_business) await db.flush() + column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name) for column in column_info: column_type = column[-1].split('(')[0].upper() @@ -70,20 +86,34 @@ async def import_business_and_model(*, obj: ImportParam) -> None: @staticmethod async def render_tpl_code(*, business: GenBusiness) -> dict[str, str]: + """ + 渲染模板代码 + + :param business: 业务对象 + :return: + """ gen_models = await gen_model_service.get_by_business(business_id=business.id) if not gen_models: raise errors.NotFoundError(msg='代码生成模型表为空') + gen_vars = gen_template.get_vars(business, gen_models) - tpl_code_map = {} - for tpl_path in gen_template.get_template_paths(): - tpl_code_map[tpl_path] = await gen_template.get_template(tpl_path).render_async(**gen_vars) - return tpl_code_map + return { + tpl_path: await gen_template.get_template(tpl_path).render_async(**gen_vars) + for tpl_path in gen_template.get_template_paths() + } async def preview(self, *, pk: int) -> dict[str, bytes]: + """ + 预览生成的代码 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + tpl_code_map = await self.render_tpl_code(business=business) return { tpl.replace('.jinja', '.py') if tpl.startswith('py') else ...: code.encode('utf-8') @@ -92,42 +122,50 @@ async def preview(self, *, pk: int) -> dict[str, bytes]: @staticmethod async def get_generate_path(*, pk: int) -> list[str]: + """ + 获取代码生成路径 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') - gen_path = business.gen_path - if not gen_path: - # 伪加密路径 - gen_path = 'current-backend-app-path' + + gen_path = business.gen_path or 'fba-backend-app-path' target_files = gen_template.get_code_gen_paths(business) - code_gen_paths = [] - for target_file in target_files: - code_gen_paths.append(os.path.join(gen_path, *target_file.split('/')[1:])) - return code_gen_paths + return [os.path.join(gen_path, *target_file.split('/')[1:]) for target_file in target_files] async def generate(self, *, pk: int) -> None: + """ + 生成代码文件 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + tpl_code_map = await self.render_tpl_code(business=business) - gen_path = business.gen_path - if not gen_path: - gen_path = os.path.join(BasePath, 'app') + gen_path = business.gen_path or os.path.join(BASE_PATH, 'app') + for tpl_path, code in tpl_code_map.items(): code_filepath = os.path.join( gen_path, *gen_template.get_code_gen_path(tpl_path, business).split('/')[1:], ) code_folder = Path(str(code_filepath)).parent - if not code_folder.exists(): - code_folder.mkdir(parents=True, exist_ok=True) + code_folder.mkdir(parents=True, exist_ok=True) + # 写入 init 文件 init_filepath = code_folder.joinpath('__init__.py') if not init_filepath.exists(): async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f: await f.write(gen_template.init_content) + if 'api' in str(code_folder): # api __init__.py api_init_filepath = code_folder.parent.joinpath('__init__.py') @@ -136,12 +174,14 @@ async def generate(self, *, pk: int) -> None: await f.write(gen_template.init_content) # app __init__.py app_init_filepath = api_init_filepath.parent.joinpath('__init__.py') - if not app_init_filepath: + if not app_init_filepath.exists(): async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f: await f.write(gen_template.init_content) - # 写入代码文件呢 + + # 写入代码文件 async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f: await f.write(code) + # model init 文件补充 if code_folder.name == 'model': async with aiofiles.open(init_filepath, 'a', encoding='utf-8') as f: @@ -151,33 +191,42 @@ async def generate(self, *, pk: int) -> None: ) async def download(self, *, pk: int) -> io.BytesIO: + """ + 下载生成的代码 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + bio = io.BytesIO() - zf = zipfile.ZipFile(bio, 'w') - tpl_code_map = await self.render_tpl_code(business=business) - for tpl_path, code in tpl_code_map.items(): - # 写入代码文件 - new_code_path = gen_template.get_code_gen_path(tpl_path, business) - zf.writestr(new_code_path, code) - # 写入 init 文件 - init_filepath = os.path.join(*new_code_path.split('/')[:-1], '__init__.py') - if 'model' not in new_code_path.split('/'): - zf.writestr(init_filepath, gen_template.init_content) - else: - zf.writestr( - init_filepath, - f'{gen_template.init_content}' - f'from backend.app.{business.app_name}.model.{business.table_name_en} ' - f'import {to_pascal(business.table_name_en)}\n', - ) - if 'api' in new_code_path: - # api __init__.py - api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py') - zf.writestr(api_init_filepath, gen_template.init_content) - zf.close() + with zipfile.ZipFile(bio, 'w') as zf: + tpl_code_map = await self.render_tpl_code(business=business) + for tpl_path, code in tpl_code_map.items(): + # 写入代码文件 + new_code_path = gen_template.get_code_gen_path(tpl_path, business) + zf.writestr(new_code_path, code) + + # 写入 init 文件 + init_filepath = os.path.join(*new_code_path.split('/')[:-1], '__init__.py') + if 'model' not in new_code_path.split('/'): + zf.writestr(init_filepath, gen_template.init_content) + else: + zf.writestr( + init_filepath, + f'{gen_template.init_content}' + f'from backend.app.{business.app_name}.model.{business.table_name_en} ' + f'import {to_pascal(business.table_name_en)}\n', + ) + + if 'api' in new_code_path: + # api __init__.py + api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py') + zf.writestr(api_init_filepath, gen_template.init_content) + bio.seek(0) return bio diff --git a/backend/app/task/api/router.py b/backend/app/task/api/router.py index 18cb9ba66..1b5bbf6d4 100644 --- a/backend/app/task/api/router.py +++ b/backend/app/task/api/router.py @@ -5,6 +5,6 @@ from backend.app.task.api.v1.task import router as task_router from backend.core.conf import settings -v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH) +v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH, tags=['任务']) -v1.include_router(task_router, prefix='/tasks', tags=['任务']) +v1.include_router(task_router, prefix='/tasks') diff --git a/backend/app/task/api/v1/task.py b/backend/app/task/api/v1/task.py index de27ef7d9..5f66c5ec0 100644 --- a/backend/app/task/api/v1/task.py +++ b/backend/app/task/api/v1/task.py @@ -27,7 +27,7 @@ async def get_all_tasks() -> ResponseSchemaModel[list[str]]: description='此接口被视为作废,建议使用 flower 查看任务详情', dependencies=[DependsJwtAuth], ) -async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> ResponseSchemaModel[TaskResult]: +async def get_task_detail(tid: Annotated[str, Path(description='任务 UUID')]) -> ResponseSchemaModel[TaskResult]: status = task_service.get_detail(tid=tid) return response_base.success(data=status) @@ -40,7 +40,7 @@ async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> DependsRBAC, ], ) -async def revoke_task(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel: +async def revoke_task(tid: Annotated[str, Path(description='任务 UUID')]) -> ResponseModel: task_service.revoke(tid=tid) return response_base.success() diff --git a/backend/app/task/celery.py b/backend/app/task/celery.py index 8750d0d46..7ddcfd83a 100644 --- a/backend/app/task/celery.py +++ b/backend/app/task/celery.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + import celery import celery_aio_pool @@ -9,61 +11,64 @@ __all__ = ['celery_app'] -def init_celery() -> celery.Celery: - """初始化 celery 应用""" - - # TODO: Update this work if celery version >= 6.0.0 - # https://github.com/fastapi-practices/fastapi_best_architecture/issues/321 - # https://github.com/celery/celery/issues/7874 - celery.app.trace.build_tracer = celery_aio_pool.build_async_tracer - celery.app.trace.reset_worker_optimizations() - - # Celery Schedule Tasks - # https://docs.celeryq.dev/en/stable/userguide/periodic-tasks.html - beat_schedule = task_settings.CELERY_SCHEDULE - - # Celery Config - # https://docs.celeryq.dev/en/stable/userguide/configuration.html - broker_url = ( - ( +def get_broker_url() -> str: + """获取消息代理 URL""" + if task_settings.CELERY_BROKER == 'redis': + return ( f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:' f'{settings.REDIS_PORT}/{task_settings.CELERY_BROKER_REDIS_DATABASE}' ) - if task_settings.CELERY_BROKER == 'redis' - else ( - f'amqp://{task_settings.RABBITMQ_USERNAME}:{task_settings.RABBITMQ_PASSWORD}@' - f'{task_settings.RABBITMQ_HOST}:{task_settings.RABBITMQ_PORT}' - ) + return ( + f'amqp://{task_settings.RABBITMQ_USERNAME}:{task_settings.RABBITMQ_PASSWORD}@' + f'{task_settings.RABBITMQ_HOST}:{task_settings.RABBITMQ_PORT}' ) - result_backend = ( + + +def get_result_backend() -> str: + """获取结果后端 URL""" + return ( f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:' f'{settings.REDIS_PORT}/{task_settings.CELERY_BACKEND_REDIS_DATABASE}' ) - result_backend_transport_options = { - 'global_keyprefix': f'{task_settings.CELERY_BACKEND_REDIS_PREFIX}', + + +def get_result_backend_transport_options() -> dict[str, Any]: + """获取结果后端传输选项""" + return { + 'global_keyprefix': task_settings.CELERY_BACKEND_REDIS_PREFIX, 'retry_policy': { 'timeout': task_settings.CELERY_BACKEND_REDIS_TIMEOUT, }, } + +def init_celery() -> celery.Celery: + """初始化 Celery 应用""" + + # TODO: Update this work if celery version >= 6.0.0 + # https://github.com/fastapi-practices/fastapi_best_architecture/issues/321 + # https://github.com/celery/celery/issues/7874 + celery.app.trace.build_tracer = celery_aio_pool.build_async_tracer + celery.app.trace.reset_worker_optimizations() + app = celery.Celery( 'fba_celery', enable_utc=False, timezone=settings.DATETIME_TIMEZONE, - beat_schedule=beat_schedule, - broker_url=broker_url, + beat_schedule=task_settings.CELERY_SCHEDULE, + broker_url=get_broker_url(), broker_connection_retry_on_startup=True, - result_backend=result_backend, - result_backend_transport_options=result_backend_transport_options, + result_backend=get_result_backend(), + result_backend_transport_options=get_result_backend_transport_options(), task_cls='app.task.celery_task.base:TaskBase', task_track_started=True, ) - # Load task modules + # 自动发现任务 app.autodiscover_tasks(task_settings.CELERY_TASK_PACKAGES) return app -# 创建 celery 实例 +# 创建 Celery 实例 celery_app: celery.Celery = init_celery() diff --git a/backend/app/task/celery_task/base.py b/backend/app/task/celery_task/base.py index fcca1e031..4e21f27cd 100644 --- a/backend/app/task/celery_task/base.py +++ b/backend/app/task/celery_task/base.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any from celery import Task from sqlalchemy.exc import SQLAlchemyError @@ -9,16 +10,37 @@ class TaskBase(Task): - """任务基类""" + """Celery 任务基类""" autoretry_for = (SQLAlchemyError,) max_retries = task_settings.CELERY_TASK_MAX_RETRIES - async def before_start(self, task_id, args, kwargs): + async def before_start(self, task_id: str, args, kwargs) -> None: + """ + 任务开始前执行钩子 + + :param task_id: 任务 ID + :return: + """ await task_notification(msg=f'任务 {task_id} 开始执行') - async def on_success(self, retval, task_id, args, kwargs): + async def on_success(self, retval: Any, task_id: str, args, kwargs) -> None: + """ + 任务成功后执行钩子 + + :param retval: 任务返回值 + :param task_id: 任务 ID + :return: + """ await task_notification(msg=f'任务 {task_id} 执行成功') - async def on_failure(self, exc, task_id, args, kwargs, einfo): + async def on_failure(self, exc: Exception, task_id: str, args, kwargs, einfo) -> None: + """ + 任务失败后执行钩子 + + :param exc: 异常对象 + :param task_id: 任务 ID + :param einfo: 异常信息 + :return: + """ await task_notification(msg=f'任务 {task_id} 执行失败') diff --git a/backend/app/task/celery_task/tasks.py b/backend/app/task/celery_task/tasks.py index 231fd0d67..08d79e295 100644 --- a/backend/app/task/celery_task/tasks.py +++ b/backend/app/task/celery_task/tasks.py @@ -7,5 +7,6 @@ @celery_app.task(name='task_demo_async') async def task_demo_async() -> str: + """异步示例任务,模拟耗时操作""" await sleep(20) return 'test async' diff --git a/backend/app/task/conf.py b/backend/app/task/conf.py index 73fb98ae0..38e9e1263 100644 --- a/backend/app/task/conf.py +++ b/backend/app/task/conf.py @@ -1,35 +1,35 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from functools import lru_cache -from typing import Literal +from typing import Any, Literal from celery.schedules import crontab from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class TaskSettings(BaseSettings): - """Task Settings""" + """Celery 任务配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict(env_file=f'{BASE_PATH}/.env', env_file_encoding='utf-8', extra='ignore') - # Env Config + # .env 环境 ENVIRONMENT: Literal['dev', 'pro'] - # Env Celery - CELERY_BROKER_REDIS_DATABASE: int # 仅在 dev 模式时生效 + # .env Redis 配置 + CELERY_BROKER_REDIS_DATABASE: int CELERY_BACKEND_REDIS_DATABASE: int - # Env Rabbitmq + # .env RabbitMQ 配置 # docker run -d --hostname fba-mq --name fba-mq -p 5672:5672 -p 15672:15672 rabbitmq:latest RABBITMQ_HOST: str RABBITMQ_PORT: int RABBITMQ_USERNAME: str RABBITMQ_PASSWORD: str - # Celery + # Celery 基础配置 CELERY_BROKER: Literal['rabbitmq', 'redis'] = 'redis' CELERY_BACKEND_REDIS_PREFIX: str = 'fba:celery:' CELERY_BACKEND_REDIS_TIMEOUT: int = 5 @@ -38,7 +38,9 @@ class TaskSettings(BaseSettings): 'app.task.celery_task.db_log', ] CELERY_TASK_MAX_RETRIES: int = 5 - CELERY_SCHEDULE: dict = { + + # Celery 定时任务配置 + CELERY_SCHEDULE: dict[str, dict[str, Any]] = { 'exec-every-10-seconds': { 'task': 'task_demo_async', 'schedule': 10, @@ -55,7 +57,8 @@ class TaskSettings(BaseSettings): @model_validator(mode='before') @classmethod - def validate_celery_broker(cls, values): + def validate_celery_broker(cls, values: Any) -> Any: + """生产环境强制使用 RabbitMQ 作为消息代理""" if values['ENVIRONMENT'] == 'pro': values['CELERY_BROKER'] = 'rabbitmq' return values @@ -63,7 +66,7 @@ def validate_celery_broker(cls, values): @lru_cache def get_task_settings() -> TaskSettings: - """获取 task 配置""" + """获取 Celery 任务配置""" return TaskSettings() diff --git a/backend/app/task/schema/task.py b/backend/app/task/schema/task.py index 850e58606..fe7b37fed 100644 --- a/backend/app/task/schema/task.py +++ b/backend/app/task/schema/task.py @@ -1,23 +1,29 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + from pydantic import Field from backend.common.schema import SchemaBase class RunParam(SchemaBase): + """任务运行参数""" + name: str = Field(description='任务名称') - args: list | None = Field(default=None, description='任务函数位置参数') - kwargs: dict | None = Field(default=None, description='任务函数关键字参数') + args: list[Any] | None = Field(None, description='任务函数位置参数') + kwargs: dict[str, Any] | None = Field(None, description='任务函数关键字参数') class TaskResult(SchemaBase): - result: str - traceback: str - status: str - name: str - args: list | None - kwargs: dict | None - worker: str - retries: int | None - queue: str | None + """任务执行结果""" + + result: str = Field(description='任务执行结果') + traceback: str = Field(description='错误堆栈信息') + status: str = Field(description='任务状态') + name: str = Field(description='任务名称') + args: list[Any] | None = Field(None, description='任务函数位置参数') + kwargs: dict[str, Any] | None = Field(None, description='任务函数关键字参数') + worker: str = Field(description='执行任务的 worker') + retries: int | None = Field(None, description='重试次数') + queue: str | None = Field(None, description='任务队列') diff --git a/backend/app/task/service/task_service.py b/backend/app/task/service/task_service.py index a9f740541..5aeebfc4f 100644 --- a/backend/app/task/service/task_service.py +++ b/backend/app/task/service/task_service.py @@ -13,14 +13,21 @@ class TaskService: @staticmethod async def get_list() -> list[str]: + """获取所有已注册的 Celery 任务列表""" registered_tasks = await run_in_threadpool(celery_app.control.inspect().registered) if not registered_tasks: - raise errors.ForbiddenError(msg='celery 服务未启动') + raise errors.ForbiddenError(msg='Celery 服务未启动') tasks = list(registered_tasks.values())[0] return tasks @staticmethod def get_detail(*, tid: str) -> TaskResult: + """ + 获取指定任务的详细信息 + + :param tid: 任务 UUID + :return: + """ try: result = AsyncResult(id=tid, app=celery_app) except NotRegistered: @@ -38,7 +45,13 @@ def get_detail(*, tid: str) -> TaskResult: ) @staticmethod - def revoke(*, tid: str): + def revoke(*, tid: str) -> None: + """ + 撤销指定的任务 + + :param tid: 任务 UUID + :return: + """ try: result = AsyncResult(id=tid, app=celery_app) except NotRegistered: @@ -47,6 +60,12 @@ def revoke(*, tid: str): @staticmethod def run(*, obj: RunParam) -> str: + """ + 运行指定的任务 + + :param obj: 任务运行参数 + :return: + """ task: AsyncResult = celery_app.send_task(name=obj.name, args=obj.args, kwargs=obj.kwargs) return task.task_id diff --git a/backend/common/enums.py b/backend/common/enums.py index 78e9d5d25..cdc5fe6a4 100644 --- a/backend/common/enums.py +++ b/backend/common/enums.py @@ -2,27 +2,38 @@ # -*- coding: utf-8 -*- from enum import Enum from enum import IntEnum as SourceIntEnum -from typing import Type +from typing import Any, Type, TypeVar + +T = TypeVar('T', bound=Enum) class _EnumBase: + """枚举基类,提供通用方法""" + @classmethod - def get_member_keys(cls: Type[Enum]) -> list[str]: + def get_member_keys(cls: Type[T]) -> list[str]: + """获取枚举成员名称列表""" return [name for name in cls.__members__.keys()] @classmethod - def get_member_values(cls: Type[Enum]) -> list: + def get_member_values(cls: Type[T]) -> list: + """获取枚举成员值列表""" return [item.value for item in cls.__members__.values()] + @classmethod + def get_member_dict(cls: Type[T]) -> dict[str, Any]: + """获取枚举成员字典""" + return {name: item.value for name, item in cls.__members__.items()} + class IntEnum(_EnumBase, SourceIntEnum): - """整型枚举""" + """整型枚举基类""" pass class StrEnum(_EnumBase, str, Enum): - """字符串枚举""" + """字符串枚举基类""" pass @@ -56,7 +67,7 @@ class RoleDataRuleExpressionType(IntEnum): class MethodType(StrEnum): - """请求方法""" + """HTTP 请求方法""" GET = 'GET' POST = 'POST' @@ -67,7 +78,7 @@ class MethodType(StrEnum): class LoginLogStatusType(IntEnum): - """登陆日志状态""" + """登录日志状态""" fail = 0 success = 1 @@ -100,7 +111,7 @@ class UserSocialType(StrEnum): """用户社交类型""" github = 'GitHub' - linuxdo = 'LinuxDo' + linux_do = 'LinuxDo' class FileType(StrEnum): @@ -176,7 +187,7 @@ class GenModelMySQLColumnType(StrEnum): class GenModelPostgreSQLColumnType(StrEnum): - """代码生成模型列类型(PostgreSQL),仅作为数据保留,并未实施""" + """代码生成模型列类型(PostgreSQL)""" # Python 类型映射 BIGINT = 'int' diff --git a/backend/common/exception/errors.py b/backend/common/exception/errors.py index a763a646a..ecd020d20 100644 --- a/backend/common/exception/errors.py +++ b/backend/common/exception/errors.py @@ -1,12 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -""" -全局业务异常类 - -业务代码执行异常时,可以使用 raise xxxError 触发内部错误,它尽可能实现带有后台任务的异常,但它不适用于**自定义响应状态码** -如果要求使用**自定义响应状态码**,则可以通过 return response_base.fail(res=CustomResponseCode.xxx) 直接返回 -""" # noqa: E501 - from typing import Any from fastapi import HTTPException @@ -16,6 +9,8 @@ class BaseExceptionMixin(Exception): + """基础异常混入类""" + code: int def __init__(self, *, msg: str = None, data: Any = None, background: BackgroundTask | None = None): @@ -26,17 +21,23 @@ def __init__(self, *, msg: str = None, data: Any = None, background: BackgroundT class HTTPError(HTTPException): + """HTTP 异常""" + def __init__(self, *, code: int, msg: Any = None, headers: dict[str, Any] | None = None): super().__init__(status_code=code, detail=msg, headers=headers) class CustomError(BaseExceptionMixin): + """自定义异常""" + def __init__(self, *, error: CustomErrorCode, data: Any = None, background: BackgroundTask | None = None): self.code = error.code super().__init__(msg=error.msg, data=data, background=background) class RequestError(BaseExceptionMixin): + """请求异常""" + code = StandardResponseCode.HTTP_400 def __init__(self, *, msg: str = 'Bad Request', data: Any = None, background: BackgroundTask | None = None): @@ -44,6 +45,8 @@ def __init__(self, *, msg: str = 'Bad Request', data: Any = None, background: Ba class ForbiddenError(BaseExceptionMixin): + """禁止访问异常""" + code = StandardResponseCode.HTTP_403 def __init__(self, *, msg: str = 'Forbidden', data: Any = None, background: BackgroundTask | None = None): @@ -51,6 +54,8 @@ def __init__(self, *, msg: str = 'Forbidden', data: Any = None, background: Back class NotFoundError(BaseExceptionMixin): + """资源不存在异常""" + code = StandardResponseCode.HTTP_404 def __init__(self, *, msg: str = 'Not Found', data: Any = None, background: BackgroundTask | None = None): @@ -58,6 +63,8 @@ def __init__(self, *, msg: str = 'Not Found', data: Any = None, background: Back class ServerError(BaseExceptionMixin): + """服务器异常""" + code = StandardResponseCode.HTTP_500 def __init__( @@ -67,6 +74,8 @@ def __init__( class GatewayError(BaseExceptionMixin): + """网关异常""" + code = StandardResponseCode.HTTP_502 def __init__(self, *, msg: str = 'Bad Gateway', data: Any = None, background: BackgroundTask | None = None): @@ -74,6 +83,8 @@ def __init__(self, *, msg: str = 'Bad Gateway', data: Any = None, background: Ba class AuthorizationError(BaseExceptionMixin): + """授权异常""" + code = StandardResponseCode.HTTP_401 def __init__(self, *, msg: str = 'Permission Denied', data: Any = None, background: BackgroundTask | None = None): @@ -81,6 +92,8 @@ def __init__(self, *, msg: str = 'Permission Denied', data: Any = None, backgrou class TokenError(HTTPError): + """Token 异常""" + code = StandardResponseCode.HTTP_401 def __init__(self, *, msg: str = 'Not Authenticated', headers: dict[str, Any] | None = None): diff --git a/backend/common/exception/exception_handler.py b/backend/common/exception/exception_handler.py index 24c89f46e..c213202f7 100644 --- a/backend/common/exception/exception_handler.py +++ b/backend/common/exception/exception_handler.py @@ -3,7 +3,6 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from pydantic import ValidationError -from pydantic.errors import PydanticUserError from starlette.exceptions import HTTPException from starlette.middleware.cors import CORSMiddleware from uvicorn.protocols.http.h11_impl import STATUS_PHRASES @@ -12,7 +11,6 @@ from backend.common.response.response_code import CustomResponseCode, StandardResponseCode from backend.common.response.response_schema import response_base from backend.common.schema import ( - CUSTOM_USAGE_ERROR_MESSAGES, CUSTOM_VALIDATION_ERROR_MESSAGES, ) from backend.core.conf import settings @@ -20,36 +18,34 @@ from backend.utils.trace_id import get_request_trace_id -def _get_exception_code(status_code: int): +def _get_exception_code(status_code: int) -> int: """ - 获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接 + 获取返回状态码(可用状态码基于 RFC 定义) - `python 状态码标准支持 `__ + `python 状态码标准支持 `__ `IANA 状态码注册表 `__ - :param status_code: + :param status_code: HTTP 状态码 :return: """ try: STATUS_PHRASES[status_code] + return status_code except Exception: - code = StandardResponseCode.HTTP_400 - else: - code = status_code - return code + return StandardResponseCode.HTTP_400 -async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError): +async def _validation_exception_handler(request: Request, exc: RequestValidationError | ValidationError): """ 数据验证异常处理 - :param e: + :param request: 请求对象 + :param exc: 验证异常 :return: """ errors = [] - for error in e.errors(): + for error in exc.errors(): custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type']) if custom_message: ctx = error.get('ctx') @@ -87,10 +83,10 @@ def register_exception(app: FastAPI): @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ - 全局HTTP异常处理 + 全局 HTTP 异常处理 - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: HTTP 异常 :return: """ if settings.ENVIRONMENT == 'dev': @@ -113,10 +109,10 @@ async def http_exception_handler(request: Request, exc: HTTPException): @app.exception_handler(RequestValidationError) async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError): """ - fastapi 数据验证异常处理 + FastAPI 数据验证异常处理 - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: 验证异常 :return: """ return await _validation_exception_handler(request, exc) @@ -124,42 +120,21 @@ async def fastapi_validation_exception_handler(request: Request, exc: RequestVal @app.exception_handler(ValidationError) async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): """ - pydantic 数据验证异常处理 + Pydantic 数据验证异常处理 - :param request: - :param exc: + :param request: 请求对象 + :param exc: 验证异常 :return: """ return await _validation_exception_handler(request, exc) - @app.exception_handler(PydanticUserError) - async def pydantic_user_error_handler(request: Request, exc: PydanticUserError): - """ - Pydantic 用户异常处理 - - :param request: - :param exc: - :return: - """ - content = { - 'code': StandardResponseCode.HTTP_500, - 'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code), - 'data': None, - } - request.state.__request_pydantic_user_error__ = content - content.update(trace_id=get_request_trace_id(request)) - return MsgSpecJSONResponse( - status_code=StandardResponseCode.HTTP_500, - content=content, - ) - @app.exception_handler(AssertionError) async def assertion_error_handler(request: Request, exc: AssertionError): """ 断言错误处理 - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: 断言错误 :return: """ if settings.ENVIRONMENT == 'dev': @@ -183,8 +158,8 @@ async def custom_exception_handler(request: Request, exc: BaseExceptionMixin): """ 全局自定义异常处理 - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: 自定义异常 :return: """ content = { @@ -205,8 +180,8 @@ async def all_unknown_exception_handler(request: Request, exc: Exception): """ 全局未知异常处理 - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: 未知异常 :return: """ if settings.ENVIRONMENT == 'dev': @@ -233,10 +208,11 @@ async def cors_custom_code_500_exception_handler(request, exc): 跨域自定义 500 异常处理 `Related issue `_ + `Solution `_ - :param request: - :param exc: + :param request: FastAPI 请求对象 + :param exc: 自定义异常 :return: """ if isinstance(exc, BaseExceptionMixin): diff --git a/backend/common/log.py b/backend/common/log.py index bd7283a2e..69ebeb0ed 100644 --- a/backend/common/log.py +++ b/backend/common/log.py @@ -14,18 +14,19 @@ class InterceptHandler(logging.Handler): """ - Default handler from examples in loguru documentation. - See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging + 日志拦截处理器,用于将标准库的日志重定向到 loguru + + 参考:https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging """ def emit(self, record: logging.LogRecord): - # Get corresponding Loguru level if it exists + # 获取对应的 Loguru 级别(如果存在) try: level = logger.level(record.levelname).name except ValueError: level = record.levelno - # Find caller from where originated the logged message. + # 查找记录日志消息的调用者 frame, depth = inspect.currentframe(), 0 while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): frame = frame.f_back @@ -34,16 +35,19 @@ def emit(self, record: logging.LogRecord): logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) -def setup_logging(): +def setup_logging() -> None: """ - From https://github.com/benoitc/gunicorn/issues/1572#issuecomment-638391953 - https://github.com/pawamoy/pawamoy.github.io/issues/17 + 设置日志处理器 + + 参考: + - https://github.com/benoitc/gunicorn/issues/1572#issuecomment-638391953 + - https://github.com/pawamoy/pawamoy.github.io/issues/17 """ - # Set the logging handler and level + # 设置根日志处理器和级别 logging.root.handlers = [InterceptHandler()] logging.root.setLevel(settings.LOG_STD_LEVEL) - # Remove all log handlers and propagate to root logger + # 配置日志传播规则 for name in logging.root.manager.loggerDict.keys(): logging.getLogger(name).handlers = [] if 'uvicorn.access' in name or 'watchfiles.main' in name: @@ -54,17 +58,15 @@ def setup_logging(): # Debug log handlers # logging.debug(f'{logging.getLogger(name)}, {logging.getLogger(name).propagate}') - # Define the correlation_id default filter function + # 定义 correlation_id 默认过滤函数 # https://github.com/snok/asgi-correlation-id/issues/7 def correlation_id_filter(record): cid = correlation_id.get(settings.LOG_CID_DEFAULT_VALUE) record['correlation_id'] = cid[: settings.LOG_CID_UUID_LENGTH] return record - # Remove default loguru logger - logger.remove() - - # Set the loguru default handlers + # 配置 loguru 处理器 + logger.remove() # 移除默认处理器 logger.configure( handlers=[ { @@ -77,16 +79,17 @@ def correlation_id_filter(record): ) -def set_custom_logfile(): +def set_custom_logfile() -> None: + """设置自定义日志文件""" log_path = path_conf.LOG_DIR if not os.path.exists(log_path): os.mkdir(log_path) - # log files + # 日志文件 log_access_file = os.path.join(log_path, settings.LOG_ACCESS_FILENAME) log_error_file = os.path.join(log_path, settings.LOG_ERROR_FILENAME) - # set loguru logger default config + # 日志文件通用配置 # https://loguru.readthedocs.io/en/stable/api/logger.html#loguru._logger.Logger.add log_config = { 'format': settings.LOG_FILE_FORMAT, @@ -96,7 +99,7 @@ def set_custom_logfile(): 'compression': 'tar.gz', } - # stdout file + # 标准输出文件 logger.add( str(log_access_file), level=settings.LOG_ACCESS_FILE_LEVEL, @@ -106,7 +109,7 @@ def set_custom_logfile(): **log_config, ) - # stderr file + # 标准错误文件 logger.add( str(log_error_file), level=settings.LOG_ERROR_FILE_LEVEL, @@ -117,4 +120,5 @@ def set_custom_logfile(): ) +# 创建 logger 实例 log = logger diff --git a/backend/common/model.py b/backend/common/model.py index 57ca8c5b6..f006b506a 100644 --- a/backend/common/model.py +++ b/backend/common/model.py @@ -13,7 +13,7 @@ # MappedBase -> id: Mapped[id_key] # DataClassBase && Base -> id: Mapped[id_key] = mapped_column(init=False) id_key = Annotated[ - int, mapped_column(primary_key=True, index=True, autoincrement=True, sort_order=-999, comment='主键id') + int, mapped_column(primary_key=True, index=True, autoincrement=True, sort_order=-999, comment='主键 ID') ] @@ -38,31 +38,39 @@ class DateTimeMixin(MappedAsDataclass): class MappedBase(AsyncAttrs, DeclarativeBase): """ - 生命式基类, 作为所有基类或数据模型类的父类而存在 + 声明式基类, 作为所有基类或数据模型类的父类而存在 `AsyncAttrs `__ + `DeclarativeBase `__ + `mapped_column() `__ """ @declared_attr.directive def __tablename__(cls) -> str: + """生成表名""" return cls.__name__.lower() + @declared_attr.directive + def __table_args__(cls) -> dict: + """表配置""" + return {'comment': cls.__doc__ or ''} + class DataClassBase(MappedAsDataclass, MappedBase): """ - 声明性数据类基类, 它将带有数据类集成, 允许使用更高级配置, 但你必须注意它的一些特性, 尤其是和 DeclarativeBase 一起使用时 + 声明性数据类基类, 带有数据类集成, 允许使用更高级配置, 但你必须注意它的一些特性, 尤其是和 DeclarativeBase 一起使用时 `MappedAsDataclass `__ - """ # noqa: E501 + """ __abstract__ = True class Base(DataClassBase, DateTimeMixin): """ - 声明性 Mixin 数据类基类, 带有数据类集成, 并包含 MiXin 数据类基础表结构, 你可以简单的理解它为含有基础表结构的数据类基类 - """ # noqa: E501 + 声明性数据类基类, 带有数据类集成, 并包含 MiXin 数据类基础表结构 + """ __abstract__ = True diff --git a/backend/common/pagination.py b/backend/common/pagination.py index 823408068..359ec5927 100644 --- a/backend/common/pagination.py +++ b/backend/common/pagination.py @@ -3,7 +3,7 @@ from __future__ import annotations from math import ceil -from typing import TYPE_CHECKING, Generic, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar from fastapi import Depends, Query from fastapi_pagination import pagination_ctx @@ -21,8 +21,10 @@ class _CustomPageParams(BaseModel, AbstractParams): - page: int = Query(1, ge=1, description='Page number') - size: int = Query(20, gt=0, le=100, description='Page size') # 默认 20 条记录 + """自定义分页参数""" + + page: int = Query(1, ge=1, description='页码') + size: int = Query(20, gt=0, le=100, description='每页数量') def to_raw_params(self) -> RawParams: return RawParams( @@ -32,23 +34,29 @@ def to_raw_params(self) -> RawParams: class _Links(BaseModel): - first: str = Field(..., description='首页链接') - last: str = Field(..., description='尾页链接') - self: str = Field(..., description='当前页链接') + """分页链接""" + + first: str = Field(description='首页链接') + last: str = Field(description='尾页链接') + self: str = Field(description='当前页链接') next: str | None = Field(None, description='下一页链接') prev: str | None = Field(None, description='上一页链接') class _PageDetails(BaseModel): - items: list = Field([], description='当前页数据') - total: int = Field(..., description='总条数') - page: int = Field(..., description='当前页') - size: int = Field(..., description='每页数量') - total_pages: int = Field(..., description='总页数') - links: _Links + """分页详情""" + + items: list = Field([], description='当前页数据列表') + total: int = Field(description='数据总条数') + page: int = Field(description='当前页码') + size: int = Field(description='每页数量') + total_pages: int = Field(description='总页数') + links: _Links = Field(description='分页链接') class _CustomPage(_PageDetails, AbstractPage[T], Generic[T]): + """自定义分页类""" + __params_type__ = _CustomPageParams @classmethod @@ -60,19 +68,19 @@ def create( ) -> _CustomPage[T]: page = params.page size = params.size - total_pages = ceil(total / params.size) + total_pages = ceil(total / size) links = create_links( first={'page': 1, 'size': size}, - last={'page': f'{ceil(total / params.size)}', 'size': size} if total > 0 else {'page': 1, 'size': size}, - next={'page': f'{page + 1}', 'size': size} if (page + 1) <= total_pages else None, - prev={'page': f'{page - 1}', 'size': size} if (page - 1) >= 1 else None, + last={'page': total_pages, 'size': size} if total > 0 else {'page': 1, 'size': size}, + next={'page': page + 1, 'size': size} if (page + 1) <= total_pages else None, + prev={'page': page - 1, 'size': size} if (page - 1) >= 1 else None, ).model_dump() return cls( items=items, total=total, - page=params.page, - size=params.size, + page=page, + size=size, total_pages=total_pages, links=links, # type: ignore ) @@ -80,7 +88,7 @@ def create( class PageData(_PageDetails, Generic[SchemaT]): """ - 包含 data schema 的统一返回模型,适用于分页接口 + 包含返回数据 schema 的统一返回模型,仅适用于分页接口 E.g. :: @@ -103,12 +111,12 @@ def test() -> ResponseSchemaModel[PageData[GetApiDetail]]: items: Sequence[SchemaT] -async def paging_data(db: AsyncSession, select: Select) -> dict: +async def paging_data(db: AsyncSession, select: Select) -> dict[str, Any]: """ 基于 SQLAlchemy 创建分页数据 - :param db: - :param select: + :param db: 数据库会话 + :param select: SQL 查询语句 :return: """ paginated_data: _CustomPage = await paginate(db, select) diff --git a/backend/common/response/response_code.py b/backend/common/response/response_code.py index 5aba844f8..5ec754be1 100644 --- a/backend/common/response/response_code.py +++ b/backend/common/response/response_code.py @@ -9,17 +9,13 @@ class CustomCodeBase(Enum): """自定义状态码基类""" @property - def code(self): - """ - 获取状态码 - """ + def code(self) -> int: + """获取状态码""" return self.value[0] @property - def msg(self): - """ - 获取状态码信息 - """ + def msg(self) -> str: + """获取状态码信息""" return self.value[1] diff --git a/backend/common/response/response_schema.py b/backend/common/response/response_schema.py index a75a48eef..73c0e6594 100644 --- a/backend/common/response/response_schema.py +++ b/backend/common/response/response_schema.py @@ -3,7 +3,7 @@ from typing import Any, Generic, TypeVar from fastapi import Response -from pydantic import BaseModel +from pydantic import BaseModel, Field from backend.common.response.response_code import CustomResponse, CustomResponseCode from backend.utils.serializers import MsgSpecJSONResponse @@ -13,9 +13,9 @@ class ResponseModel(BaseModel): """ - 通用型统一返回模型,不包含 data schema + 不包含返回数据 schema 的通用型统一返回模型 - E.g. :: + 示例:: @router.get('/test', response_model=ResponseModel) def test(): @@ -33,16 +33,16 @@ def test() -> ResponseModel: return ResponseModel(code=res.code, msg=res.msg, data={'test': 'test'}) """ - code: int = CustomResponseCode.HTTP_200.code - msg: str = CustomResponseCode.HTTP_200.msg - data: Any | None = None + code: int = Field(CustomResponseCode.HTTP_200.code, description='返回状态码') + msg: str = Field(CustomResponseCode.HTTP_200.msg, description='返回信息') + data: Any | None = Field(None, description='返回数据') class ResponseSchemaModel(ResponseModel, Generic[SchemaT]): """ - 包含 data schema 的统一返回模型,适用于非分页接口 + 包含返回数据 schema 的通用型统一返回模型,仅适用于非分页接口 - E.g. :: + 示例:: @router.get('/test', response_model=ResponseSchemaModel[GetApiDetail]) def test(): @@ -85,6 +85,13 @@ def success( res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_200, data: Any | None = None, ) -> ResponseModel | ResponseSchemaModel: + """ + 成功响应 + + :param res: 返回信息 + :param data: 返回数据 + :return: + """ return self.__response(res=res, data=data) def fail( @@ -93,6 +100,13 @@ def fail( res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_400, data: Any = None, ) -> ResponseModel | ResponseSchemaModel: + """ + 失败响应 + + :param res: 返回信息 + :param data: 返回数据 + :return: + """ return self.__response(res=res, data=data) @staticmethod @@ -108,8 +122,8 @@ def fast_success( 使用此返回方法时,不能指定接口参数 response_model 和箭头返回类型 - :param res: - :param data: + :param res: 返回信息 + :param data: 返回数据 :return: """ return MsgSpecJSONResponse({'code': res.code, 'msg': res.msg, 'data': data}) diff --git a/backend/common/schema.py b/backend/common/schema.py index 09f9b61d5..87ec90b7a 100644 --- a/backend/common/schema.py +++ b/backend/common/schema.py @@ -7,7 +7,8 @@ from backend.core.conf import settings -# 自定义验证错误信息不包含验证预期内容(也就是输入内容),受支持的预期内容字段参考以下链接 +# 自定义验证错误信息 +# 不包含验证预期内容(也就是输入内容),受支持的预期内容字段参考以下链接 # https://github.com/pydantic/pydantic-core/blob/a5cb7382643415b716b1a7a5392914e50f726528/tests/test_errors.py#L266 # 替换预期内容字段方式,参考以下链接 # https://github.com/pydantic/pydantic/blob/caa78016433ec9b16a973f92f187a7b6bfde6cb5/docs/errors/errors.md?plain=1#L232 @@ -108,50 +109,20 @@ 'value_error': '值输入错误', } -CUSTOM_USAGE_ERROR_MESSAGES = { - 'class-not-fully-defined': '类属性类型未完全定义', - 'custom-json-schema': '__modify_schema__ 方法在V2中已被弃用', - 'decorator-missing-field': '定义了无效字段验证器', - 'discriminator-no-field': '鉴别器字段未全部定义', - 'discriminator-alias-type': '鉴别器字段使用非字符串类型定义', - 'discriminator-needs-literal': '鉴别器字段需要使用字面值定义', - 'discriminator-alias': '鉴别器字段别名定义不一致', - 'discriminator-validator': '鉴别器字段禁止定义字段验证器', - 'model-field-overridden': '无类型定义字段禁止重写', - 'model-field-missing-annotation': '缺少字段类型定义', - 'config-both': '重复定义配置项', - 'removed-kwargs': '调用已移除的关键字配置参数', - 'invalid-for-json-schema': '存在无效的 JSON 类型', - 'base-model-instantiated': '禁止实例化基础模型', - 'undefined-annotation': '缺少类型定义', - 'schema-for-unknown-type': '未知类型定义', - 'create-model-field-definitions': '字段定义错误', - 'create-model-config-base': '配置项定义错误', - 'validator-no-fields': '字段验证器未指定字段', - 'validator-invalid-fields': '字段验证器字段定义错误', - 'validator-instance-method': '字段验证器必须为类方法', - 'model-serializer-instance-method': '序列化器必须为实例方法', - 'validator-v1-signature': 'V1字段验证器错误已被弃用', - 'validator-signature': '字段验证器签名错误', - 'field-serializer-signature': '字段序列化器签名无法识别', - 'model-serializer-signature': '模型序列化器签名无法识别', - 'multiple-field-serializers': '字段序列化器重复定义', - 'invalid_annotated_type': '无效的类型定义', - 'type-adapter-config-unused': '类型适配器配置项定义错误', - 'root-model-extra': '根模型禁止定义额外字段', -} - - CustomPhoneNumber = Annotated[str, Field(pattern=r'^1[3-9]\d{9}$')] class CustomEmailStr(EmailStr): + """自定义邮箱类型""" + @classmethod def _validate(cls, __input_value: str) -> str: return None if __input_value == '' else validate_email(__input_value)[1] class SchemaBase(BaseModel): + """基础模型配置""" + model_config = ConfigDict( use_enum_values=True, json_encoders={datetime: lambda x: x.strftime(settings.DATETIME_FORMAT)}, diff --git a/backend/common/security/jwt.py b/backend/common/security/jwt.py index cca671129..17869808d 100644 --- a/backend/common/security/jwt.py +++ b/backend/common/security/jwt.py @@ -3,6 +3,7 @@ import json from datetime import timedelta +from typing import Any from uuid import uuid4 from fastapi import Depends, Request @@ -32,10 +33,10 @@ def get_hash_password(password: str, salt: bytes | None) -> str: """ - Encrypt passwords using the hash algorithm + 使用哈希算法加密密码 - :param password: - :param salt: + :param password: 密码 + :param salt: 盐值 :return: """ return password_hash.hash(password, salt=salt) @@ -43,33 +44,68 @@ def get_hash_password(password: str, salt: bytes | None) -> str: def password_verify(plain_password: str, hashed_password: str) -> bool: """ - Password verification + 密码验证 - :param plain_password: The password to verify - :param hashed_password: The hash ciphers to compare + :param plain_password: 待验证的密码 + :param hashed_password: 哈希密码 :return: """ return password_hash.verify(plain_password, hashed_password) -async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> AccessToken: +def jwt_encode(payload: dict[str, Any]) -> str: """ - Generate encryption token + 生成 JWT token - :param user_id: The user id of the JWT - :param multi_login: Multipoint login for user - :param kwargs: Token extra information + :param payload: 载荷 :return: """ - expire = timezone.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS) - session_uuid = str(uuid4()) - access_token = jwt.encode( - {'session_uuid': session_uuid, 'exp': expire, 'sub': user_id}, + return jwt.encode( + payload, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM, ) - if multi_login is False: + +def jwt_decode(token: str) -> TokenPayload: + """ + 解析 JWT token + + :param token: JWT token + :return: + """ + try: + payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM]) + session_uuid = payload.get('session_uuid') or 'debug' + user_id = payload.get('sub') + expire_time = payload.get('exp') + if not user_id: + raise TokenError(msg='Token 无效') + except ExpiredSignatureError: + raise TokenError(msg='Token 已过期') + except (JWTError, Exception): + raise TokenError(msg='Token 无效') + return TokenPayload(id=int(user_id), session_uuid=session_uuid, expire_time=expire_time) + + +async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> AccessToken: + """ + 生成加密 token + + :param user_id: 用户 ID + :param multi_login: 是否允许多端登录 + :param kwargs: token 额外信息 + :return: + """ + expire = timezone.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS) + session_uuid = str(uuid4()) + access_token = jwt_encode({ + 'session_uuid': session_uuid, + 'exp': expire, + 'sub': user_id, + }) + + if not multi_login: await redis_client.delete_prefix(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}') await redis_client.setex( @@ -91,20 +127,16 @@ async def create_access_token(user_id: str, multi_login: bool, **kwargs) -> Acce async def create_refresh_token(user_id: str, multi_login: bool) -> RefreshToken: """ - Generate encryption refresh token, only used to create a new token + 生成加密刷新 token,仅用于创建新的 token - :param user_id: The user id of the JWT - :param multi_login: multipoint login for user + :param user_id: 用户 ID + :param multi_login: 是否允许多端登录 :return: """ expire = timezone.now() + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS) - refresh_token = jwt.encode( - {'exp': expire, 'sub': user_id}, - settings.TOKEN_SECRET_KEY, - settings.TOKEN_ALGORITHM, - ) + refresh_token = jwt_encode({'exp': expire, 'sub': user_id}) - if multi_login is False: + if not multi_login: key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}' await redis_client.delete_prefix(key_prefix) @@ -118,12 +150,12 @@ async def create_refresh_token(user_id: str, multi_login: bool) -> RefreshToken: async def create_new_token(user_id: str, refresh_token: str, multi_login: bool, **kwargs) -> NewToken: """ - Generate new token + 生成新的 token - :param user_id: - :param refresh_token: - :param multi_login: - :param kwargs: Access token extra information + :param user_id: 用户 ID + :param refresh_token: 刷新 token + :param multi_login: 是否允许多端登录 + :param kwargs: token 附加信息 :return: """ redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{user_id}:{refresh_token}') @@ -137,46 +169,38 @@ async def create_new_token(user_id: str, refresh_token: str, multi_login: bool, ) -def get_token(request: Request) -> str: +async def revoke_token(user_id: str, session_uuid: str) -> None: """ - Get token for request header + 撤销 token + :param user_id: 用户 ID + :param session_uuid: 会话 ID :return: """ - authorization = request.headers.get('Authorization') - scheme, token = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != 'bearer': - raise TokenError(msg='Token 无效') - return token + token_key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{session_uuid}' + await redis_client.delete(token_key) -def jwt_decode(token: str) -> TokenPayload: +def get_token(request: Request) -> str: """ - Decode token + 获取请求头中的 token - :param token: + :param request: FastAPI 请求对象 :return: """ - try: - payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM]) - session_uuid = payload.get('session_uuid') or 'debug' - user_id = payload.get('sub') - expire_time = payload.get('exp') - if not user_id: - raise TokenError(msg='Token 无效') - except ExpiredSignatureError: - raise TokenError(msg='Token 已过期') - except (JWTError, Exception): + authorization = request.headers.get('Authorization') + scheme, token = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != 'bearer': raise TokenError(msg='Token 无效') - return TokenPayload(id=int(user_id), session_uuid=session_uuid, expire_time=expire_time) + return token async def get_current_user(db: AsyncSession, pk: int) -> User: """ - Get the current user through token + 获取当前用户 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 用户 ID :return: """ from backend.app.admin.crud.crud_user import user_dao @@ -200,9 +224,9 @@ async def get_current_user(db: AsyncSession, pk: int) -> User: def superuser_verify(request: Request) -> bool: """ - Verify the current user permissions through token + 验证当前用户权限 - :param request: + :param request: FastAPI 请求对象 :return: """ superuser = request.user.is_superuser @@ -213,16 +237,20 @@ def superuser_verify(request: Request) -> bool: async def jwt_authentication(token: str) -> GetUserInfoWithRelationDetail: """ - JWT authentication + JWT 认证 - :param token: + :param token: JWT token :return: """ token_payload = jwt_decode(token) user_id = token_payload.id redis_token = await redis_client.get(f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token_payload.session_uuid}') - if not redis_token or token != redis_token: + if not redis_token: raise TokenError(msg='Token 已过期') + + if token != redis_token: + raise TokenError(msg='Token 已失效') + cache_user = await redis_client.get(f'{settings.JWT_USER_REDIS_PREFIX}:{user_id}') if not cache_user: async with async_db_session() as db: diff --git a/backend/common/security/permission.py b/backend/common/security/permission.py index 7a526a634..ec233255a 100644 --- a/backend/common/security/permission.py +++ b/backend/common/security/permission.py @@ -17,33 +17,48 @@ class RequestPermission: """ - 请求权限,仅用于角色菜单RBAC + 请求权限验证器,用于角色菜单 RBAC 权限控制 - Tip: + 注意: 使用此请求权限时,需要将 `Depends(RequestPermission('xxx'))` 在 `DependsRBAC` 之前设置, - 因为 fastapi 当前版本的接口依赖注入按正序执行,意味着 RBAC 标识会在验证前被设置 + 因为 FastAPI 当前版本的接口依赖注入按正序执行,意味着 RBAC 标识会在验证前被设置 """ - def __init__(self, value: str): + def __init__(self, value: str) -> None: + """ + 初始化请求权限验证器 + + :param value: 权限标识 + :return: + """ self.value = value - async def __call__(self, request: Request): + async def __call__(self, request: Request) -> None: + """ + 验证请求权限 + + :param request: FastAPI 请求对象 + :return: + """ if settings.RBAC_ROLE_MENU_MODE: if not isinstance(self.value, str): raise ServerError - # 附加权限标识 + # 附加权限标识到请求状态 request.state.permission = self.value def filter_data_permission(request: Request) -> ColumnElement[bool]: """ - 过滤数据权限 + 过滤数据权限,控制用户可见数据范围 - 使用场景:用户登录前台后,控制其能看到哪些数据 + 使用场景: + - 用户登录前台后,控制其能看到哪些数据 + - 根据用户角色和规则过滤数据访问权限 - :param request: + :param request: FastAPI 请求对象 :return: """ + # 获取用户角色和规则 data_rules = [] for role in request.user.roles: data_rules.extend(role.rules) @@ -57,10 +72,13 @@ def filter_data_permission(request: Request) -> ColumnElement[bool]: where_or_list = [] for rule in user_data_rules: + # 验证规则模型 rule_model = rule.model if rule_model not in settings.DATA_PERMISSION_MODELS: raise errors.NotFoundError(msg='数据规则模型不存在') model_ins = dynamic_import_data_model(settings.DATA_PERMISSION_MODELS[rule_model]) + + # 验证规则列 model_columns = [ key for key in model_ins.__table__.columns.keys() if key not in settings.DATA_PERMISSION_COLUMN_EXCLUDE ] @@ -68,11 +86,9 @@ def filter_data_permission(request: Request) -> ColumnElement[bool]: if column not in model_columns: raise errors.NotFoundError(msg='数据规则模型列不存在') - # 获取模型的列对象 + # 构建过滤条件 column_obj = getattr(model_ins, column) rule_expression = rule.expression - - # 根据表达式类型构建条件 condition = None if rule_expression == RoleDataRuleExpressionType.eq: condition = column_obj == rule.value @@ -93,14 +109,14 @@ def filter_data_permission(request: Request) -> ColumnElement[bool]: values = rule.value.split(',') if isinstance(rule.value, str) else rule.value condition = ~column_obj.in_(values) + # 根据运算符添加到对应列表 if condition is not None: - rule_operator = rule.operator - if rule_operator == RoleDataRuleOperatorType.AND: + if rule.operator == RoleDataRuleOperatorType.AND: where_and_list.append(condition) - elif rule_operator == RoleDataRuleOperatorType.OR: + elif rule.operator == RoleDataRuleOperatorType.OR: where_or_list.append(condition) - # 组合条件 + # 组合所有条件 where_list = [] if where_and_list: where_list.append(and_(*where_and_list)) diff --git a/backend/common/security/rbac.py b/backend/common/security/rbac.py index d34c73888..80cbda4b7 100644 --- a/backend/common/security/rbac.py +++ b/backend/common/security/rbac.py @@ -13,8 +13,8 @@ async def rbac_verify(request: Request, _token: str = DependsJwtAuth) -> None: """ RBAC 权限校验(鉴权顺序很重要,谨慎修改) - :param request: - :param _token: + :param request: FastAPI 请求对象 + :param _token: JWT 令牌 :return: """ path = request.url.path diff --git a/backend/common/socketio/actions.py b/backend/common/socketio/actions.py index 9007f74e8..8abc25621 100644 --- a/backend/common/socketio/actions.py +++ b/backend/common/socketio/actions.py @@ -7,7 +7,7 @@ async def task_notification(msg: str): """ 任务通知 - :param msg: + :param msg: 通知信息 :return: """ await sio.emit('task_notification', {'msg': msg}) diff --git a/backend/common/socketio/server.py b/backend/common/socketio/server.py index 5369cec97..6bc91f15e 100644 --- a/backend/common/socketio/server.py +++ b/backend/common/socketio/server.py @@ -8,8 +8,9 @@ from backend.core.conf import settings from backend.database.redis import redis_client +# 创建 Socket.IO 服务器实例 sio = socketio.AsyncServer( - # 此配置是为了集成 celery 实现消息订阅,如果你不使用 celery,可以直接删除此配置,不会造成任何影响 + # 集成 Celery 实现消息订阅 client_manager=socketio.AsyncRedisManager( f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:' f'{settings.REDIS_PORT}/{task_settings.CELERY_BROKER_REDIS_DATABASE}' @@ -30,15 +31,15 @@ @sio.event async def connect(sid, environ, auth): - """当客户端连接时触发""" + """处理 WebSocket 连接事件""" if not auth: - log.error('ws 连接失败:无授权') + log.error('WebSocket 连接失败:无授权') return False session_uuid = auth.get('session_uuid') token = auth.get('token') if not token or not session_uuid: - log.error('ws 连接失败:授权失败,请检查') + log.error('WebSocket 连接失败:授权失败,请检查') return False # 免授权直连 @@ -49,7 +50,7 @@ async def connect(sid, environ, auth): try: await jwt_authentication(token) except Exception as e: - log.info(f'ws 连接失败:{e}') + log.info(f'WebSocket 连接失败:{str(e)}') return False await redis_client.sadd(settings.TOKEN_ONLINE_REDIS_PREFIX, session_uuid) @@ -57,6 +58,6 @@ async def connect(sid, environ, auth): @sio.event -async def disconnect(sid): - """当客户端断开连接时触发""" +async def disconnect(sid: str) -> None: + """处理 WebSocket 断开连接事件""" await redis_client.spop(settings.TOKEN_ONLINE_REDIS_PREFIX) diff --git a/backend/core/conf.py b/backend/core/conf.py index 28b70e923..4de1e3a7d 100644 --- a/backend/core/conf.py +++ b/backend/core/conf.py @@ -6,36 +6,39 @@ from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class Settings(BaseSettings): - """Global Settings""" + """全局配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict( + env_file=f'{BASE_PATH}/.env', + env_file_encoding='utf-8', + extra='ignore', + case_sensitive=True, + ) - # Env Config + # .env 环境 ENVIRONMENT: Literal['dev', 'pro'] - # Env Database Type + # .env 数据库 DATABASE_TYPE: Literal['mysql', 'postgresql'] - - # Env Database DATABASE_HOST: str DATABASE_PORT: int DATABASE_USER: str DATABASE_PASSWORD: str - # Env Redis + # .env Redis REDIS_HOST: str REDIS_PORT: int REDIS_PASSWORD: str REDIS_DATABASE: int - # Env Token + # .env Token TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32) - # Env Opera Log + # .env 操作日志加密密钥 OPERA_LOG_ENCRYPT_SECRET_KEY: str # 密钥 os.urandom(32), 需使用 bytes.hex() 方法转换为 str # FastAPI @@ -48,14 +51,7 @@ class Settings(BaseSettings): FASTAPI_OPENAPI_URL: str | None = '/openapi' FASTAPI_STATIC_FILES: bool = True - # Upload - UPLOAD_READ_SIZE: int = 1024 # 上传文件时分片读取大小 - UPLOAD_IMAGE_EXT_INCLUDE: list[str] = ['jpg', 'jpeg', 'png', 'gif', 'webp'] - UPLOAD_IMAGE_SIZE_MAX: int = 1024 * 1024 * 5 - UPLOAD_VIDEO_EXT_INCLUDE: list[str] = ['mp4', 'mov', 'avi', 'flv'] - UPLOAD_VIDEO_SIZE_MAX: int = 1024 * 1024 * 20 - - # Database + # 数据库 DATABASE_ECHO: bool = False DATABASE_POOL_ECHO: bool = False DATABASE_SCHEMA: str = 'fba' @@ -64,24 +60,21 @@ class Settings(BaseSettings): # Redis REDIS_TIMEOUT: int = 5 - # Socketio - WS_NO_AUTH_MARKER: str = 'internal' - # Token - TOKEN_ALGORITHM: str = 'HS256' # 算法 - TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒 - TOKEN_REFRESH_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # refresh token 过期时间,单位:秒 + TOKEN_ALGORITHM: str = 'HS256' + TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 # 1 天 + TOKEN_REFRESH_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 TOKEN_REDIS_PREFIX: str = 'fba:token' TOKEN_EXTRA_INFO_REDIS_PREFIX: str = 'fba:token_extra_info' TOKEN_ONLINE_REDIS_PREFIX: str = 'fba:token_online' TOKEN_REFRESH_REDIS_PREFIX: str = 'fba:refresh_token' - TOKEN_REQUEST_PATH_EXCLUDE: list[str] = [ # JWT / RBAC 白名单 + TOKEN_REQUEST_PATH_EXCLUDE: list[str] = [ # JWT / RBAC 路由白名单 f'{FASTAPI_API_V1_PATH}/auth/login', ] # JWT JWT_USER_REDIS_PREFIX: str = 'fba:user' - JWT_USER_REDIS_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 + JWT_USER_REDIS_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 # RBAC RBAC_ROLE_MENU_MODE: bool = False @@ -90,51 +83,52 @@ class Settings(BaseSettings): 'sys:monitor:server', ] - # Cookies + # Cookie COOKIE_REFRESH_TOKEN_KEY: str = 'fba_refresh_token' - COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS: int = TOKEN_REFRESH_EXPIRE_SECONDS - - # Log - LOG_CID_DEFAULT_VALUE: str = '-' - LOG_CID_UUID_LENGTH: int = 32 # must <= 32 - LOG_STD_LEVEL: str = 'INFO' - LOG_ACCESS_FILE_LEVEL: str = 'INFO' - LOG_ERROR_FILE_LEVEL: str = 'ERROR' - LOG_STD_FORMAT: str = ( - '{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | ' - ' {correlation_id} | {message}' - ) - LOG_FILE_FORMAT: str = ( - '{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | ' - ' {correlation_id} | {message}' - ) - LOG_ACCESS_FILENAME: str = 'fba_access.log' - LOG_ERROR_FILENAME: str = 'fba_error.log' + COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 - # Middleware - MIDDLEWARE_CORS: bool = True - MIDDLEWARE_ACCESS: bool = True + # 数据权限配置 + DATA_PERMISSION_MODELS: dict[str, str] = { # 允许进行数据过滤的 SQLA 模型,它必须以模块字符串的方式定义 + 'Api': 'backend.plugin.casbin.model.Api', + } + DATA_PERMISSION_COLUMN_EXCLUDE: list[str] = [ # 排除允许进行数据过滤的 SQLA 模型列 + 'id', + 'sort', + 'created_time', + 'updated_time', + ] - # Trace ID - TRACE_ID_REQUEST_HEADER_KEY: str = 'X-Request-ID' + # Socket.IO + WS_NO_AUTH_MARKER: str = 'internal' # CORS - CORS_ALLOWED_ORIGINS: list[str] = [ + CORS_ALLOWED_ORIGINS: list[str] = [ # 末尾不带斜杠 'http://127.0.0.1:8000', - 'http://localhost:5173', # 前端地址,末尾不要带 '/' + 'http://localhost:5173', ] CORS_EXPOSE_HEADERS: list[str] = [ - TRACE_ID_REQUEST_HEADER_KEY, + 'X-Request-ID', ] - # DateTime + # 中间件配置 + MIDDLEWARE_CORS: bool = True + MIDDLEWARE_ACCESS: bool = True + + # 请求限制配置 + REQUEST_LIMITER_REDIS_PREFIX: str = 'fba:limiter' + + # 时间配置 DATETIME_TIMEZONE: str = 'Asia/Shanghai' DATETIME_FORMAT: str = '%Y-%m-%d %H:%M:%S' - # Request limiter - REQUEST_LIMITER_REDIS_PREFIX: str = 'fba:limiter' + # 文件上传 + UPLOAD_READ_SIZE: int = 1024 + UPLOAD_IMAGE_EXT_INCLUDE: list[str] = ['jpg', 'jpeg', 'png', 'gif', 'webp'] + UPLOAD_IMAGE_SIZE_MAX: int = 5 * 1024 * 1024 # 5 MB + UPLOAD_VIDEO_EXT_INCLUDE: list[str] = ['mp4', 'mov', 'avi', 'flv'] + UPLOAD_VIDEO_SIZE_MAX: int = 20 * 1024 * 1024 # 20 MB - # Demo mode (Only GET, OPTIONS requests are allowed) + # 演示模式配置 DEMO_MODE: bool = False DEMO_MODE_EXCLUDE: set[tuple[str, str]] = { ('POST', f'{FASTAPI_API_V1_PATH}/auth/login'), @@ -142,17 +136,37 @@ class Settings(BaseSettings): ('GET', f'{FASTAPI_API_V1_PATH}/auth/captcha'), } - # Ip location + # IP 定位配置 IP_LOCATION_PARSE: Literal['online', 'offline', 'false'] = 'offline' IP_LOCATION_REDIS_PREFIX: str = 'fba:ip:location' - IP_LOCATION_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒 + IP_LOCATION_EXPIRE_SECONDS: int = 60 * 60 * 24 # 1 天 + + # 追踪 ID + TRACE_ID_REQUEST_HEADER_KEY: str = 'X-Request-ID' + + # 日志 + LOG_CID_DEFAULT_VALUE: str = '-' + LOG_CID_UUID_LENGTH: int = 32 # 日志 correlation_id 长度,必须小于等于 32 + LOG_STD_LEVEL: str = 'INFO' + LOG_ACCESS_FILE_LEVEL: str = 'INFO' + LOG_ERROR_FILE_LEVEL: str = 'ERROR' + LOG_STD_FORMAT: str = ( + '{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | ' + ' {correlation_id} | {message}' + ) + LOG_FILE_FORMAT: str = ( + '{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | ' + ' {correlation_id} | {message}' + ) + LOG_ACCESS_FILENAME: str = 'fba_access.log' + LOG_ERROR_FILENAME: str = 'fba_error.log' - # Opera log + # 操作日志 OPERA_LOG_PATH_EXCLUDE: list[str] = [ '/favicon.ico', - FASTAPI_DOCS_URL, - FASTAPI_REDOC_URL, - FASTAPI_OPENAPI_URL, + '/docs', + '/redoc', + '/openapi', f'{FASTAPI_API_V1_PATH}/auth/login/swagger', f'{FASTAPI_API_V1_PATH}/oauth2/github/callback', f'{FASTAPI_API_V1_PATH}/oauth2/linux-do/callback', @@ -165,27 +179,15 @@ class Settings(BaseSettings): 'confirm_password', ] - # Data permission - DATA_PERMISSION_MODELS: dict[ - str, str - ] = { # 允许进行数据过滤的 SQLA 模型,它必须以模块字符串的方式定义(它应该只用于前台数据,这里只是为了演示) - 'Api': 'backend.plugin.casbin.model.Api', - } - DATA_PERMISSION_COLUMN_EXCLUDE: list[str] = [ # 排除允许进行数据过滤的 SQLA 模型列 - 'id', - 'sort', - 'created_time', - 'updated_time', - ] - - # Plugin + # 插件配置 PLUGIN_PIP_CHINA: bool = True PLUGIN_PIP_INDEX_URL: str = 'https://mirrors.aliyun.com/pypi/simple/' @model_validator(mode='before') @classmethod def check_env(cls, values: Any) -> Any: - if values['ENVIRONMENT'] == 'pro': + """生产环境下禁用 OpenAPI 文档和静态文件服务""" + if values.get('ENVIRONMENT') == 'pro': values['FASTAPI_OPENAPI_URL'] = None values['FASTAPI_STATIC_FILES'] = False return values @@ -193,9 +195,9 @@ def check_env(cls, values: Any) -> Any: @lru_cache def get_settings() -> Settings: - """获取全局配置""" + """获取全局配置单例""" return Settings() -# 创建配置实例 +# 创建全局配置实例 settings = get_settings() diff --git a/backend/core/path_conf.py b/backend/core/path_conf.py index edbc271ac..903e13406 100644 --- a/backend/core/path_conf.py +++ b/backend/core/path_conf.py @@ -1,30 +1,27 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os - from pathlib import Path -# 获取项目根目录 -# 或使用绝对路径,指到backend目录为止,例如windows:BasePath = D:\git_project\fastapi_mysql\backend -BasePath = Path(__file__).resolve().parent.parent +# 项目根目录 +BASE_PATH = Path(__file__).resolve().parent.parent # alembic 迁移文件存放路径 -ALEMBIC_VERSION_DIR = os.path.join(BasePath, 'alembic', 'versions') +ALEMBIC_VERSION_DIR = BASE_PATH / 'alembic' / 'versions' # 日志文件路径 -LOG_DIR = os.path.join(BasePath, 'log') - -# 离线 IP 数据库路径 -IP2REGION_XDB = os.path.join(BasePath, 'static', 'ip2region.xdb') +LOG_DIR = BASE_PATH / 'log' # 静态资源目录 -STATIC_DIR = os.path.join(BasePath, 'static') +STATIC_DIR = BASE_PATH / 'static' # 上传文件目录 -UPLOAD_DIR = os.path.join(BasePath, 'static', 'upload') +UPLOAD_DIR = STATIC_DIR / 'upload' # jinja2 模版文件路径 -JINJA2_TEMPLATE_DIR = os.path.join(BasePath, 'templates') +JINJA2_TEMPLATE_DIR = BASE_PATH / 'templates' # 插件目录 -PLUGIN_DIR = os.path.join(BasePath, 'plugin') +PLUGIN_DIR = BASE_PATH / 'plugin' + +# 离线 IP 数据库路径 +IP2REGION_XDB = STATIC_DIR / 'ip2region.xdb' diff --git a/backend/core/registrar.py b/backend/core/registrar.py index b81af8500..f02d20215 100644 --- a/backend/core/registrar.py +++ b/backend/core/registrar.py @@ -3,6 +3,7 @@ import os from contextlib import asynccontextmanager +from typing import AsyncGenerator import socketio @@ -30,10 +31,11 @@ @asynccontextmanager -async def register_init(app: FastAPI): +async def register_init(app: FastAPI) -> AsyncGenerator[None, None]: """ 启动初始化 + :param app: FastAPI 应用实例 :return: """ # 创建数据库表 @@ -55,8 +57,8 @@ async def register_init(app: FastAPI): await FastAPILimiter.close() -def register_app(): - # FastAPI +def register_app() -> FastAPI: + """注册 FastAPI 应用""" app = FastAPI( title=settings.FASTAPI_TITLE, version=settings.FASTAPI_VERSION, @@ -68,79 +70,71 @@ def register_app(): lifespan=register_init, ) - # socketio + # 注册组件 register_socket_app(app) - - # 日志 register_logger() - - # 静态文件 register_static_file(app) - - # 中间件 register_middleware(app) - - # 路由 register_router(app) - - # 分页 register_page(app) - - # 全局异常处理 register_exception(app) return app def register_logger() -> None: - """ - 系统日志 - - :return: - """ + """注册日志""" setup_logging() set_custom_logfile() -def register_static_file(app: FastAPI): +def register_static_file(app: FastAPI) -> None: """ - 静态资源服务,生产应使用 nginx 代理静态资源服务 + 注册静态资源服务 - :param app: + :param app: FastAPI 应用实例 :return: """ # 上传静态资源 if not os.path.exists(UPLOAD_DIR): os.makedirs(UPLOAD_DIR) app.mount('/static/upload', StaticFiles(directory=UPLOAD_DIR), name='upload') + # 固有静态资源 if settings.FASTAPI_STATIC_FILES: app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static') -def register_middleware(app: FastAPI): +def register_middleware(app: FastAPI) -> None: """ - 中间件,执行顺序从下往上 + 注册中间件(执行顺序从下往上) - :param app: + :param app: FastAPI 应用实例 :return: """ - # Opera log (required) + # Opera log (必须) app.add_middleware(OperaLogMiddleware) - # JWT auth (required) + + # JWT auth (必须) app.add_middleware( - AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler + AuthenticationMiddleware, + backend=JwtAuthMiddleware(), + on_error=JwtAuthMiddleware.auth_exception_handler, ) + # Access log if settings.MIDDLEWARE_ACCESS: from backend.middleware.access_middleware import AccessMiddleware app.add_middleware(AccessMiddleware) + # State app.add_middleware(StateMiddleware) - # Trace ID (required) + + # Trace ID (必须) app.add_middleware(CorrelationIdMiddleware, validator=False) - # CORS: Always at the end + + # CORS(必须放在最下面) if settings.MIDDLEWARE_CORS: from fastapi.middleware.cors import CORSMiddleware @@ -154,19 +148,20 @@ def register_middleware(app: FastAPI): ) -def register_router(app: FastAPI): +def register_router(app: FastAPI) -> None: """ - 路由 + 注册路由 - :param app: FastAPI + :param app: FastAPI 应用实例 :return: """ dependencies = [Depends(demo_site)] if settings.DEMO_MODE else None - # API + # 插件路由 plugin_router_inject() - from backend.app.router import router # 必须在插件路由注入后导入 + # 系统路由(必须在插件路由注入后导入) + from backend.app.router import router app.include_router(router, dependencies=dependencies) @@ -175,21 +170,21 @@ def register_router(app: FastAPI): simplify_operation_ids(app) -def register_page(app: FastAPI): +def register_page(app: FastAPI) -> None: """ - 分页查询 + 注册分页查询功能 - :param app: + :param app: FastAPI 应用实例 :return: """ add_pagination(app) -def register_socket_app(app: FastAPI): +def register_socket_app(app: FastAPI) -> None: """ - socket 应用 + 注册 Socket.IO 应用 - :param app: + :param app: FastAPI 应用实例 :return: """ from backend.common.socketio.server import sio diff --git a/backend/database/db.py b/backend/database/db.py index ef521a998..529bd9c14 100644 --- a/backend/database/db.py +++ b/backend/database/db.py @@ -1,6 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import sys -from typing import Annotated +from typing import Annotated, AsyncGenerator from uuid import uuid4 from fastapi import Depends @@ -33,7 +35,12 @@ def create_database_url(unittest: bool = False) -> URL: def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: - """创建数据库引擎和 Session""" + """ + 创建数据库引擎和 Session + + :param url: 数据库连接 URL + :return: + """ try: # 数据库引擎 engine = create_async_engine( @@ -49,17 +56,21 @@ def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_ pool_pre_ping=True, # 低:False 高:True pool_use_lifo=False, # 低:False 高:True ) - # log.success('数据库连接成功') except Exception as e: log.error('❌ 数据库链接失败 {}', e) sys.exit() else: - db_session = async_sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) + db_session = async_sessionmaker( + bind=engine, + class_=AsyncSession, + autoflush=False, # 禁用自动刷新 + expire_on_commit=False, # 禁用提交时过期 + ) return engine, db_session -async def get_db(): - """session 生成器""" +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话""" async with async_db_session() as session: yield session diff --git a/backend/database/redis.py b/backend/database/redis.py index 7f587d64d..045469438 100644 --- a/backend/database/redis.py +++ b/backend/database/redis.py @@ -10,22 +10,26 @@ class RedisCli(Redis): - def __init__(self): + """Redis 客户端""" + + def __init__(self) -> None: + """初始化 Redis 客户端""" super(RedisCli, self).__init__( host=settings.REDIS_HOST, port=settings.REDIS_PORT, password=settings.REDIS_PASSWORD, db=settings.REDIS_DATABASE, socket_timeout=settings.REDIS_TIMEOUT, + socket_connect_timeout=5, # 连接超时 + socket_keepalive=True, # 保持连接 + health_check_interval=30, # 健康检查间隔 decode_responses=True, # 转码 utf-8 + retry_on_timeout=True, # 超时重试 + max_connections=20, # 最大连接数 ) - async def open(self): - """ - 触发初始化连接 - - :return: - """ + async def open(self) -> None: + """触发初始化连接""" try: await self.ping() except TimeoutError: @@ -38,12 +42,12 @@ async def open(self): log.error('❌ 数据库 redis 连接异常 {}', e) sys.exit() - async def delete_prefix(self, prefix: str, exclude: str | list = None): + async def delete_prefix(self, prefix: str, exclude: str | list[str] | None = None) -> None: """ - 删除指定前缀的所有key + 删除指定前缀的所有 key - :param prefix: - :param exclude: + :param prefix: 前缀 + :param exclude: 排除的 key :return: """ keys = [] diff --git a/backend/middleware/access_middleware.py b/backend/middleware/access_middleware.py index ad3cbedc3..61fa38a44 100644 --- a/backend/middleware/access_middleware.py +++ b/backend/middleware/access_middleware.py @@ -11,6 +11,13 @@ class AccessMiddleware(BaseHTTPMiddleware): """请求日志中间件""" async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + 处理请求并记录访问日志 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ start_time = timezone.now() response = await call_next(request) end_time = timezone.now() diff --git a/backend/middleware/jwt_auth_middleware.py b/backend/middleware/jwt_auth_middleware.py index 09356e83e..537ee7804 100644 --- a/backend/middleware/jwt_auth_middleware.py +++ b/backend/middleware/jwt_auth_middleware.py @@ -18,7 +18,17 @@ class _AuthenticationError(AuthenticationError): """重写内部认证错误类""" - def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None): + def __init__( + self, *, code: int | None = None, msg: str | None = None, headers: dict[str, Any] | None = None + ) -> None: + """ + 初始化认证错误 + + :param code: 错误码 + :param msg: 错误信息 + :param headers: 响应头 + :return: + """ self.code = code self.msg = msg self.headers = headers @@ -29,20 +39,32 @@ class JwtAuthMiddleware(AuthenticationBackend): @staticmethod def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> Response: - """覆盖内部认证错误处理""" + """ + 覆盖内部认证错误处理 + + :param conn: HTTP 连接对象 + :param exc: 认证错误对象 + :return: + """ return MsgSpecJSONResponse(content={'code': exc.code, 'msg': exc.msg, 'data': None}, status_code=exc.code) async def authenticate(self, request: Request) -> tuple[AuthCredentials, GetUserInfoWithRelationDetail] | None: + """ + 认证请求 + + :param request: FastAPI 请求对象 + :return: + """ token = request.headers.get('Authorization') if not token: - return + return None if request.url.path in settings.TOKEN_REQUEST_PATH_EXCLUDE: - return + return None scheme, token = get_authorization_scheme_param(token) if scheme.lower() != 'bearer': - return + return None try: user = await jwt_authentication(token) diff --git a/backend/middleware/opera_log_middleware.py b/backend/middleware/opera_log_middleware.py index 747a9642b..3ac0477e6 100644 --- a/backend/middleware/opera_log_middleware.py +++ b/backend/middleware/opera_log_middleware.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from asyncio import create_task +from typing import Any from asgiref.sync import sync_to_async from fastapi import Response @@ -22,7 +23,14 @@ class OperaLogMiddleware(BaseHTTPMiddleware): """操作日志中间件""" - async def dispatch(self, request: Request, call_next) -> Response: + async def dispatch(self, request: Request, call_next: Any) -> Response: + """ + 处理请求并记录操作日志 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ # 排除记录白名单 path = request.url.path if path in settings.OPERA_LOG_PATH_EXCLUDE or not path.startswith(f'{settings.FASTAPI_API_V1_PATH}'): @@ -70,17 +78,22 @@ async def dispatch(self, request: Request, call_next) -> Response: cost_time=cost_time, opera_time=start_time, ) - create_task(opera_log_service.create(obj_in=opera_log_in)) # noqa: ignore + create_task(opera_log_service.create(obj=opera_log_in)) # noqa: ignore # 错误抛出 - err = request_next.err - if err: - raise err from None + if request_next.err: + raise request_next.err from None return request_next.response - async def execute_request(self, request: Request, call_next) -> RequestCallNext: - """执行请求""" + async def execute_request(self, request: Request, call_next: Any) -> RequestCallNext: + """ + 执行请求并处理异常 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ code = 200 msg = 'Success' status = StatusType.enable @@ -90,10 +103,10 @@ async def execute_request(self, request: Request, call_next) -> RequestCallNext: response = await call_next(request) code, msg = self.request_exception_handler(request, code, msg) except Exception as e: - log.error(f'请求异常: {e}') + log.error(f'请求异常: {str(e)}') # code 处理包含 SQLAlchemy 和 Pydantic - code = getattr(e, 'code', None) or code - msg = getattr(e, 'msg', None) or msg + code = getattr(e, 'code', code) + msg = getattr(e, 'msg', msg) status = StatusType.disable err = e @@ -101,11 +114,17 @@ async def execute_request(self, request: Request, call_next) -> RequestCallNext: @staticmethod def request_exception_handler(request: Request, code: int, msg: str) -> tuple[str, str]: - """请求异常处理器""" + """ + 请求异常处理器 + + :param request: FastAPI 请求对象 + :param code: 错误码 + :param msg: 错误信息 + :return: + """ exception_states = [ '__request_http_exception__', '__request_validation_exception__', - '__request_pydantic_user_error__', '__request_assertion_error__', '__request_custom_exception__', '__request_all_unknown_exception__', @@ -121,8 +140,13 @@ def request_exception_handler(request: Request, code: int, msg: str) -> tuple[st return code, msg @staticmethod - async def get_request_args(request: Request) -> dict: - """获取请求参数""" + async def get_request_args(request: Request) -> dict[str, Any]: + """ + 获取请求参数 + + :param request: FastAPI 请求对象 + :return: + """ args = dict(request.query_params) args.update(request.path_params) # Tip: .body() 必须在 .form() 之前获取 @@ -131,51 +155,46 @@ async def get_request_args(request: Request) -> dict: form_data = await request.form() if len(form_data) > 0: args.update({k: v.filename if isinstance(v, UploadFile) else v for k, v in form_data.items()}) - else: - if body_data: - content_type = request.headers.get('Content-Type', '').split(';')[0].strip().lower() - if content_type == 'application/json': - json_data = await request.json() - if isinstance(json_data, bytes): - json_data = json_data.decode('utf-8') - if isinstance(json_data, dict): - args.update(json_data) - else: - # 注意:非字典数据默认使用 body 作为键 - args.update({'body': json_data}) + elif body_data: + content_type = request.headers.get('Content-Type', '').split(';') + if 'application/json' in content_type: + json_data = await request.json() + if isinstance(json_data, dict): + args.update(json_data) else: + # 注意:非字典数据默认使用 body 作为键 args.update({'body': str(body_data)}) + else: + args.update({'body': str(body_data)}) return args @staticmethod @sync_to_async - def desensitization(args: dict) -> dict | None: + def desensitization(args: dict[str, Any]) -> dict[str, Any] | None: """ 脱敏处理 - :param args: + :param args: 需要脱敏的参数字典 :return: """ if not args: - args = None - else: - match settings.OPERA_LOG_ENCRYPT_TYPE: - case OperaLogCipherType.aes: - for key in args.keys(): - if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE: - args[key] = (AESCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(args[key])).hex() - case OperaLogCipherType.md5: - for key in args.keys(): - if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE: - args[key] = Md5Cipher.encrypt(args[key]) - case OperaLogCipherType.itsdangerous: - for key in args.keys(): - if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE: - args[key] = ItsDCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(args[key]) - case OperaLogCipherType.plan: - pass - case _: - for key in args.keys(): - if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE: - args[key] = '******' + return None + + encrypt_type = settings.OPERA_LOG_ENCRYPT_TYPE + encrypt_key_include = settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE + encrypt_secret_key = settings.OPERA_LOG_ENCRYPT_SECRET_KEY + + for key, value in args.items(): + if key in encrypt_key_include: + match encrypt_type: + case OperaLogCipherType.aes: + args[key] = (AESCipher(encrypt_secret_key).encrypt(value)).hex() + case OperaLogCipherType.md5: + args[key] = Md5Cipher.encrypt(value) + case OperaLogCipherType.itsdangerous: + args[key] = ItsDCipher(encrypt_secret_key).encrypt(value) + case OperaLogCipherType.plan: + pass + case _: + args[key] = '******' return args diff --git a/backend/middleware/state_middleware.py b/backend/middleware/state_middleware.py index b2246b7ec..b7307524e 100644 --- a/backend/middleware/state_middleware.py +++ b/backend/middleware/state_middleware.py @@ -7,9 +7,16 @@ class StateMiddleware(BaseHTTPMiddleware): - """请求 state 中间件""" + """请求 state 中间件,用于解析和设置请求的附加信息""" async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + 处理请求并设置请求状态信息 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ ip_info = await parse_ip_info(request) ua_info = parse_user_agent_info(request) diff --git a/backend/plugin/casbin/api/v1/sys/api.py b/backend/plugin/casbin/api/v1/sys/api.py index 6f0ce1bdb..6bf968be5 100644 --- a/backend/plugin/casbin/api/v1/sys/api.py +++ b/backend/plugin/casbin/api/v1/sys/api.py @@ -23,14 +23,14 @@ async def get_all_apis() -> ResponseSchemaModel[list[GetApiDetail]]: @router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth]) -async def get_api(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetApiDetail]: +async def get_api(pk: Annotated[int, Path(description='API ID')]) -> ResponseSchemaModel[GetApiDetail]: api = await api_service.get(pk=pk) return response_base.success(data=api) @router.get( '', - summary='(模糊条件)分页获取所有接口', + summary='分页获取所有接口', dependencies=[ DependsJwtAuth, DependsPagination, @@ -39,9 +39,9 @@ async def get_api(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetApiDe async def get_pagination_apis( request: Request, db: CurrentSession, - name: Annotated[str | None, Query()] = None, - method: Annotated[str | None, Query()] = None, - path: Annotated[str | None, Query()] = None, + name: Annotated[str | None, Query(description='API 名称')] = None, + method: Annotated[str | None, Query(description='请求方法')] = None, + path: Annotated[str | None, Query(description='API 路径')] = None, ) -> ResponseSchemaModel[PageData[GetApiDetail]]: api_select = await api_service.get_select(request=request, name=name, method=method, path=path) page_data = await paging_data(db, api_select) @@ -69,7 +69,7 @@ async def create_api(obj: CreateApiParam) -> ResponseModel: DependsRBAC, ], ) -async def update_api(pk: Annotated[int, Path(...)], obj: UpdateApiParam) -> ResponseModel: +async def update_api(pk: Annotated[int, Path(description='API ID')], obj: UpdateApiParam) -> ResponseModel: count = await api_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -78,13 +78,13 @@ async def update_api(pk: Annotated[int, Path(...)], obj: UpdateApiParam) -> Resp @router.delete( '', - summary='(批量)删除接口', + summary='批量删除接口', dependencies=[ Depends(RequestPermission('sys:api:del')), DependsRBAC, ], ) -async def delete_api(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_api(pk: Annotated[list[int], Query(description='API ID 列表')]) -> ResponseModel: count = await api_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/plugin/casbin/api/v1/sys/casbin.py b/backend/plugin/casbin/api/v1/sys/casbin.py index e5929c91a..88069adad 100644 --- a/backend/plugin/casbin/api/v1/sys/casbin.py +++ b/backend/plugin/casbin/api/v1/sys/casbin.py @@ -12,11 +12,11 @@ from backend.common.security.rbac import DependsRBAC from backend.database.db import CurrentSession from backend.plugin.casbin.schema.casbin_rule import ( + CreateGroupParam, CreatePolicyParam, - CreateUserRoleParam, DeleteAllPoliciesParam, + DeleteGroupParam, DeletePolicyParam, - DeleteUserRoleParam, GetPolicyDetail, UpdatePoliciesParam, UpdatePolicyParam, @@ -28,7 +28,7 @@ @router.get( '', - summary='(模糊条件)分页获取所有权限策略', + summary='分页获取所有权限策略', dependencies=[ DependsJwtAuth, DependsPagination, @@ -36,17 +36,17 @@ ) async def get_pagination_casbin( db: CurrentSession, - ptype: Annotated[str | None, Query(description='策略类型, p / g')] = None, - sub: Annotated[str | None, Query(description='用户 uuid / 角色')] = None, + ptype: Annotated[str | None, Query(description='策略类型:p / g')] = None, + sub: Annotated[str | None, Query(description='用户 UUID / 角色 ID')] = None, ) -> ResponseSchemaModel[PageData[GetPolicyDetail]]: casbin_select = await casbin_service.get_casbin_list(ptype=ptype, sub=sub) page_data = await paging_data(db, casbin_select) return response_base.success(data=page_data) -@router.get('/policies', summary='获取所有P权限策略', dependencies=[DependsJwtAuth]) +@router.get('/policies', summary='获取所有 P 权限策略', dependencies=[DependsJwtAuth]) async def get_all_policies( - role: Annotated[int | None, Query(description='角色ID')] = None, + role: Annotated[int | None, Query(description='角色 ID')] = None, ) -> ResponseSchemaModel[list[list[str]]]: policies = await casbin_service.get_policy_list(role=role) return response_base.success(data=policies) @@ -54,29 +54,20 @@ async def get_all_policies( @router.post( '/policy', - summary='添加P权限策略', + summary='添加 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:add')), DependsRBAC, ], ) async def create_policy(p: CreatePolicyParam) -> ResponseSchemaModel[bool]: - """ - p 策略: - - - 推荐添加基于角色的访问权限, 需配合添加 g 策略才能真正拥有访问权限,适合配置全局接口访问策略
- **格式**: 角色 role + 访问路径 path + 访问方法 method - - - 如果添加基于用户的访问权限, 不需配合添加 g 策略就能真正拥有权限,适合配置指定用户接口访问策略
- **格式**: 用户 uuid + 访问路径 path + 访问方法 method - """ data = await casbin_service.create_policy(p=p) return response_base.success(data=data) @router.post( '/policies', - summary='添加多组P权限策略', + summary='添加多组 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:group:add')), DependsRBAC, @@ -89,7 +80,7 @@ async def create_policies(ps: list[CreatePolicyParam]) -> ResponseSchemaModel[bo @router.put( '/policy', - summary='更新P权限策略', + summary='更新 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:edit')), DependsRBAC, @@ -102,7 +93,7 @@ async def update_policy(obj: UpdatePolicyParam) -> ResponseSchemaModel[bool]: @router.put( '/policies', - summary='更新多组P权限策略', + summary='更新多组 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:group:edit')), DependsRBAC, @@ -115,7 +106,7 @@ async def update_policies(obj: UpdatePoliciesParam) -> ResponseSchemaModel[bool] @router.delete( '/policy', - summary='删除P权限策略', + summary='删除 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:del')), DependsRBAC, @@ -128,7 +119,7 @@ async def delete_policy(p: DeletePolicyParam) -> ResponseSchemaModel[bool]: @router.delete( '/policies', - summary='删除多组P权限策略', + summary='删除多组 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:group:del')), DependsRBAC, @@ -141,7 +132,7 @@ async def delete_policies(ps: list[DeletePolicyParam]) -> ResponseSchemaModel[bo @router.delete( '/policies/all', - summary='删除所有P权限策略', + summary='删除所有 P 权限策略', dependencies=[ Depends(RequestPermission('casbin:p:empty')), DependsRBAC, @@ -154,7 +145,7 @@ async def delete_all_policies(sub: DeleteAllPoliciesParam) -> ResponseModel: return response_base.fail() -@router.get('/groups', summary='获取所有G权限策略', dependencies=[DependsJwtAuth]) +@router.get('/groups', summary='获取所有 G 权限策略', dependencies=[DependsJwtAuth]) async def get_all_groups() -> ResponseSchemaModel[list[list[str]]]: data = await casbin_service.get_group_list() return response_base.success(data=data) @@ -162,74 +153,65 @@ async def get_all_groups() -> ResponseSchemaModel[list[list[str]]]: @router.post( '/group', - summary='添加G权限策略', + summary='添加 G 权限策略', dependencies=[ Depends(RequestPermission('casbin:g:add')), DependsRBAC, ], ) -async def create_group(g: CreateUserRoleParam) -> ResponseSchemaModel[bool]: - """ - g 策略 (**依赖 p 策略**): - - - 如果在 p 策略中添加了基于角色的访问权限, 则还需要在 g 策略中添加基于用户组的访问权限, 才能真正拥有访问权限
- **格式**: 用户 uuid + 角色 role - - - 如果在 p 策略中添加了基于用户的访问权限, 则不添加相应的 g 策略能直接拥有访问权限
- 但是拥有的不是用户角色的所有权限, 而只是单一的对应的 p 策略所添加的访问权限 - """ +async def create_group(g: CreateGroupParam) -> ResponseSchemaModel[bool]: data = await casbin_service.create_group(g=g) return response_base.success(data=data) @router.post( '/groups', - summary='添加多组G权限策略', + summary='添加多组 G 权限策略', dependencies=[ Depends(RequestPermission('casbin:g:group:add')), DependsRBAC, ], ) -async def create_groups(gs: list[CreateUserRoleParam]) -> ResponseSchemaModel[bool]: +async def create_groups(gs: list[CreateGroupParam]) -> ResponseSchemaModel[bool]: data = await casbin_service.create_groups(gs=gs) return response_base.success(data=data) @router.delete( '/group', - summary='删除G权限策略', + summary='删除 G 权限策略', dependencies=[ Depends(RequestPermission('casbin:g:del')), DependsRBAC, ], ) -async def delete_group(g: DeleteUserRoleParam) -> ResponseSchemaModel[bool]: +async def delete_group(g: DeleteGroupParam) -> ResponseSchemaModel[bool]: data = await casbin_service.delete_group(g=g) return response_base.success(data=data) @router.delete( '/groups', - summary='删除多组G权限策略', + summary='删除多组 G 权限策略', dependencies=[ Depends(RequestPermission('casbin:g:group:del')), DependsRBAC, ], ) -async def delete_groups(gs: list[DeleteUserRoleParam]) -> ResponseSchemaModel[bool]: +async def delete_groups(gs: list[DeleteGroupParam]) -> ResponseSchemaModel[bool]: data = await casbin_service.delete_groups(gs=gs) return response_base.success(data=data) @router.delete( '/groups/all', - summary='删除所有G权限策略', + summary='删除所有 G 权限策略', dependencies=[ Depends(RequestPermission('casbin:g:empty')), DependsRBAC, ], ) -async def delete_all_groups(uuid: Annotated[UUID, Query(...)]) -> ResponseModel: +async def delete_all_groups(uuid: Annotated[UUID, Query()]) -> ResponseModel: count = await casbin_service.delete_all_groups(uuid=uuid) if count > 0: return response_base.success() diff --git a/backend/plugin/casbin/conf.py b/backend/plugin/casbin/conf.py index e29837b9f..ee031e466 100644 --- a/backend/plugin/casbin/conf.py +++ b/backend/plugin/casbin/conf.py @@ -8,7 +8,7 @@ class CasbinSettings(BaseSettings): - """Casbin Settings""" + """Casbin 配置""" # RBAC RBAC_CASBIN_EXCLUDE: set[tuple[str, str]] = { @@ -19,7 +19,7 @@ class CasbinSettings(BaseSettings): @lru_cache def get_casbin_settings() -> CasbinSettings: - """获取 xxx 配置""" + """获取 Casbin 配置""" return CasbinSettings() diff --git a/backend/plugin/casbin/crud/crud_api.py b/backend/plugin/casbin/crud/crud_api.py index 8112de5ab..343aab9b0 100644 --- a/backend/plugin/casbin/crud/crud_api.py +++ b/backend/plugin/casbin/crud/crud_api.py @@ -13,12 +13,14 @@ class CRUDApi(CRUDPlus[Api]): + """API 数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> Api | None: """ 获取 API - :param db: - :param pk: + :param db: 数据库会话 + :param pk: API ID :return: """ return await self.select_model(db, pk) @@ -27,10 +29,10 @@ async def get_list(self, request: Request, name: str = None, method: str = None, """ 获取 API 列表 - :param request: - :param name: - :param method: - :param path: + :param request: FastAPI 请求对象 + :param name: API 名称 + :param method: 请求方法 + :param path: API 路径 :return: """ filters = {} @@ -47,48 +49,48 @@ async def get_all(self, db: AsyncSession) -> Sequence[Api]: """ 获取所有 API - :param db: + :param db: 数据库会话 :return: """ return await self.select_models(db) async def get_by_name(self, db: AsyncSession, name: str) -> Api | None: """ - 通过 name 获取 API + 通过名称获取 API - :param db: - :param name: + :param db: 数据库会话 + :param name: API 名称 :return: """ return await self.select_model_by_column(db, name=name) - async def create(self, db: AsyncSession, obj_in: CreateApiParam) -> None: + async def create(self, db: AsyncSession, obj: CreateApiParam) -> None: """ 创建 API - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建 API 参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateApiParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateApiParam) -> int: """ 更新 API - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: API ID + :param obj: 更新 API 参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除 API - :param db: - :param pk: + :param db: 数据库会话 + :param pk: API ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) diff --git a/backend/plugin/casbin/crud/crud_casbin.py b/backend/plugin/casbin/crud/crud_casbin.py index 5af907540..3a658010d 100644 --- a/backend/plugin/casbin/crud/crud_casbin.py +++ b/backend/plugin/casbin/crud/crud_casbin.py @@ -11,35 +11,38 @@ class CRUDCasbin(CRUDPlus[CasbinRule]): + """Casbin 规则数据库操作类""" + async def get_list(self, ptype: str, sub: str) -> Select: """ 获取策略列表 - :param ptype: - :param sub: + :param ptype: 策略类型 + :param sub: 用户 UUID / 角色 ID :return: """ return await self.select_order('id', 'desc', ptype=ptype, v0__like=f'%{sub}%') async def delete_policies_by_sub(self, db: AsyncSession, sub: DeleteAllPoliciesParam) -> int: """ - 删除角色所有P策略 + 删除角色所有 P 策略 - :param db: - :param sub: + :param db: 数据库会话 + :param sub: 删除所有 P 策略参数 :return: """ - where_list = [sub.role] + filters = [sub.role] if sub.uuid: - where_list.append(sub.uuid) - return await self.delete_model_by_column(db, allow_multiple=True, v0__mor={'eq': where_list}) + filters.append(sub.uuid) + + return await self.delete_model_by_column(db, allow_multiple=True, v0__mor={'eq': filters}) async def delete_groups_by_uuid(self, db: AsyncSession, uuid: UUID) -> int: """ - 删除用户所有G策略 + 删除用户所有 G 策略 - :param db: - :param uuid: + :param db: 数据库会话 + :param uuid: 用户 UUID :return: """ return await self.delete_model_by_column(db, allow_multiple=True, v0=str(uuid)) diff --git a/backend/plugin/casbin/model/api.py b/backend/plugin/casbin/model/api.py index ac948efa5..45127b869 100644 --- a/backend/plugin/casbin/model/api.py +++ b/backend/plugin/casbin/model/api.py @@ -9,12 +9,12 @@ class Api(Base): - """系统api""" + """API 表""" __tablename__ = 'sys_api' id: Mapped[id_key] = mapped_column(init=False) - name: Mapped[str] = mapped_column(String(50), unique=True, comment='api名称') + name: Mapped[str] = mapped_column(String(50), unique=True, comment='API 名称') method: Mapped[str] = mapped_column(String(16), comment='请求方法') - path: Mapped[str] = mapped_column(String(500), comment='api路径') + path: Mapped[str] = mapped_column(String(500), comment='API 路径') remark: Mapped[str | None] = mapped_column(LONGTEXT().with_variant(TEXT, 'postgresql'), comment='备注') diff --git a/backend/plugin/casbin/model/casbin_rule.py b/backend/plugin/casbin/model/casbin_rule.py index 3f8afcce5..7569f2d4e 100644 --- a/backend/plugin/casbin/model/casbin_rule.py +++ b/backend/plugin/casbin/model/casbin_rule.py @@ -9,20 +9,20 @@ class CasbinRule(MappedBase): - """重写 casbin 中的 CasbinRule model 类, 使用自定义 Base, 避免产生 alembic 迁移问题""" + """Casbin 规则表""" __tablename__ = 'sys_casbin_rule' id: Mapped[id_key] ptype: Mapped[str] = mapped_column(String(255), comment='策略类型: p / g') - v0: Mapped[str] = mapped_column(String(255), comment='角色ID / 用户uuid') - v1: Mapped[str] = mapped_column(LONGTEXT().with_variant(TEXT, 'postgresql'), comment='api路径 / 角色名称') + v0: Mapped[str] = mapped_column(String(255), comment='用户 UUID / 角色 ID') + v1: Mapped[str] = mapped_column(LONGTEXT().with_variant(TEXT, 'postgresql'), comment='API 路径 / 角色名称') v2: Mapped[str | None] = mapped_column(String(255), comment='请求方法') - v3: Mapped[str | None] = mapped_column(String(255)) - v4: Mapped[str | None] = mapped_column(String(255)) - v5: Mapped[str | None] = mapped_column(String(255)) + v3: Mapped[str | None] = mapped_column(String(255), comment='预留字段') + v4: Mapped[str | None] = mapped_column(String(255), comment='预留字段') + v5: Mapped[str | None] = mapped_column(String(255), comment='预留字段') - def __str__(self): + def __str__(self) -> str: arr = [self.ptype] for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5): if v is None: @@ -30,5 +30,5 @@ def __str__(self): arr.append(v) return ', '.join(arr) - def __repr__(self): - return ''.format(self.id, str(self)) + def __repr__(self) -> str: + return f'' diff --git a/backend/plugin/casbin/schema/api.py b/backend/plugin/casbin/schema/api.py index 8237e00e8..330360094 100644 --- a/backend/plugin/casbin/schema/api.py +++ b/backend/plugin/casbin/schema/api.py @@ -9,23 +9,27 @@ class ApiSchemaBase(SchemaBase): - name: str - method: MethodType = Field(default=MethodType.GET, description='请求方法') - path: str = Field(description='api路径') - remark: str | None = None + """API 基础模型""" + + name: str = Field(description='API 名称') + method: MethodType = Field(MethodType.GET, description='请求方法') + path: str = Field(description='API 路径') + remark: str | None = Field(None, description='备注') class CreateApiParam(ApiSchemaBase): - pass + """创建 API 参数""" class UpdateApiParam(ApiSchemaBase): - pass + """更新 API 参数""" class GetApiDetail(ApiSchemaBase): + """API 详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='API ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/plugin/casbin/schema/casbin_rule.py b/backend/plugin/casbin/schema/casbin_rule.py index 8a569e11d..f5a74e4f9 100644 --- a/backend/plugin/casbin/schema/casbin_rule.py +++ b/backend/plugin/casbin/schema/casbin_rule.py @@ -7,47 +7,59 @@ class CreatePolicyParam(SchemaBase): - sub: str = Field(description='用户uuid / 角色ID') - path: str = Field(description='api 路径') - method: MethodType = Field(default=MethodType.GET, description='请求方法') + """创建 P 策略参数""" + + sub: str = Field(description='用户 UUID / 角色 ID') + path: str = Field(description='API 路径') + method: MethodType = Field(MethodType.GET, description='请求方法') class UpdatePolicyParam(SchemaBase): - old: CreatePolicyParam - new: CreatePolicyParam + """更新 P 策略参数""" + + old: CreatePolicyParam = Field(description='原策略') + new: CreatePolicyParam = Field(description='新策略') class UpdatePoliciesParam(SchemaBase): - old: list[CreatePolicyParam] - new: list[CreatePolicyParam] + """批量更新策略参数""" + + old: list[CreatePolicyParam] = Field(description='原策略列表') + new: list[CreatePolicyParam] = Field(description='新策略列表') class DeletePolicyParam(CreatePolicyParam): - pass + """删除策略参数""" class DeleteAllPoliciesParam(SchemaBase): - uuid: str | None = None - role: str + """删除所有策略参数""" + uuid: str | None = Field(None, description='用户 UUID') + role: str = Field(description='角色') -class CreateUserRoleParam(SchemaBase): - uuid: str = Field(description='用户 uuid') + +class CreateGroupParam(SchemaBase): + """创建 G 策略参数""" + + uuid: str = Field(description='用户 UUID') role: str = Field(description='角色') -class DeleteUserRoleParam(CreateUserRoleParam): - pass +class DeleteGroupParam(CreateGroupParam): + """删除 G 策略参数""" class GetPolicyDetail(SchemaBase): + """策略详情""" + model_config = ConfigDict(from_attributes=True) - id: int + id: int = Field(description='规则 ID') ptype: str = Field(description='规则类型, p / g') - v0: str = Field(description='用户 uuid / 角色') - v1: str = Field(description='api 路径 / 角色') - v2: str | None = None - v3: str | None = None - v4: str | None = None - v5: str | None = None + v0: str = Field(description='用户 UUID / 角色 ID') + v1: str = Field(description='API 路径 / 角色') + v2: str | None = Field(None, description='请求方法') + v3: str | None = Field(None, description='预留字段') + v4: str | None = Field(None, description='预留字段') + v5: str | None = Field(None, description='预留字段') diff --git a/backend/plugin/casbin/service/api_service.py b/backend/plugin/casbin/service/api_service.py index 6cdfc2073..47af85b08 100644 --- a/backend/plugin/casbin/service/api_service.py +++ b/backend/plugin/casbin/service/api_service.py @@ -13,8 +13,16 @@ class ApiService: + """API 服务类""" + @staticmethod async def get(*, pk: int) -> Api: + """ + 获取 API + + :param pk: API ID + :return: + """ async with async_db_session() as db: api = await api_dao.get(db, pk) if not api: @@ -23,16 +31,32 @@ async def get(*, pk: int) -> Api: @staticmethod async def get_select(*, request: Request, name: str = None, method: str = None, path: str = None) -> Select: + """ + 获取 API 查询对象 + + :param request: 请求对象 + :param name: API 名称 + :param method: 请求方法 + :param path: API 路径 + :return: + """ return await api_dao.get_list(request=request, name=name, method=method, path=path) @staticmethod async def get_all() -> Sequence[Api]: + """获取所有 API""" async with async_db_session() as db: apis = await api_dao.get_all(db) return apis @staticmethod async def create(*, obj: CreateApiParam) -> None: + """ + 创建 API + + :param obj: 创建 API 参数 + :return: + """ async with async_db_session.begin() as db: api = await api_dao.get_by_name(db, obj.name) if api: @@ -41,6 +65,13 @@ async def create(*, obj: CreateApiParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateApiParam) -> int: + """ + 更新 API + + :param pk: API ID + :param obj: 更新 API 参数 + :return: + """ async with async_db_session.begin() as db: api = await api_dao.get(db, pk) if not api: @@ -50,6 +81,12 @@ async def update(*, pk: int, obj: UpdateApiParam) -> int: @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除 API + + :param pk: API ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await api_dao.delete(db, pk) return count diff --git a/backend/plugin/casbin/service/casbin_service.py b/backend/plugin/casbin/service/casbin_service.py index f93670bb5..b15371a59 100644 --- a/backend/plugin/casbin/service/casbin_service.py +++ b/backend/plugin/casbin/service/casbin_service.py @@ -8,11 +8,11 @@ from backend.database.db import async_db_session from backend.plugin.casbin.crud.crud_casbin import casbin_dao from backend.plugin.casbin.schema.casbin_rule import ( + CreateGroupParam, CreatePolicyParam, - CreateUserRoleParam, DeleteAllPoliciesParam, + DeleteGroupParam, DeletePolicyParam, - DeleteUserRoleParam, UpdatePoliciesParam, UpdatePolicyParam, ) @@ -20,12 +20,27 @@ class CasbinService: + """Casbin 权限服务类""" + @staticmethod async def get_casbin_list(*, ptype: str, sub: str) -> Select: + """ + 获取 Casbin 规则列表 + + :param ptype: 策略类型 + :param sub: 用户 UUID / 角色 ID + :return: + """ return await casbin_dao.get_list(ptype, sub) @staticmethod async def get_policy_list(*, role: int | None = None) -> list: + """ + 获取 P 策略列表 + + :param role: 角色ID + :return: + """ enforcer = await casbin_enforcer() if role is not None: data = enforcer.get_filtered_named_policy('p', 0, str(role)) @@ -35,6 +50,12 @@ async def get_policy_list(*, role: int | None = None) -> list: @staticmethod async def create_policy(*, p: CreatePolicyParam) -> bool: + """ + 创建 P 策略 + + :param p: 策略参数 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.add_policy(p.sub, p.path, p.method) if not data: @@ -43,6 +64,12 @@ async def create_policy(*, p: CreatePolicyParam) -> bool: @staticmethod async def create_policies(*, ps: list[CreatePolicyParam]) -> bool: + """ + 批量创建 P 策略 + + :param ps: 策略参数列表 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.add_policies([list(p.model_dump().values()) for p in ps]) if not data: @@ -51,6 +78,12 @@ async def create_policies(*, ps: list[CreatePolicyParam]) -> bool: @staticmethod async def update_policy(*, obj: UpdatePolicyParam) -> bool: + """ + 更新 P 策略 + + :param obj: 更新 P 策略参数 + :return: + """ old_obj = obj.old new_obj = obj.new enforcer = await casbin_enforcer() @@ -58,20 +91,34 @@ async def update_policy(*, obj: UpdatePolicyParam) -> bool: if not _p: raise errors.NotFoundError(msg='权限不存在') data = await enforcer.update_policy( - [old_obj.sub, old_obj.path, old_obj.method], [new_obj.sub, new_obj.path, new_obj.method] + [old_obj.sub, old_obj.path, old_obj.method], + [new_obj.sub, new_obj.path, new_obj.method], ) return data @staticmethod async def update_policies(*, obj: UpdatePoliciesParam) -> bool: + """ + 批量更新 P 策略 + + :param obj: 更新 P 策略参数 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.update_policies( - [list(o.model_dump().values()) for o in obj.old], [list(n.model_dump().values()) for n in obj.new] + [list(o.model_dump().values()) for o in obj.old], + [list(n.model_dump().values()) for n in obj.new], ) return data @staticmethod async def delete_policy(*, p: DeletePolicyParam) -> bool: + """ + 删除 P 策略 + + :param p: 删除参数 + :return: + """ enforcer = await casbin_enforcer() _p = enforcer.has_policy(p.sub, p.path, p.method) if not _p: @@ -81,6 +128,12 @@ async def delete_policy(*, p: DeletePolicyParam) -> bool: @staticmethod async def delete_policies(*, ps: list[DeletePolicyParam]) -> bool: + """ + 批量删除 P 策略 + + :param ps: 删除参数列表 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.remove_policies([list(p.model_dump().values()) for p in ps]) if not data: @@ -89,18 +142,31 @@ async def delete_policies(*, ps: list[DeletePolicyParam]) -> bool: @staticmethod async def delete_all_policies(*, sub: DeleteAllPoliciesParam) -> int: + """ + 删除所有 P 策略 + + :param sub: 删除参数 + :return: + """ async with async_db_session.begin() as db: count = await casbin_dao.delete_policies_by_sub(db, sub) return count @staticmethod async def get_group_list() -> list: + """获取 G 策略列表""" enforcer = await casbin_enforcer() data = enforcer.get_grouping_policy() return data @staticmethod - async def create_group(*, g: CreateUserRoleParam) -> bool: + async def create_group(*, g: CreateGroupParam) -> bool: + """ + 创建 G 策略 + + :param g: 创建 G 策略参数 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.add_grouping_policy(g.uuid, g.role) if not data: @@ -108,7 +174,13 @@ async def create_group(*, g: CreateUserRoleParam) -> bool: return data @staticmethod - async def create_groups(*, gs: list[CreateUserRoleParam]) -> bool: + async def create_groups(*, gs: list[CreateGroupParam]) -> bool: + """ + 批量创建 G 策略 + + :param gs: 创建参数列表 + :return: + """ enforcer = await casbin_enforcer() data = await enforcer.add_grouping_policies([list(g.model_dump().values()) for g in gs]) if not data: @@ -116,7 +188,13 @@ async def create_groups(*, gs: list[CreateUserRoleParam]) -> bool: return data @staticmethod - async def delete_group(*, g: DeleteUserRoleParam) -> bool: + async def delete_group(*, g: DeleteGroupParam) -> bool: + """ + 删除 G 策略 + + :param g: 删除参数 + :return: + """ enforcer = await casbin_enforcer() _g = enforcer.has_grouping_policy(g.uuid, g.role) if not _g: @@ -125,7 +203,13 @@ async def delete_group(*, g: DeleteUserRoleParam) -> bool: return data @staticmethod - async def delete_groups(*, gs: list[DeleteUserRoleParam]) -> bool: + async def delete_groups(*, gs: list[DeleteGroupParam]) -> bool: + """ + 批量删除 G 策略 + + :param gs: 删除参数列表 + :return: 是否成功 + """ enforcer = await casbin_enforcer() data = await enforcer.remove_grouping_policies([list(g.model_dump().values()) for g in gs]) if not data: @@ -134,6 +218,12 @@ async def delete_groups(*, gs: list[DeleteUserRoleParam]) -> bool: @staticmethod async def delete_all_groups(*, uuid: UUID) -> int: + """ + 删除所有 G 策略 + + :param uuid: 用户uuid + :return: 删除数量 + """ async with async_db_session.begin() as db: count = await casbin_dao.delete_groups_by_uuid(db, uuid) return count diff --git a/backend/plugin/casbin/utils/rbac.py b/backend/plugin/casbin/utils/rbac.py index 2932d7dac..1eb81f787 100644 --- a/backend/plugin/casbin/utils/rbac.py +++ b/backend/plugin/casbin/utils/rbac.py @@ -12,11 +12,7 @@ async def casbin_enforcer() -> casbin.AsyncEnforcer: - """ - 获取 casbin 执行器 - - :return: - """ + """获取 casbin 执行器""" # 模型定义:https://casbin.org/zh/docs/category/model _CASBIN_RBAC_MODEL_CONF_TEXT = """ [request_definition] @@ -45,7 +41,7 @@ async def casbin_verify(request: Request) -> None: """ Casbin 权限校验 - :param request: + :param request: FastAPI 请求对象 :return: """ method = request.method diff --git a/backend/plugin/notice/api/v1/sys/notice.py b/backend/plugin/notice/api/v1/sys/notice.py index e4ecbcb0e..2832130d5 100644 --- a/backend/plugin/notice/api/v1/sys/notice.py +++ b/backend/plugin/notice/api/v1/sys/notice.py @@ -17,20 +17,20 @@ @router.get('/{pk}', summary='获取通知公告详情', dependencies=[DependsJwtAuth]) -async def get_notice(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetNoticeDetail]: +async def get_notice(pk: Annotated[int, Path(description='通知公告 ID')]) -> ResponseSchemaModel[GetNoticeDetail]: notice = await notice_service.get(pk=pk) return response_base.success(data=notice) @router.get( '', - summary='(模糊条件)分页获取所有通知公告', + summary='分页获取所有通知公告', dependencies=[ DependsJwtAuth, DependsPagination, ], ) -async def get_pagination_notice(db: CurrentSession) -> ResponseSchemaModel[PageData[GetNoticeDetail]]: +async def get_pagination_notices(db: CurrentSession) -> ResponseSchemaModel[PageData[GetNoticeDetail]]: notice_select = await notice_service.get_select() page_data = await paging_data(db, notice_select) return response_base.success(data=page_data) @@ -57,7 +57,7 @@ async def create_notice(obj: CreateNoticeParam) -> ResponseModel: DependsRBAC, ], ) -async def update_notice(pk: Annotated[int, Path(...)], obj: UpdateNoticeParam) -> ResponseModel: +async def update_notice(pk: Annotated[int, Path(description='通知公告 ID')], obj: UpdateNoticeParam) -> ResponseModel: count = await notice_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -66,13 +66,13 @@ async def update_notice(pk: Annotated[int, Path(...)], obj: UpdateNoticeParam) - @router.delete( '', - summary='(批量)删除通知公告', + summary='批量删除通知公告', dependencies=[ Depends(RequestPermission('sys:notice:del')), DependsRBAC, ], ) -async def delete_notice(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_notice(pk: Annotated[list[int], Query(description='通知公告 ID 列表')]) -> ResponseModel: count = await notice_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/plugin/notice/crud/crud_notice.py b/backend/plugin/notice/crud/crud_notice.py index a2835da69..756ec9aa7 100644 --- a/backend/plugin/notice/crud/crud_notice.py +++ b/backend/plugin/notice/crud/crud_notice.py @@ -11,60 +11,58 @@ class CRUDNotice(CRUDPlus[Notice]): + """通知公告数据库操作类""" + async def get(self, db: AsyncSession, pk: int) -> Notice | None: """ - 获取系统通知公告 + 获取通知公告 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 通知公告 ID :return: """ return await self.select_model(db, pk) async def get_list(self) -> Select: - """ - 获取系统通知公告列表 - - :return: - """ + """获取通知公告列表""" return await self.select_order('created_time', 'desc') async def get_all(self, db: AsyncSession) -> Sequence[Notice]: """ - 获取所有系统通知公告 + 获取所有通知公告 - :param db: + :param db: 数据库会话 :return: """ return await self.select_models(db) - async def create(self, db: AsyncSession, obj_in: CreateNoticeParam) -> None: + async def create(self, db: AsyncSession, obj: CreateNoticeParam) -> None: """ - 创建系统通知公告 + 创建通知公告 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建通知公告参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateNoticeParam) -> int: + async def update(self, db: AsyncSession, pk: int, obj: UpdateNoticeParam) -> int: """ - 更新系统通知公告 + 更新通知公告 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: 通知公告 ID + :param obj: 更新通知公告参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ - 删除系统通知公告 + 删除通知公告 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 通知公告 ID 列表 :return: """ return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) diff --git a/backend/plugin/notice/model/notice.py b/backend/plugin/notice/model/notice.py index 2f5801832..35519cb02 100644 --- a/backend/plugin/notice/model/notice.py +++ b/backend/plugin/notice/model/notice.py @@ -8,7 +8,7 @@ class Notice(Base): - """系统通知公告""" + """系统通知公告表""" __tablename__ = 'sys_notice' diff --git a/backend/plugin/notice/schema/notice.py b/backend/plugin/notice/schema/notice.py index a5b3161a0..eb68192a3 100644 --- a/backend/plugin/notice/schema/notice.py +++ b/backend/plugin/notice/schema/notice.py @@ -9,25 +9,29 @@ class NoticeSchemaBase(SchemaBase): - title: str - type: int - author: str - source: str - status: StatusType = Field(default=StatusType.enable) - content: str + """通知公告基础模型""" + + title: str = Field(description='标题') + type: int = Field(description='类型(0:通知、1:公告)') + author: str = Field(description='作者') + source: str = Field(description='信息来源') + status: StatusType = Field(StatusType.enable, description='状态(0:隐藏、1:显示)') + content: str = Field(description='内容') class CreateNoticeParam(NoticeSchemaBase): - pass + """创建通知公告参数""" class UpdateNoticeParam(NoticeSchemaBase): - pass + """更新通知公告参数""" class GetNoticeDetail(NoticeSchemaBase): + """通知公告详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='通知公告 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(None, description='更新时间') diff --git a/backend/plugin/notice/service/notice_service.py b/backend/plugin/notice/service/notice_service.py index 5e8e8dc24..55f4c1181 100644 --- a/backend/plugin/notice/service/notice_service.py +++ b/backend/plugin/notice/service/notice_service.py @@ -12,8 +12,16 @@ class NoticeService: + """通知公告服务类""" + @staticmethod async def get(*, pk: int) -> Notice: + """ + 获取通知公告 + + :param pk: 通知公告 ID + :return: + """ async with async_db_session() as db: notice = await notice_dao.get(db, pk) if not notice: @@ -22,21 +30,36 @@ async def get(*, pk: int) -> Notice: @staticmethod async def get_select() -> Select: + """获取通知公告查询对象""" return await notice_dao.get_list() @staticmethod async def get_all() -> Sequence[Notice]: + """获取所有通知公告""" async with async_db_session() as db: notices = await notice_dao.get_all(db) return notices @staticmethod async def create(*, obj: CreateNoticeParam) -> None: + """ + 创建通知公告 + + :param obj: 创建通知公告参数 + :return: + """ async with async_db_session.begin() as db: await notice_dao.create(db, obj) @staticmethod async def update(*, pk: int, obj: UpdateNoticeParam) -> int: + """ + 更新通知公告 + + :param pk: 通知公告 ID + :param obj: 更新通知公告参数 + :return: + """ async with async_db_session.begin() as db: notice = await notice_dao.get(db, pk) if not notice: @@ -46,6 +69,12 @@ async def update(*, pk: int, obj: UpdateNoticeParam) -> int: @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除通知公告 + + :param pk: 通知公告 ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await notice_dao.delete(db, pk) return count diff --git a/backend/plugin/tools.py b/backend/plugin/tools.py index 842a321b1..2a71ccf74 100644 --- a/backend/plugin/tools.py +++ b/backend/plugin/tools.py @@ -6,6 +6,8 @@ import sys import warnings +from typing import Any + import rtoml from fastapi import APIRouter @@ -17,156 +19,196 @@ class PluginInjectError(Exception): - pass + """插件注入错误""" def get_plugins() -> list[str]: - """获取插件""" + """获取插件列表""" plugin_packages = [] + # 遍历插件目录 for item in os.listdir(PLUGIN_DIR): item_path = os.path.join(PLUGIN_DIR, item) - if os.path.isdir(item_path): - if '__init__.py' in os.listdir(item_path): - plugin_packages.append(item) + # 检查是否为目录且包含 __init__.py 文件 + if os.path.isdir(item_path) and '__init__.py' in os.listdir(item_path): + plugin_packages.append(item) return plugin_packages -def get_plugin_models() -> list: +def get_plugin_models() -> list[type]: """获取插件所有模型类""" classes = [] + + # 获取所有插件 plugins = get_plugins() + for plugin in plugins: + # 导入插件的模型模块 module_path = f'backend.plugin.{plugin}.model' module = import_module_cached(module_path) + + # 获取模块中的所有类 for name, obj in inspect.getmembers(module): if inspect.isclass(obj): classes.append(obj) + return classes -def plugin_router_inject() -> None: +def load_plugin_config(plugin: str) -> dict[str, Any]: """ - 插件路由注入 + 加载插件配置 + :param plugin: 插件名称 :return: """ - plugins = get_plugins() - for plugin in plugins: - toml_path = os.path.join(PLUGIN_DIR, plugin, 'plugin.toml') - if not os.path.exists(toml_path): - raise PluginInjectError(f'插件 {plugin} 缺少 plugin.toml 配置文件,请检查插件是否合法') - - # 获取 plugin.toml 配置 - with open(toml_path, 'r', encoding='utf-8') as f: - data = rtoml.load(f) - api = data.get('api', {}) - - # 非独立 app - if api: - app_include = data.get('app', {}).get('include', '') - if not app_include: - raise PluginInjectError(f'非独立 app 插件 {plugin} 配置文件存在错误,请检查') - - # 插件中 API 路由文件的路径 - plugin_api_path = os.path.join(PLUGIN_DIR, plugin, 'api') - if not os.path.exists(plugin_api_path): - raise PluginInjectError(f'插件 {plugin} 缺少 api 目录,请检查插件文件是否完整') - - # 将插件路由注入到对应模块的路由中 - for root, _, api_files in os.walk(plugin_api_path): - for file in api_files: - if file.endswith('.py') and file != '__init__.py': - # 解析插件路由配置 - prefix = data.get('api', {}).get(f'{file[:-3]}', {}).get('prefix', '') - tags = data.get('api', {}).get(f'{file[:-3]}', {}).get('tags', []) - - # 获取插件路由模块 - file_path = os.path.join(root, file) - path_to_module_str = os.path.relpath(file_path, PLUGIN_DIR).replace(os.sep, '.')[:-3] - module_path = f'backend.plugin.{path_to_module_str}' - try: - module = import_module_cached(module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入非独立 app 插件 {plugin} 模块 {module_path} 失败:{e}') from e - plugin_router = getattr(module, 'router', None) - if not plugin_router: - warnings.warn( - f'非独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,' - '请检查插件文件是否完整', - FutureWarning, - ) - continue - - # 获取源程序路由模块 - relative_path = os.path.relpath(root, plugin_api_path) - target_module_path = f'backend.app.{app_include}.api.{relative_path.replace(os.sep, ".")}' - try: - target_module = import_module_cached(target_module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入源程序模块 {target_module_path} 失败:{e}') from e - target_router = getattr(target_module, 'router', None) - if not target_router or not isinstance(target_router, APIRouter): - raise PluginInjectError( - f'非独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,' - '请检查插件文件是否完整' - ) - - # 将插件路由注入到目标 router 中 - target_router.include_router( - router=plugin_router, - prefix=prefix, - tags=[tags] if tags else [], - ) - # 独立 app - else: - # 将插件中的路由直接注入到总路由中 - module_path = f'backend.plugin.{plugin}.api.router' + toml_path = os.path.join(PLUGIN_DIR, plugin, 'plugin.toml') + if not os.path.exists(toml_path): + raise PluginInjectError(f'插件 {plugin} 缺少 plugin.toml 配置文件,请检查插件是否合法') + + with open(toml_path, 'r', encoding='utf-8') as f: + return rtoml.load(f) + + +def inject_extra_router(plugin: str, data: dict[str, Any]) -> None: + """ + 扩展级插件路由注入 + + :param plugin: 插件名称 + :param data: 插件配置数据 + :return: + """ + app_include = data.get('app', {}).get('include', '') + if not app_include: + raise PluginInjectError(f'扩展级插件 {plugin} 配置文件存在错误,请检查') + + plugin_api_path = os.path.join(PLUGIN_DIR, plugin, 'api') + if not os.path.exists(plugin_api_path): + raise PluginInjectError(f'插件 {plugin} 缺少 api 目录,请检查插件文件是否完整') + + for root, _, api_files in os.walk(plugin_api_path): + for file in api_files: + if not (file.endswith('.py') and file != '__init__.py'): + continue + + # 解析插件路由配置 + file_config = data.get('api', {}).get(f'{file[:-3]}', {}) + prefix = file_config.get('prefix', '') + tags = file_config.get('tags', []) + + # 获取插件路由模块 + file_path = os.path.join(root, file) + path_to_module_str = os.path.relpath(file_path, PLUGIN_DIR).replace(os.sep, '.')[:-3] + module_path = f'backend.plugin.{path_to_module_str}' + try: module = import_module_cached(module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入独立 app 插件 {plugin} 模块 {module_path} 失败:{e}') from e - routers = data.get('app', {}).get('router', []) - if not routers or not isinstance(routers, list): - raise PluginInjectError(f'独立 app 插件 {plugin} 配置文件存在错误,请检查') - for router in routers: - plugin_router = getattr(module, router, None) - if not plugin_router or not isinstance(plugin_router, APIRouter): - raise PluginInjectError( - f'独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + plugin_router = getattr(module, 'router', None) + if not plugin_router: + warnings.warn( + f'扩展级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整', + FutureWarning, ) - target_module_path = 'backend.app.router' + continue + + # 获取目标 app 路由 + relative_path = os.path.relpath(root, plugin_api_path) + target_module_path = f'backend.app.{app_include}.api.{relative_path.replace(os.sep, ".")}' target_module = import_module_cached(target_module_path) - target_router = getattr(target_module, 'router') + target_router = getattr(target_module, 'router', None) + + if not target_router or not isinstance(target_router, APIRouter): + raise PluginInjectError( + f'扩展级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + ) - # 将插件路由注入到目标 router 中 - target_router.include_router(plugin_router) + # 将插件路由注入到目标路由中 + target_router.include_router( + router=plugin_router, + prefix=prefix, + tags=[tags] if tags else [], + ) + except Exception as e: + raise PluginInjectError(f'扩展级插件 {plugin} 路由注入失败:{str(e)}') from e + + +def inject_app_router(plugin: str, data: dict[str, Any]) -> None: + """ + 应用级插件路由注入 + + :param plugin: 插件名称 + :param data: 插件配置数据 + :return: + """ + module_path = f'backend.plugin.{plugin}.api.router' + try: + module = import_module_cached(module_path) + routers = data.get('app', {}).get('router', []) + if not routers or not isinstance(routers, list): + raise PluginInjectError(f'应用级插件 {plugin} 配置文件存在错误,请检查') + + # 获取目标路由 + target_module = import_module_cached('backend.app.router') + target_router = getattr(target_module, 'router') + + for router in routers: + plugin_router = getattr(module, router, None) + if not plugin_router or not isinstance(plugin_router, APIRouter): + raise PluginInjectError( + f'应用级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + ) + + # 将插件路由注入到目标路由中 + target_router.include_router(plugin_router) + except Exception as e: + raise PluginInjectError(f'应用级插件 {plugin} 路由注入失败:{str(e)}') from e + + +def plugin_router_inject() -> None: + """插件路由注入""" + for plugin in get_plugins(): + data = load_plugin_config(plugin) + # 基于插件 plugin.toml 配置文件,判断插件类型 + if data.get('api'): + inject_extra_router(plugin, data) + else: + inject_app_router(plugin, data) + + +def _install_plugin_requirements(plugin: str, requirements_file: str) -> None: + """ + 安装单个插件的依赖 + + :param plugin: 插件名称 + :param requirements_file: 依赖文件路径 + :return: + """ + try: + ensurepip_install = [sys.executable, '-m', 'ensurepip', '--upgrade'] + pip_install = [sys.executable, '-m', 'pip', 'install', '-r', requirements_file] + if settings.PLUGIN_PIP_CHINA: + pip_install.extend(['-i', settings.PLUGIN_PIP_INDEX_URL]) + subprocess.check_call(ensurepip_install) + subprocess.check_call(pip_install) + except subprocess.CalledProcessError as e: + raise PluginInjectError(f'插件 {plugin} 依赖安装失败:{e.stderr}') from e def install_requirements() -> None: """安装插件依赖""" - plugins = get_plugins() - for plugin in plugins: + for plugin in get_plugins(): requirements_file = os.path.join(PLUGIN_DIR, plugin, 'requirements.txt') - if not os.path.exists(requirements_file): - continue - else: - try: - ensurepip_install = [sys.executable, '-m', 'ensurepip', '--upgrade'] - pip_install = [sys.executable, '-m', 'pip', 'install', '-r', requirements_file] - if settings.PLUGIN_PIP_CHINA: - pip_install.extend(['-i', settings.PLUGIN_PIP_INDEX_URL]) - subprocess.check_call(ensurepip_install) - subprocess.check_call(pip_install) - except subprocess.CalledProcessError as e: - raise PluginInjectError(f'插件 {plugin} 依赖安装失败:{e.stderr}') from e + if os.path.exists(requirements_file): + _install_plugin_requirements(plugin, requirements_file) async def install_requirements_async() -> None: """ - 异步安装插件依赖(由于 Windows 平台限制,无法实现完美的全异步方案),详情: + 异步安装插件依赖 + + 由于 Windows 平台限制,无法实现完美的全异步方案,详情: https://stackoverflow.com/questions/44633458/why-am-i-getting-notimplementederror-with-async-and-await-on-windows """ await run_in_threadpool(install_requirements) diff --git a/backend/templates/py/api.jinja b/backend/templates/py/api.jinja index 9c5798096..226f5208a 100644 --- a/backend/templates/py/api.jinja +++ b/backend/templates/py/api.jinja @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- from typing import Annotated +from fastapi import APIRouter, Depends, Path, Query + from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Get{{ schema_name }}Detail, Update{{ schema_name }}Param from backend.app.{{ app_name }}.service.{{ table_name_en }}_service import {{ table_name_en }}_service from backend.common.pagination import DependsPagination, PageData, paging_data @@ -10,26 +12,25 @@ from backend.common.security.jwt import DependsJwtAuth from backend.common.security.permission import RequestPermission from backend.common.security.rbac import DependsRBAC from backend.database.db import CurrentSession -from fastapi import APIRouter, Depends, Path, Query router = APIRouter() @router.get('/{pk}', summary='获取{{ table_simple_name_zh }}详情', dependencies=[DependsJwtAuth]) -async def get_{{ table_name_en }}(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[Get{{ schema_name }}Detail]: +async def get_{{ table_name_en }}(pk: Annotated[int, Path(description='{{ table_simple_name_zh }} ID')]) -> ResponseSchemaModel[Get{{ schema_name }}Detail]: {{ table_name_en }} = await {{ table_name_en }}_service.get(pk=pk) return response_base.success(data={{ table_name_en }}) @router.get( '', - summary='(模糊条件)分页获取所有{{ table_simple_name_zh }}', + summary='分页获取所有{{ table_simple_name_zh }}', dependencies=[ DependsJwtAuth, DependsPagination, ], ) -async def get_pagination_{{ table_name_en }}(db: CurrentSession) -> ResponseSchemaModel[PageData[Get{{ schema_name }}Detail]]: +async def get_pagination_{{ table_name_en }}s(db: CurrentSession) -> ResponseSchemaModel[PageData[Get{{ schema_name }}Detail]]: {{ table_name_en }}_select = await {{ table_name_en }}_service.get_select() page_data = await paging_data(db, {{ table_name_en }}_select) return response_base.success(data=page_data) @@ -56,7 +57,7 @@ async def create_{{ table_name_en }}(obj: Create{{ schema_name }}Param) -> Respo DependsRBAC, ], ) -async def update_{{ table_name_en }}(pk: Annotated[int, Path(...)], obj: Update{{ schema_name }}Param) -> ResponseModel: +async def update_{{ table_name_en }}(pk: Annotated[int, Path(description='{{ table_simple_name_zh }} ID')], obj: Update{{ schema_name }}Param) -> ResponseModel: count = await {{ table_name_en }}_service.update(pk=pk, obj=obj) if count > 0: return response_base.success() @@ -65,13 +66,13 @@ async def update_{{ table_name_en }}(pk: Annotated[int, Path(...)], obj: Update{ @router.delete( '', - summary='(批量)删除{{ table_simple_name_zh }}', + summary='批量删除{{ table_simple_name_zh }}', dependencies=[ Depends(RequestPermission('{{ permission }}:del')), DependsRBAC, ], ) -async def delete_{{ table_name_en }}(pk: Annotated[list[int], Query(...)]) -> ResponseModel: +async def delete_{{ table_name_en }}(pk: Annotated[list[int], Query(description='{{ table_simple_name_zh }} ID 列表')]) -> ResponseModel: count = await {{ table_name_en }}_service.delete(pk=pk) if count > 0: return response_base.success() diff --git a/backend/templates/py/crud.jinja b/backend/templates/py/crud.jinja index 73d32f103..bf32c74a8 100644 --- a/backend/templates/py/crud.jinja +++ b/backend/templates/py/crud.jinja @@ -2,71 +2,68 @@ # -*- coding: utf-8 -*- from typing import Sequence -from backend.app.{{ app_name }}.model import {{ table_name_class }} -from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus import CRUDPlus +from backend.app.{{ app_name }}.model import {{ table_name_class }} +from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param + class CRUD{{ table_name_class }}(CRUDPlus[{{ schema_name }}]): async def get(self, db: AsyncSession, pk: int) -> {{ table_name_class }} | None: """ 获取{{ table_name_zh }} - :param db: - :param pk: + :param db: 数据库会话 + :param pk: {{ table_simple_name_zh }} ID :return: """ return await self.select_model(db, pk) async def get_list(self) -> Select: - """ - 获取{{ table_name_zh }}列表 - - :return: - """ + """获取{{ table_name_zh }}列表""" return await self.select_order('created_time', 'desc') async def get_all(self, db: AsyncSession) -> Sequence[{{ table_name_class }}]: """ 获取所有{{ table_name_zh }} - :param db: + :param db: 数据库会话 :return: """ return await self.select_models(db) - async def create(self, db: AsyncSession, obj_in: Create{{ schema_name }}Param) -> None: + async def create(self, db: AsyncSession, obj: Create{{ schema_name }}Param) -> None: """ 创建{{ table_name_zh }} - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj: 创建{{ table_simple_name_zh }} 参数 :return: """ - await self.create_model(db, obj_in) + await self.create_model(db, obj) - async def update(self, db: AsyncSession, pk: int, obj_in: Update{{ schema_name }}Param) -> int: + async def update(self, db: AsyncSession, pk: int, obj: Update{{ schema_name }}Param) -> int: """ 更新{{ table_name_zh }} - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param pk: {{ table_simple_name_zh }} ID + :param obj: 更新 {{ table_simple_name_zh }} 参数 :return: """ - return await self.update_model(db, pk, obj_in) + return await self.update_model(db, pk, obj) async def delete(self, db: AsyncSession, pk: list[int]) -> int: """ 删除{{ table_name_zh }} - :param db: - :param pk: + :param db: 数据库会话 + :param pk: {{ table_simple_name_zh }} ID :return: """ - return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) + return await self.delete_model_by_column(db, allow_multiple=True, id__in=pk) {{ table_name_en }}_dao: CRUD{{ table_name_class }} = CRUD{{ table_name_class }}({{ table_name_class }}) diff --git a/backend/templates/py/model.jinja b/backend/templates/py/model.jinja index 6d3f7aa26..85f77dd8a 100644 --- a/backend/templates/py/model.jinja +++ b/backend/templates/py/model.jinja @@ -5,7 +5,6 @@ from uuid import UUID import sqlalchemy as sa -from backend.common.model import {% if default_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key {% if database_type == 'mysql' -%} from sqlalchemy.dialects import mysql {% else -%} @@ -13,6 +12,8 @@ from sqlalchemy.dialects import postgresql {% endif -%} from sqlalchemy.orm import Mapped, mapped_column +from backend.common.model import {% if default_datetime_column %}Base{% else %}MappedBase{% endif %}, id_key + class {{ table_name_class }}({% if default_datetime_column %}Base{% else %}MappedBase{% endif %}): """{{ table_name_zh }}""" diff --git a/backend/templates/py/schema.jinja b/backend/templates/py/schema.jinja index 6a1eb20b9..c807e6146 100644 --- a/backend/templates/py/schema.jinja +++ b/backend/templates/py/schema.jinja @@ -2,27 +2,30 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase class {{ schema_name }}SchemaBase(SchemaBase): + """{{ table_simple_name_zh }}基础模型""" {% for model in models %} - {{ model.name }}: {% if model.nullable %}{{ model.pd_type }} | None = None{% else %}{{ model.pd_type }}{% endif %} + {{ model.name }}: {% if model.nullable %}{{ model.pd_type }} | None = Field(None, description='{{ model.comment }}'){% else %}{{ model.pd_type }} = Field(description='{{ model.comment }}'){% endif %} {% endfor %} class Create{{ schema_name }}Param({{ schema_name }}SchemaBase): - pass + """创建{{ table_simple_name_zh }}参数""" class Update{{ schema_name }}Param({{ schema_name }}SchemaBase): - pass + """更新{{ table_simple_name_zh }}参数""" class Get{{ schema_name }}Detail({{ schema_name }}SchemaBase): + """{{ table_simple_name_zh }}详情""" + model_config = ConfigDict(from_attributes=True) id: int diff --git a/backend/templates/py/service.jinja b/backend/templates/py/service.jinja index 576b9c683..c8cb39c6f 100644 --- a/backend/templates/py/service.jinja +++ b/backend/templates/py/service.jinja @@ -2,17 +2,24 @@ # -*- coding: utf-8 -*- from typing import Sequence +from sqlalchemy import Select + from backend.app.{{ app_name }}.crud.crud_{{ table_name_en }} import {{ table_name_en }}_dao from backend.app.{{ app_name }}.model import {{ table_name_class }} from backend.app.{{ app_name }}.schema.{{ table_name_en }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param from backend.common.exception import errors from backend.database.db import async_db_session -from sqlalchemy import Select class {{ table_name_class }}Service: @staticmethod async def get(*, pk: int) -> {{ table_name_class }}: + """ + 获取{{ table_simple_name_zh }} + + :param pk: {{ table_simple_name_zh }} ID + :return: + """ async with async_db_session() as db: {{ table_name_en }} = await {{ table_name_en }}_dao.get(db, pk) if not {{ table_name_en }}: @@ -21,27 +28,48 @@ class {{ table_name_class }}Service: @staticmethod async def get_select() -> Select: + """获取{{ table_simple_name_zh }}查询对象""" return await {{ table_name_en }}_dao.get_list() @staticmethod async def get_all() -> Sequence[{{ table_name_class }}]: + """获取所有{{ table_simple_name_zh }}""" async with async_db_session() as db: {{ table_name_en }}s = await {{ table_name_en }}_dao.get_all(db) return {{ table_name_en }}s @staticmethod async def create(*, obj: Create{{ schema_name }}Param) -> None: + """ + 创建{{ table_simple_name_zh }} + + :param obj: 创建{{ table_simple_name_zh }}参数 + :return: + """ async with async_db_session.begin() as db: await {{ table_name_en }}_dao.create(db, obj) @staticmethod async def update(*, pk: int, obj: Update{{ schema_name }}Param) -> int: + """ + 更新{{ table_simple_name_zh }} + + :param pk: {{ table_simple_name_zh }} ID + :param obj: 更新{{ table_simple_name_zh }}参数 + :return: + """ async with async_db_session.begin() as db: count = await {{ table_name_en }}_dao.update(db, pk, obj) return count @staticmethod async def delete(*, pk: list[int]) -> int: + """ + 删除{{ table_simple_name_zh }} + + :param pk: {{ table_simple_name_zh }} ID 列表 + :return: + """ async with async_db_session.begin() as db: count = await {{ table_name_en }}_dao.delete(db, pk) return count diff --git a/backend/utils/build_tree.py b/backend/utils/build_tree.py index 4a948ad10..c227a37c2 100644 --- a/backend/utils/build_tree.py +++ b/backend/utils/build_tree.py @@ -7,7 +7,12 @@ def get_tree_nodes(row: Sequence[RowData]) -> list[dict[str, Any]]: - """获取所有树形结构节点""" + """ + 获取所有树形结构节点 + + :param row: 原始数据行序列 + :return: + """ tree_nodes = select_list_serialize(row) tree_nodes.sort(key=lambda x: x['sort']) return tree_nodes @@ -17,10 +22,10 @@ def traversal_to_tree(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 通过遍历算法构造树形结构 - :param nodes: + :param nodes: 树节点列表 :return: """ - tree = [] + tree: list[dict[str, Any]] = [] node_dict = {node['id']: node for node in nodes} for node in nodes: @@ -45,16 +50,16 @@ def recursive_to_tree(nodes: list[dict[str, Any]], *, parent_id: int | None = No """ 通过递归算法构造树形结构(性能影响较大) - :param nodes: - :param parent_id: + :param nodes: 树节点列表 + :param parent_id: 父节点 ID,默认为 None 表示根节点 :return: """ - tree = [] + tree: list[dict[str, Any]] = [] for node in nodes: if node['parent_id'] == parent_id: - child_node = recursive_to_tree(nodes, parent_id=node['id']) - if child_node: - node['children'] = child_node + child_nodes = recursive_to_tree(nodes, parent_id=node['id']) + if child_nodes: + node['children'] = child_nodes tree.append(node) return tree @@ -65,9 +70,9 @@ def get_tree_data( """ 获取树形结构数据 - :param row: - :param build_type: - :param parent_id: + :param row: 原始数据行序列 + :param build_type: 构建树形结构的算法类型,默认为遍历算法 + :param parent_id: 父节点 ID,仅在递归算法中使用 :return: """ nodes = get_tree_nodes(row) diff --git a/backend/utils/demo_site.py b/backend/utils/demo_site.py index 4792b1331..039577c03 100644 --- a/backend/utils/demo_site.py +++ b/backend/utils/demo_site.py @@ -6,9 +6,13 @@ from backend.core.conf import settings -async def demo_site(request: Request): - """演示站点""" +async def demo_site(request: Request) -> None: + """ + 演示站点 + :param request: FastAPI 请求对象 + :return: + """ method = request.method path = request.url.path if ( diff --git a/backend/utils/encrypt.py b/backend/utils/encrypt.py index 91f6db7c8..855b026c4 100644 --- a/backend/utils/encrypt.py +++ b/backend/utils/encrypt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import hashlib import os from typing import Any @@ -13,9 +14,14 @@ class AESCipher: - def __init__(self, key: bytes | str): + """AES 加密器""" + + def __init__(self, key: bytes | str) -> None: """ + 初始化 AES 加密器 + :param key: 密钥,16/24/32 bytes 或 16 进制字符串 + :return: """ self.key = key if isinstance(key, bytes) else bytes.fromhex(key) @@ -40,7 +46,7 @@ def decrypt(self, ciphertext: bytes | str) -> str: """ AES 解密 - :param ciphertext: 解密前的密文, bytes 或 16 进制字符串 + :param ciphertext: 解密前的密文,bytes 或 16 进制字符串 :return: """ ciphertext = ciphertext if isinstance(ciphertext, bytes) else bytes.fromhex(ciphertext) @@ -55,6 +61,8 @@ def decrypt(self, ciphertext: bytes | str) -> str: class Md5Cipher: + """MD5 加密器""" + @staticmethod def encrypt(plaintext: bytes | str) -> str: """ @@ -63,8 +71,6 @@ def encrypt(plaintext: bytes | str) -> str: :param plaintext: 加密前的明文 :return: """ - import hashlib - md5 = hashlib.md5() if not isinstance(plaintext, bytes): plaintext = str(plaintext).encode('utf-8') @@ -73,15 +79,20 @@ def encrypt(plaintext: bytes | str) -> str: class ItsDCipher: - def __init__(self, key: bytes | str): + """ItsDangerous 加密器""" + + def __init__(self, key: bytes | str) -> None: """ + 初始化 ItsDangerous 加密器 + :param key: 密钥,16/24/32 bytes 或 16 进制字符串 + :return: """ self.key = key if isinstance(key, bytes) else bytes.fromhex(key) def encrypt(self, plaintext: Any) -> str: """ - ItsDangerous 加密 (可能失败,如果 plaintext 无法序列化,则会加密为 MD5) + ItsDangerous 加密 :param plaintext: 加密前的明文 :return: @@ -96,7 +107,7 @@ def encrypt(self, plaintext: Any) -> str: def decrypt(self, ciphertext: str) -> Any: """ - ItsDangerous 解密 (可能失败,如果 ciphertext 无法反序列化,则解密失败, 返回原始密文) + ItsDangerous 解密 :param ciphertext: 解密前的密文 :return: diff --git a/backend/utils/file_ops.py b/backend/utils/file_ops.py index 0e88fc57f..04cd5aa81 100644 --- a/backend/utils/file_ops.py +++ b/backend/utils/file_ops.py @@ -14,11 +14,11 @@ from backend.utils.timezone import timezone -def build_filename(file: UploadFile): +def build_filename(file: UploadFile) -> str: """ 构建文件名 - :param file: + :param file: FastAPI 上传文件对象 :return: """ timestamp = int(timezone.now().timestamp()) @@ -32,14 +32,15 @@ def file_verify(file: UploadFile, file_type: FileType) -> None: """ 文件验证 - :param file: - :param file_type: + :param file: FastAPI 上传文件对象 + :param file_type: 文件类型枚举 :return: """ filename = file.filename file_ext = filename.split('.')[-1].lower() if not file_ext: raise errors.ForbiddenError(msg='未知的文件类型') + if file_type == FileType.image: if file_ext not in settings.UPLOAD_IMAGE_EXT_INCLUDE: raise errors.ForbiddenError(msg='此图片格式暂不支持') @@ -52,11 +53,11 @@ def file_verify(file: UploadFile, file_type: FileType) -> None: raise errors.ForbiddenError(msg='视频超出最大限制,请重新选择') -async def upload_file(file: UploadFile): +async def upload_file(file: UploadFile) -> str: """ 上传文件 - :param file: + :param file: FastAPI 上传文件对象 :return: """ filename = build_filename(file) diff --git a/backend/utils/gen_template.py b/backend/utils/gen_template.py index 056a2a0b6..c1c0fb78e 100644 --- a/backend/utils/gen_template.py +++ b/backend/utils/gen_template.py @@ -12,7 +12,8 @@ class GenTemplate: - def __init__(self): + def __init__(self) -> None: + """初始化模板生成器""" self.env = Environment( loader=FileSystemLoader(JINJA2_TEMPLATE_DIR), autoescape=select_autoescape(enabled_extensions=['jinja']), @@ -25,18 +26,17 @@ def __init__(self): def get_template(self, jinja_file: str) -> Template: """ - 获取模版文件 + 获取模板文件 - :param jinja_file: + :param jinja_file: Jinja2 模板文件 :return: """ - return self.env.get_template(jinja_file) @staticmethod def get_template_paths() -> list[str]: """ - 获取模版文件路径 + 获取模板文件路径列表 :return: """ @@ -53,26 +53,25 @@ def get_code_gen_paths(business: GenBusiness) -> list[str]: """ 获取代码生成路径列表 - :param business: + :param business: 代码生成业务对象 :return: """ app_name = business.app_name module_name = business.table_name_en - target_files = [ + return [ f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/api/{business.api_version}/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/crud/crud_{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/model/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/schema/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/service/{module_name}_service.py', ] - return target_files def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str: """ 获取代码生成路径 - :param tpl_path: - :param business: + :param tpl_path: 模板文件路径 + :param business: 代码生成业务对象 :return: """ target_files = self.get_code_gen_paths(business) @@ -80,12 +79,12 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str: return code_gen_path_mapping[tpl_path] @staticmethod - def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict: + def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict[str, str | Sequence[GenModel]]: """ - 获取模版变量 + 获取模板变量 - :param business: - :param models: + :param business: 代码生成业务对象 + :param models: 代码生成模型对象列表 :return: """ return { diff --git a/backend/utils/health_check.py b/backend/utils/health_check.py index 2ca4d7c2d..1a68a1509 100644 --- a/backend/utils/health_check.py +++ b/backend/utils/health_check.py @@ -12,7 +12,7 @@ def ensure_unique_route_names(app: FastAPI) -> None: """ 检查路由名称是否唯一 - :param app: + :param app: FastAPI 应用实例 :return: """ temp_routes = set() @@ -23,13 +23,13 @@ def ensure_unique_route_names(app: FastAPI) -> None: temp_routes.add(route.name) -async def http_limit_callback(request: Request, response: Response, expire: int): +async def http_limit_callback(request: Request, response: Response, expire: int) -> None: """ 请求限制时的默认回调函数 - :param request: - :param response: - :param expire: 剩余毫秒 + :param request: FastAPI 请求对象 + :param response: FastAPI 响应对象 + :param expire: 剩余毫秒数 :return: """ expires = ceil(expire / 1000) diff --git a/backend/utils/import_parse.py b/backend/utils/import_parse.py index 9c50a0377..55a3e26ef 100644 --- a/backend/utils/import_parse.py +++ b/backend/utils/import_parse.py @@ -3,36 +3,36 @@ import importlib from functools import lru_cache -from typing import Any +from typing import Any, Type, TypeVar from backend.common.exception import errors from backend.common.log import log +T = TypeVar('T') + @lru_cache(maxsize=512) def import_module_cached(module_path: str) -> Any: """ 缓存导入模块 - :param module_path: + :param module_path: 模块路径 :return: """ return importlib.import_module(module_path) -def dynamic_import_data_model(module_path: str) -> Any: +def dynamic_import_data_model(module_path: str) -> Type[T]: """ 动态导入数据模型 - :param module_path: + :param module_path: 模块路径,格式为 'module_path.class_name' :return: """ - module_path, class_or_func = module_path.rsplit('.', 1) - try: + module_path, class_name = module_path.rsplit('.', 1) module = import_module_cached(module_path) - ins = getattr(module, class_or_func) + return getattr(module, class_name) except (ImportError, AttributeError) as e: - log.error(e) + log.error(f'动态导入数据模型失败:{e}') raise errors.ServerError(msg='数据模型列动态解析失败,请联系系统超级管理员') - return ins diff --git a/backend/utils/openapi.py b/backend/utils/openapi.py index fc53e82cb..b5db5c130 100644 --- a/backend/utils/openapi.py +++ b/backend/utils/openapi.py @@ -6,9 +6,9 @@ def simplify_operation_ids(app: FastAPI) -> None: """ - 简化操作 ID,以便生成的客户端具有更简单的 api 函数名称 + 简化操作 ID,以便生成的客户端具有更简单的 API 函数名称 - :param app: + :param app: FastAPI 应用实例 :return: """ for route in app.routes: diff --git a/backend/utils/re_verify.py b/backend/utils/re_verify.py index b85f3f9ff..b2d8eb226 100644 --- a/backend/utils/re_verify.py +++ b/backend/utils/re_verify.py @@ -3,41 +3,45 @@ import re -def search_string(pattern, text) -> bool: +def search_string(pattern: str, text: str) -> bool: """ 全字段正则匹配 - :param pattern: - :param text: + :param pattern: 正则表达式模式 + :param text: 待匹配的文本 :return: """ - result = re.search(pattern, text) - if result: - return True - else: + if not pattern or not text: return False + result = re.search(pattern, text) + return result is not None + -def match_string(pattern, text) -> bool: +def match_string(pattern: str, text: str) -> bool: """ 从字段开头正则匹配 - :param pattern: - :param text: + :param pattern: 正则表达式模式 + :param text: 待匹配的文本 :return: """ - result = re.match(pattern, text) - if result: - return True - else: + if not pattern or not text: return False + result = re.match(pattern, text) + return result is not None + def is_phone(text: str) -> bool: """ - 检查手机号码 + 检查手机号码格式 - :param text: + :param text: 待检查的手机号码 :return: """ - return match_string(r'^1[3-9]\d{9}$', text) + if not text: + return False + + phone_pattern = r'^1[3-9]\d{9}$' + return match_string(phone_pattern, text) diff --git a/backend/utils/redis_info.py b/backend/utils/redis_info.py index aab12eb74..da4129226 100644 --- a/backend/utils/redis_info.py +++ b/backend/utils/redis_info.py @@ -6,27 +6,48 @@ class RedisInfo: @staticmethod - async def get_info(): + async def get_info() -> dict[str, str]: + """获取 Redis 服务器信息""" + + # 获取原始信息 info = await redis_client.info() - fmt_info = {} + + # 格式化信息 + fmt_info: dict[str, str] = {} for key, value in info.items(): if isinstance(value, dict): - value = ','.join({f'{k}={v}' for k, v in value.items()}) + # 将字典格式化为字符串 + fmt_info[key] = ','.join(f'{k}={v}' for k, v in value.items()) else: - value = str(value) - fmt_info[key] = value + fmt_info[key] = str(value) + + # 添加数据库大小信息 db_size = await redis_client.dbsize() - fmt_info.update({'keys_num': db_size}) - fmt_uptime = server_info.fmt_seconds(fmt_info.get('uptime_in_seconds', 0)) - fmt_info.update({'uptime_in_seconds': fmt_uptime}) + fmt_info['keys_num'] = str(db_size) + + # 格式化运行时间 + uptime = int(fmt_info.get('uptime_in_seconds', '0')) + fmt_info['uptime_in_seconds'] = server_info.fmt_seconds(uptime) + return fmt_info @staticmethod - async def get_stats(): - stats_list = [] + async def get_stats() -> list[dict[str, str]]: + """获取 Redis 命令统计信息""" + + # 获取命令统计信息 command_stats = await redis_client.info('commandstats') - for k, v in command_stats.items(): - stats_list.append({'name': k.split('_')[-1], 'value': str(v.get('calls', ''))}) + + # 格式化统计信息 + stats_list: list[dict[str, str]] = [] + for key, value in command_stats.items(): + if not isinstance(value, dict): + continue + + command_name = key.split('_')[-1] + call_count = str(value.get('calls', '0')) + stats_list.append({'name': command_name, 'value': call_count}) + return stats_list diff --git a/backend/utils/request_parse.py b/backend/utils/request_parse.py index 2ec142c1b..e0bc382a7 100644 --- a/backend/utils/request_parse.py +++ b/backend/utils/request_parse.py @@ -15,28 +15,32 @@ def get_request_ip(request: Request) -> str: - """获取请求的 ip 地址""" + """ + 获取请求的 IP 地址 + + :param request: FastAPI 请求对象 + :return: + """ real = request.headers.get('X-Real-IP') if real: - ip = real - else: - forwarded = request.headers.get('X-Forwarded-For') - if forwarded: - ip = forwarded.split(',')[0] - else: - ip = request.client.host + return real + + forwarded = request.headers.get('X-Forwarded-For') + if forwarded: + return forwarded.split(',')[0] + # 忽略 pytest - if ip == 'testclient': - ip = '127.0.0.1' - return ip + if request.client.host == 'testclient': + return '127.0.0.1' + return request.client.host async def get_location_online(ip: str, user_agent: str) -> dict | None: """ - 在线获取 ip 地址属地,无法保证可用性,准确率较高 + 在线获取 IP 地址属地,无法保证可用性,准确率较高 - :param ip: - :param user_agent: + :param ip: IP 地址 + :param user_agent: 用户代理字符串 :return: """ async with httpx.AsyncClient(timeout=3) as client: @@ -47,16 +51,16 @@ async def get_location_online(ip: str, user_agent: str) -> dict | None: if response.status_code == 200: return response.json() except Exception as e: - log.error(f'在线获取 ip 地址属地失败,错误信息:{e}') + log.error(f'在线获取 IP 地址属地失败,错误信息:{e}') return None @sync_to_async def get_location_offline(ip: str) -> dict | None: """ - 离线获取 ip 地址属地,无法保证准确率,100%可用 + 离线获取 IP 地址属地,无法保证准确率,100% 可用 - :param ip: + :param ip: IP 地址 :return: """ try: @@ -71,23 +75,30 @@ def get_location_offline(ip: str) -> dict | None: 'city': data[3] if data[3] != '0' else None, } except Exception as e: - log.error(f'离线获取 ip 地址属地失败,错误信息:{e}') + log.error(f'离线获取 IP 地址属地失败,错误信息:{e}') return None async def parse_ip_info(request: Request) -> IpInfo: + """ + 解析请求的 IP 信息 + + :param request: FastAPI 请求对象 + :return: + """ country, region, city = None, None, None ip = get_request_ip(request) location = await redis_client.get(f'{settings.IP_LOCATION_REDIS_PREFIX}:{ip}') if location: country, region, city = location.split('|') return IpInfo(ip=ip, country=country, region=region, city=city) + + location_info = None if settings.IP_LOCATION_PARSE == 'online': location_info = await get_location_online(ip, request.headers.get('User-Agent')) elif settings.IP_LOCATION_PARSE == 'offline': location_info = await get_location_offline(ip) - else: - location_info = None + if location_info: country = location_info.get('country') region = location_info.get('regionName') @@ -101,6 +112,12 @@ async def parse_ip_info(request: Request) -> IpInfo: def parse_user_agent_info(request: Request) -> UserAgentInfo: + """ + 解析请求的用户代理信息 + + :param request: FastAPI 请求对象 + :return: + """ user_agent = request.headers.get('User-Agent') _user_agent = parse(user_agent) os = _user_agent.get_os() diff --git a/backend/utils/serializers.py b/backend/utils/serializers.py index 9d9cd6e27..ad2552e85 100644 --- a/backend/utils/serializers.py +++ b/backend/utils/serializers.py @@ -14,43 +14,38 @@ R = TypeVar('R', bound=RowData) -def select_columns_serialize(row: R) -> dict: +def select_columns_serialize(row: R) -> dict[str, Any]: """ - Serialize SQLAlchemy select table columns, does not contain relational columns + 序列化 SQLAlchemy 查询表的列,不包含关联列 - :param row: + :param row: SQLAlchemy 查询结果行 :return: """ result = {} for column in row.__table__.columns.keys(): - v = getattr(row, column) - if isinstance(v, Decimal): - v = decimal_encoder(v) - result[column] = v + value = getattr(row, column) + if isinstance(value, Decimal): + value = decimal_encoder(value) + result[column] = value return result def select_list_serialize(row: Sequence[R]) -> list[dict[str, Any]]: """ - Serialize SQLAlchemy select list + 序列化 SQLAlchemy 查询列表 - :param row: + :param row: SQLAlchemy 查询结果列表 :return: """ - result = [select_columns_serialize(_) for _ in row] - return result + return [select_columns_serialize(item) for item in row] -def select_as_dict(row: R, use_alias: bool = False) -> dict: +def select_as_dict(row: R, use_alias: bool = False) -> dict[str, Any]: """ - Converting SQLAlchemy select to dict, which can contain relational data, - depends on the properties of the select object itself - - If set use_alias is True, the column name will be returned as alias, - If alias doesn't exist in columns, we don't recommend setting it to True + 将 SQLAlchemy 查询结果转换为字典,可以包含关联数据 - :param row: - :param use_alias: + :param row: SQLAlchemy 查询结果行 + :param use_alias: 是否使用别名作为列名 :return: """ if not use_alias: @@ -70,7 +65,7 @@ def select_as_dict(row: R, use_alias: bool = False) -> dict: class MsgSpecJSONResponse(JSONResponse): """ - JSON response using the high-performance msgspec library to serialize data to JSON. + 使用高性能的 msgspec 库将数据序列化为 JSON 的响应类 """ def render(self, content: Any) -> bytes: diff --git a/backend/utils/server_info.py b/backend/utils/server_info.py index 2b2fd516a..161986d10 100644 --- a/backend/utils/server_info.py +++ b/backend/utils/server_info.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import os import platform import socket @@ -5,7 +7,6 @@ from datetime import datetime, timedelta from datetime import timezone as tz -from typing import List import psutil @@ -14,8 +15,13 @@ class ServerInfo: @staticmethod - def format_bytes(size) -> str: - """格式化字节""" + def format_bytes(size: int | float) -> str: + """ + 格式化字节大小 + + :param size: 字节大小 + :return: + """ factor = 1024 for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: if abs(size) < factor: @@ -25,64 +31,64 @@ def format_bytes(size) -> str: @staticmethod def fmt_seconds(seconds: int) -> str: + """ + 格式化秒数为可读的时间字符串 + + :param seconds: 秒数 + :return: + """ days, rem = divmod(int(seconds), 86400) hours, rem = divmod(rem, 3600) minutes, seconds = divmod(rem, 60) + parts = [] if days: - parts.append('{} 天'.format(days)) + parts.append(f'{days} 天') if hours: - parts.append('{} 小时'.format(hours)) + parts.append(f'{hours} 小时') if minutes: - parts.append('{} 分钟'.format(minutes)) + parts.append(f'{minutes} 分钟') if seconds: - parts.append('{} 秒'.format(seconds)) - if len(parts) == 0: - return '0 秒' - else: - return ' '.join(parts) + parts.append(f'{seconds} 秒') + + return ' '.join(parts) if parts else '0 秒' @staticmethod def fmt_timedelta(td: timedelta) -> str: - """格式化时间差""" + """ + 格式化时间差 + + :param td: 时间差对象 + :return: + """ total_seconds = round(td.total_seconds()) return ServerInfo.fmt_seconds(total_seconds) @staticmethod - def get_cpu_info() -> dict: + def get_cpu_info() -> dict[str, float | int]: """获取 CPU 信息""" cpu_info = {'usage': round(psutil.cpu_percent(percpu=False), 2)} # % - # 检查是否是 Apple M系列芯片 - if platform.system() == 'Darwin' and 'arm' in platform.machine().lower(): - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 - else: - try: - # CPU 频率信息,最大、最小和当前频率 - cpu_freq = psutil.cpu_freq() - cpu_info['max_freq'] = round(cpu_freq.max, 2) # MHz - cpu_info['min_freq'] = round(cpu_freq.min, 2) # MHz - cpu_info['current_freq'] = round(cpu_freq.current, 2) # MHz - except FileNotFoundError: - # 处理无法获取频率的情况 - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 - except AttributeError: - # 处理属性不存在的情况(更安全的做法) - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 + try: + # CPU 频率信息,最大、最小和当前频率 + cpu_freq = psutil.cpu_freq() + cpu_info.update({ + 'max_freq': round(cpu_freq.max, 2), # MHz + 'min_freq': round(cpu_freq.min, 2), # MHz + 'current_freq': round(cpu_freq.current, 2), # MHz + }) + except Exception: + cpu_info.update({'max_freq': 0, 'min_freq': 0, 'current_freq': 0}) # CPU 逻辑核心数,物理核心数 - cpu_info['logical_num'] = psutil.cpu_count(logical=True) - cpu_info['physical_num'] = psutil.cpu_count(logical=False) + cpu_info.update({ + 'logical_num': psutil.cpu_count(logical=True), + 'physical_num': psutil.cpu_count(logical=False), + }) return cpu_info @staticmethod - def get_mem_info() -> dict: + def get_mem_info() -> dict[str, float]: """获取内存信息""" mem = psutil.virtual_memory() return { @@ -93,7 +99,7 @@ def get_mem_info() -> dict: } @staticmethod - def get_sys_info() -> dict: + def get_sys_info() -> dict[str, str]: """获取服务器信息""" try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sk: @@ -101,6 +107,7 @@ def get_sys_info() -> dict: ip = sk.getsockname()[0] except socket.gaierror: ip = '127.0.0.1' + return { 'name': socket.gethostname(), 'ip': ip, @@ -109,7 +116,7 @@ def get_sys_info() -> dict: } @staticmethod - def get_disk_info() -> List[dict]: + def get_disk_info() -> list[dict[str, str]]: """获取磁盘信息""" disk_info = [] for disk in psutil.disk_partitions(): @@ -126,11 +133,12 @@ def get_disk_info() -> List[dict]: return disk_info @staticmethod - def get_service_info(): + def get_service_info() -> dict[str, str | datetime]: """获取服务信息""" process = psutil.Process(os.getpid()) mem_info = process.memory_info() start_time = timezone.f_datetime(datetime.utcfromtimestamp(process.create_time()).replace(tzinfo=tz.utc)) + return { 'name': 'Python3', 'version': platform.python_version(), @@ -140,7 +148,7 @@ def get_service_info(): 'mem_rss': ServerInfo.format_bytes(mem_info.rss), # 常驻内存, 即当前进程实际使用的物理内存 'mem_free': ServerInfo.format_bytes(mem_info.vms - mem_info.rss), # 空闲内存 'startup': start_time, - 'elapsed': f'{ServerInfo.fmt_timedelta(timezone.now() - start_time)}', + 'elapsed': ServerInfo.fmt_timedelta(timezone.now() - start_time), } diff --git a/backend/utils/timezone.py b/backend/utils/timezone.py index 6836d5bb9..a2d587a82 100644 --- a/backend/utils/timezone.py +++ b/backend/utils/timezone.py @@ -9,32 +9,34 @@ class TimeZone: - def __init__(self, tz: str = settings.DATETIME_TIMEZONE): - self.tz_info = zoneinfo.ZoneInfo(tz) - - def now(self) -> datetime: + def __init__(self, tz: str = settings.DATETIME_TIMEZONE) -> None: """ - 获取时区时间 + 初始化时区转换器 + :param tz: 时区名称,默认为 settings.DATETIME_TIMEZONE :return: """ + self.tz_info = zoneinfo.ZoneInfo(tz) + + def now(self) -> datetime: + """获取当前时区时间""" return datetime.now(self.tz_info) def f_datetime(self, dt: datetime) -> datetime: """ - datetime 时间转时区时间 + 将 datetime 对象转换为当前时区时间 - :param dt: + :param dt: 需要转换的 datetime 对象 :return: """ return dt.astimezone(self.tz_info) def f_str(self, date_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime: """ - 时间字符串转时区时间 + 将时间字符串转换为当前时区的 datetime 对象 - :param date_str: - :param format_str: + :param date_str: 时间字符串 + :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ return datetime.strptime(date_str, format_str).replace(tzinfo=self.tz_info) @@ -42,10 +44,10 @@ def f_str(self, date_str: str, format_str: str = settings.DATETIME_FORMAT) -> da @staticmethod def t_str(dt: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: """ - 时间转时间字符串 + 将 datetime 对象转换为指定格式的时间字符串 - :param dt: - :param format_str: + :param dt: datetime 对象 + :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ return dt.strftime(format_str) @@ -53,9 +55,9 @@ def t_str(dt: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: @staticmethod def f_utc(dt: datetime) -> datetime: """ - 时区时间转 UTC(GMT)时区 + 将 datetime 对象转换为 UTC (GMT) 时区时间 - :param dt: + :param dt: 需要转换的 datetime 对象 :return: """ return dt.astimezone(datetime_timezone.utc) diff --git a/backend/utils/trace_id.py b/backend/utils/trace_id.py index a2df0451b..4f683b9b6 100644 --- a/backend/utils/trace_id.py +++ b/backend/utils/trace_id.py @@ -6,4 +6,10 @@ def get_request_trace_id(request: Request) -> str: + """ + 从请求头中获取追踪 ID + + :param request: FastAPI 请求对象 + :return: + """ return request.headers.get(settings.TRACE_ID_REQUEST_HEADER_KEY) or settings.LOG_CID_DEFAULT_VALUE diff --git a/backend/utils/type_conversion.py b/backend/utils/type_conversion.py index 96500df9b..1c6de58d4 100644 --- a/backend/utils/type_conversion.py +++ b/backend/utils/type_conversion.py @@ -6,9 +6,9 @@ def sql_type_to_sqlalchemy(typing: str) -> str: """ - Converts a sql type to a SQLAlchemy type. + 将 SQL 类型转换为 SQLAlchemy 类型 - :param typing: + :param typing: SQL 类型字符串 :return: """ if settings.DATABASE_TYPE == 'mysql': @@ -22,17 +22,16 @@ def sql_type_to_sqlalchemy(typing: str) -> str: def sql_type_to_pydantic(typing: str) -> str: """ - Converts a sql type to a pydantic type. + 将 SQL 类型转换为 Pydantic 类型 - :param typing: + :param typing: SQL 类型字符串 :return: """ try: if settings.DATABASE_TYPE == 'mysql': return GenModelMySQLColumnType[typing].value - else: - if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名 - return 'str' - return GenModelPostgreSQLColumnType[typing].value + if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名 + return 'str' + return GenModelPostgreSQLColumnType[typing].value except KeyError: return 'str'