Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions backend/app/admin/api/v1/sys/casbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DeletePolicyParam,
DeleteUserRoleParam,
GetPolicyListDetails,
UpdatePoliciesParam,
UpdatePolicyParam,
)
from backend.app.admin.service.casbin_service import casbin_service
Expand Down Expand Up @@ -92,8 +93,8 @@ async def create_policies(ps: list[CreatePolicyParam]) -> ResponseModel:
DependsRBAC,
],
)
async def update_policy(old: UpdatePolicyParam, new: UpdatePolicyParam) -> ResponseModel:
data = await casbin_service.update_policy(old=old, new=new)
async def update_policy(obj: UpdatePolicyParam) -> ResponseModel:
data = await casbin_service.update_policy(obj=obj)
return response_base.success(data=data)


Expand All @@ -105,8 +106,8 @@ async def update_policy(old: UpdatePolicyParam, new: UpdatePolicyParam) -> Respo
DependsRBAC,
],
)
async def update_policies(old: list[UpdatePolicyParam], new: list[UpdatePolicyParam]) -> ResponseModel:
data = await casbin_service.update_policies(old=old, new=new)
async def update_policies(obj: UpdatePoliciesParam) -> ResponseModel:
data = await casbin_service.update_policies(obj=obj)
return response_base.success(data=data)


Expand Down
2 changes: 1 addition & 1 deletion backend/app/admin/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class ApiSchemaBase(SchemaBase):
name: str
method: MethodType = Field(default=MethodType.GET, description='请求方法')
path: str = Field(..., description='api路径')
path: str = Field(description='api路径')
remark: str | None = None


Expand Down
24 changes: 15 additions & 9 deletions backend/app/admin/schema/casbin_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@


class CreatePolicyParam(SchemaBase):
sub: str = Field(..., description='用户uuid / 角色ID')
path: str = Field(..., description='api 路径')
sub: str = Field(description='用户uuid / 角色ID')
path: str = Field(description='api 路径')
method: MethodType = Field(default=MethodType.GET, description='请求方法')


class UpdatePolicyParam(CreatePolicyParam):
pass
class UpdatePolicyParam(SchemaBase):
old: CreatePolicyParam
new: CreatePolicyParam


class UpdatePoliciesParam(SchemaBase):
old: list[CreatePolicyParam]
new: list[CreatePolicyParam]


class DeletePolicyParam(CreatePolicyParam):
Expand All @@ -26,8 +32,8 @@ class DeleteAllPoliciesParam(SchemaBase):


class CreateUserRoleParam(SchemaBase):
uuid: str = Field(..., description='用户 uuid')
role: str = Field(..., description='角色')
uuid: str = Field(description='用户 uuid')
role: str = Field(description='角色')


class DeleteUserRoleParam(CreateUserRoleParam):
Expand All @@ -38,9 +44,9 @@ class GetPolicyListDetails(SchemaBase):
model_config = ConfigDict(from_attributes=True)

id: int
ptype: str = Field(..., description='规则类型, p / g')
v0: str = Field(..., description='用户 uuid / 角色')
v1: str = Field(..., description='api 路径 / 角色')
ptype: str = Field(description='规则类型, p / g')
v0: str = Field(description='用户 uuid / 角色')
v1: str = Field(description='api 路径 / 角色')
v2: str | None = None
v3: str | None = None
v4: str | None = None
Expand Down
8 changes: 4 additions & 4 deletions backend/app/admin/schema/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ class AuthLoginParam(AuthSchemaBase):

class RegisterUserParam(AuthSchemaBase):
nickname: str | None = None
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])


class AddUserParam(AuthSchemaBase):
dept_id: int
roles: list[int]
nickname: str | None = None
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])


class UserInfoSchemaBase(SchemaBase):
dept_id: int | None = None
username: str
nickname: str
email: EmailStr = Field(..., examples=['[email protected]'])
email: EmailStr = Field(examples=['[email protected]'])
phone: CustomPhoneNumber | None = None


Expand All @@ -49,7 +49,7 @@ class UpdateUserRoleParam(SchemaBase):


class AvatarParam(SchemaBase):
url: HttpUrl = Field(..., description='头像 http 地址')
url: HttpUrl = Field(description='头像 http 地址')


class GetUserInfoNoRelationDetail(UserInfoSchemaBase):
Expand Down
15 changes: 10 additions & 5 deletions backend/app/admin/service/casbin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DeleteAllPoliciesParam,
DeletePolicyParam,
DeleteUserRoleParam,
UpdatePoliciesParam,
UpdatePolicyParam,
)
from backend.common.exception import errors
Expand Down Expand Up @@ -49,19 +50,23 @@ async def create_policies(*, ps: list[CreatePolicyParam]) -> bool:
return data

@staticmethod
async def update_policy(*, old: UpdatePolicyParam, new: UpdatePolicyParam) -> bool:
async def update_policy(*, obj: UpdatePolicyParam) -> bool:
old_obj = obj.old
new_obj = obj.new
enforcer = await rbac.enforcer()
_p = enforcer.has_policy(old.sub, old.path, old.method)
_p = enforcer.has_policy(old_obj.sub, old_obj.path, old_obj.method)
if not _p:
raise errors.NotFoundError(msg='权限不存在')
data = await enforcer.update_policy([old.sub, old.path, old.method], [new.sub, new.path, new.method])
data = await enforcer.update_policy(
[old_obj.sub, old_obj.path, old_obj.method], [new_obj.sub, new_obj.path, new_obj.method]
)
return data

@staticmethod
async def update_policies(*, old: list[UpdatePolicyParam], new: list[UpdatePolicyParam]) -> bool:
async def update_policies(*, obj: UpdatePoliciesParam) -> bool:
enforcer = await rbac.enforcer()
data = await enforcer.update_policies(
[list(o.model_dump().values()) for o in old], [list(n.model_dump().values()) for n in new]
[list(o.model_dump().values()) for o in obj.old], [list(n.model_dump().values()) for n in obj.new]
)
return data

Expand Down
11 changes: 4 additions & 7 deletions backend/app/generator/api/v1/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Body, Depends, Path, Query
from fastapi import APIRouter, Depends, Path, Query
from fastapi.responses import StreamingResponse

from backend.app.generator.conf import generator_settings
from backend.app.generator.schema.gen import ImportParam
from backend.app.generator.service.gen_service import gen_service
from backend.common.response.response_schema import ResponseModel, response_base
from backend.common.security.jwt import DependsJwtAuth
Expand All @@ -29,12 +30,8 @@ async def get_all_tables(table_schema: Annotated[str, Query(..., description='
DependsRBAC,
],
)
async def import_table(
app: Annotated[str, Body(..., description='应用名称,用于代码生成到指定 app')],
table_name: Annotated[str, Body(..., description='数据库表名')],
table_schema: Annotated[str, Body(..., description='数据库名')] = 'fba',
) -> ResponseModel:
await gen_service.import_business_and_model(app=app, table_schema=table_schema, table_name=table_name)
async def import_table(obj: ImportParam) -> ResponseModel:
await gen_service.import_business_and_model(obj=obj)
return response_base.success()


Expand Down
11 changes: 11 additions & 0 deletions backend/app/generator/schema/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field

from backend.common.schema import SchemaBase


class ImportParam(SchemaBase):
app: str = Field(description='应用名称,用于代码生成到指定 app')
table_name: str = Field(description='数据库表名')
table_schema: str = Field(description='数据库名')
11 changes: 6 additions & 5 deletions backend/app/generator/service/gen_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from backend.app.generator.crud.crud_gen_business import gen_business_dao
from backend.app.generator.crud.crud_gen_model import gen_model_dao
from backend.app.generator.model import GenBusiness
from backend.app.generator.schema.gen import ImportParam
from backend.app.generator.schema.gen_business import CreateGenBusinessParam
from backend.app.generator.schema.gen_model import CreateGenModelParam
from backend.app.generator.service.gen_model_service import gen_model_service
Expand All @@ -32,17 +33,17 @@ async def get_tables(*, table_schema: str) -> Sequence[str]:
return await gen_dao.get_all_tables(db, table_schema)

@staticmethod
async def import_business_and_model(*, app: str, table_schema: str, table_name: str) -> None:
async def import_business_and_model(*, obj: ImportParam) -> None:
async with async_db_session.begin() as db:
table_info = await gen_dao.get_table(db, table_name)
table_info = await gen_dao.get_table(db, obj.table_name)
if not table_info:
raise errors.NotFoundError(msg='数据库表不存在')
business_info = await gen_business_dao.get_by_name(db, table_name)
business_info = await gen_business_dao.get_by_name(db, obj.table_name)
if business_info:
raise errors.ForbiddenError(msg='已存在相同数据库表业务')
table_name = table_info[0]
business_data = {
'app_name': app,
'app_name': obj.app,
'table_name_en': table_name,
'table_name_zh': table_info[1] or ' '.join(table_name.split('_')),
'table_simple_name_zh': table_info[1] or table_name.split('_')[-1],
Expand All @@ -51,7 +52,7 @@ async def import_business_and_model(*, app: str, table_schema: str, table_name:
new_business = GenBusiness(**CreateGenBusinessParam(**business_data).model_dump())
db.add(new_business)
await db.flush()
column_info = await gen_dao.get_all_columns(db, table_schema, table_name)
column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name)
for column in column_info:
column_type = column[-1].split('(')[0].upper()
pd_type = sql_type_to_pydantic(column_type)
Expand Down
11 changes: 4 additions & 7 deletions backend/app/task/api/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Body, Depends, Path
from fastapi import APIRouter, Depends, Path

from backend.app.task.schema.task import RunParam
from backend.app.task.service.task_service import task_service
from backend.common.response.response_schema import ResponseModel, response_base
from backend.common.security.jwt import DependsJwtAuth
Expand Down Expand Up @@ -45,10 +46,6 @@ async def get_task_result(tid: Annotated[str, Path(description='任务ID')]) ->
DependsRBAC,
],
)
async def run_task(
name: Annotated[str, Path(description='任务名称')],
args: Annotated[list | None, Body(description='任务函数位置参数')] = None,
kwargs: Annotated[dict | None, Body(description='任务函数关键字参数')] = None,
) -> ResponseModel:
task = task_service.run(name=name, args=args, kwargs=kwargs)
async def run_task(obj: RunParam) -> ResponseModel:
task = task_service.run(name=obj.name, args=obj.args, kwargs=obj.kwargs)
return response_base.success(data=task)
2 changes: 2 additions & 0 deletions backend/app/task/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
11 changes: 11 additions & 0 deletions backend/app/task/schema/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pydantic import Field

from backend.common.schema import SchemaBase


class RunParam(SchemaBase):
name: str = Field(description='任务名称')
args: list | None = Field(default=None, description='任务函数位置参数')
kwargs: dict | None = Field(default=None, description='任务函数关键字参数')
4 changes: 3 additions & 1 deletion backend/utils/gen_template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Sequence

from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
from pydantic.alias_generators import to_pascal, to_snake

Expand Down Expand Up @@ -77,7 +79,7 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
return code_gen_path_mapping[tpl_path]

@staticmethod
def get_vars(business: GenBusiness, models: list[GenModel]) -> dict:
def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict:
"""
获取模版变量

Expand Down