Skip to content

Commit 6cc9bc6

Browse files
committed
Update code generation and task return types
1 parent 3cb9a57 commit 6cc9bc6

File tree

8 files changed

+41
-38
lines changed

8 files changed

+41
-38
lines changed

backend/app/generator/api/v1/gen.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from backend.app.generator.conf import generator_settings
99
from backend.app.generator.schema.gen import ImportParam
1010
from backend.app.generator.service.gen_service import gen_service
11-
from backend.common.response.response_schema import ResponseModel, response_base
11+
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
1212
from backend.common.security.jwt import DependsJwtAuth
1313
from backend.common.security.permission import RequestPermission
1414
from backend.common.security.rbac import DependsRBAC
@@ -17,7 +17,9 @@
1717

1818

1919
@router.get('/tables', summary='获取数据库表')
20-
async def get_all_tables(table_schema: Annotated[str, Query(..., description='数据库名')] = 'fba') -> ResponseModel:
20+
async def get_all_tables(
21+
table_schema: Annotated[str, Query(..., description='数据库名')] = 'fba',
22+
) -> ResponseSchemaModel[list[str]]:
2123
data = await gen_service.get_tables(table_schema=table_schema)
2224
return response_base.success(data=data)
2325

@@ -36,13 +38,13 @@ async def import_table(obj: ImportParam) -> ResponseModel:
3638

3739

3840
@router.get('/preview/{pk}', summary='生成代码预览', dependencies=[DependsJwtAuth])
39-
async def preview_code(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseModel:
41+
async def preview_code(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseSchemaModel[dict[str, bytes]]:
4042
data = await gen_service.preview(pk=pk)
4143
return response_base.success(data=data)
4244

4345

4446
@router.get('/generate/{pk}/path', summary='获取代码生成路径', dependencies=[DependsJwtAuth])
45-
async def generate_path(pk: Annotated[int, Path(..., description='业务ID')]):
47+
async def generate_path(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseSchemaModel[list[str]]:
4648
data = await gen_service.get_generate_path(pk=pk)
4749
return response_base.success(data=data)
4850

backend/app/generator/api/v1/gen_business.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
from fastapi import APIRouter, Depends, Path
66

7+
from backend.app.generator.model import GenBusiness, GenModel
78
from backend.app.generator.schema.gen_business import (
89
CreateGenBusinessParam,
910
GetGenBusinessDetail,
1011
UpdateGenBusinessParam,
1112
)
1213
from backend.app.generator.service.gen_business_service import gen_business_service
1314
from backend.app.generator.service.gen_model_service import gen_model_service
14-
from backend.common.response.response_schema import ResponseModel, response_base
15+
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
1516
from backend.common.security.jwt import DependsJwtAuth
1617
from backend.common.security.permission import RequestPermission
1718
from backend.common.security.rbac import DependsRBAC
@@ -21,21 +22,21 @@
2122

2223

2324
@router.get('/all', summary='获取所有代码生成业务', dependencies=[DependsJwtAuth])
24-
async def get_all_businesses() -> ResponseModel:
25+
async def get_all_businesses() -> ResponseSchemaModel[list[GenBusiness]]:
2526
businesses = await gen_business_service.get_all()
2627
data = select_list_serialize(businesses)
2728
return response_base.success(data=data)
2829

2930

3031
@router.get('/{pk}', summary='获取代码生成业务详情', dependencies=[DependsJwtAuth])
31-
async def get_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
32+
async def get_business(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetGenBusinessDetail]:
3233
business = await gen_business_service.get(pk=pk)
3334
data = GetGenBusinessDetail(**select_as_dict(business))
3435
return response_base.success(data=data)
3536

3637

3738
@router.get('/{pk}/models', summary='获取代码生成业务所有模型', dependencies=[DependsJwtAuth])
38-
async def get_business_all_models(pk: Annotated[int, Path(...)]) -> ResponseModel:
39+
async def get_business_all_models(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[list[GenModel]]:
3940
models = await gen_model_service.get_by_business(business_id=pk)
4041
data = select_list_serialize(models)
4142
return response_base.success(data=data)

backend/app/generator/api/v1/gen_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from backend.app.generator.schema.gen_model import CreateGenModelParam, GetGenModelDetail, UpdateGenModelParam
88
from backend.app.generator.service.gen_model_service import gen_model_service
9-
from backend.common.response.response_schema import ResponseModel, response_base
9+
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
1010
from backend.common.security.jwt import DependsJwtAuth
1111
from backend.common.security.permission import RequestPermission
1212
from backend.common.security.rbac import DependsRBAC
@@ -16,13 +16,13 @@
1616

1717

1818
@router.get('/types', summary='获取代码生成模型列类型', dependencies=[DependsJwtAuth])
19-
async def get_model_types() -> ResponseModel:
19+
async def get_model_types() -> ResponseSchemaModel[list[str]]:
2020
model_types = await gen_model_service.get_types()
2121
return response_base.success(data=model_types)
2222

2323

2424
@router.get('/{pk}', summary='获取代码生成模型详情', dependencies=[DependsJwtAuth])
25-
async def get_model(pk: Annotated[int, Path(...)]) -> ResponseModel:
25+
async def get_model(pk: Annotated[int, Path(...)]) -> ResponseSchemaModel[GetGenModelDetail]:
2626
model = await gen_model_service.get(pk=pk)
2727
data = GetGenModelDetail(**select_as_dict(model))
2828
return response_base.success(data=data)

backend/app/generator/service/gen_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def import_business_and_model(*, obj: ImportParam) -> None:
6969
await gen_model_dao.create(db, CreateGenModelParam(**model_data), pd_type=pd_type)
7070

7171
@staticmethod
72-
async def render_tpl_code(*, business: GenBusiness) -> dict:
72+
async def render_tpl_code(*, business: GenBusiness) -> dict[str, str]:
7373
gen_models = await gen_model_service.get_by_business(business_id=business.id)
7474
if not gen_models:
7575
raise errors.NotFoundError(msg='代码生成模型表为空')
@@ -79,7 +79,7 @@ async def render_tpl_code(*, business: GenBusiness) -> dict:
7979
tpl_code_map[tpl_path] = await gen_template.get_template(tpl_path).render_async(**gen_vars)
8080
return tpl_code_map
8181

82-
async def preview(self, *, pk: int) -> dict:
82+
async def preview(self, *, pk: int) -> dict[str, bytes]:
8383
async with async_db_session() as db:
8484
business = await gen_business_dao.get(db, pk)
8585
if not business:
@@ -91,7 +91,7 @@ async def preview(self, *, pk: int) -> dict:
9191
}
9292

9393
@staticmethod
94-
async def get_generate_path(*, pk: int) -> list:
94+
async def get_generate_path(*, pk: int) -> list[str]:
9595
async with async_db_session() as db:
9696
business = await gen_business_dao.get(db, pk)
9797
if not business:

backend/app/task/api/v1/task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from fastapi import APIRouter, Depends, Path
66

7-
from backend.app.task.schema.task import RunParam
7+
from backend.app.task.schema.task import RunParam, TaskResult
88
from backend.app.task.service.task_service import task_service
9-
from backend.common.response.response_schema import ResponseModel, response_base
9+
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
1010
from backend.common.security.jwt import DependsJwtAuth
1111
from backend.common.security.permission import RequestPermission
1212
from backend.common.security.rbac import DependsRBAC
@@ -15,7 +15,7 @@
1515

1616

1717
@router.get('', summary='获取可执行任务', dependencies=[DependsJwtAuth])
18-
async def get_all_tasks() -> ResponseModel:
18+
async def get_all_tasks() -> ResponseSchemaModel[list[str]]:
1919
tasks = await task_service.get_list()
2020
return response_base.success(data=tasks)
2121

@@ -27,7 +27,7 @@ async def get_all_tasks() -> ResponseModel:
2727
description='此接口被视为作废,建议使用 flower 查看任务详情',
2828
dependencies=[DependsJwtAuth],
2929
)
30-
async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel:
30+
async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> ResponseSchemaModel[TaskResult]:
3131
status = task_service.get_detail(tid=tid)
3232
return response_base.success(data=status)
3333

@@ -53,6 +53,6 @@ async def revoke_task(tid: Annotated[str, Path(description='任务ID')]) -> Resp
5353
DependsRBAC,
5454
],
5555
)
56-
async def run_task(obj: RunParam) -> ResponseModel:
56+
async def run_task(obj: RunParam) -> ResponseSchemaModel[str]:
5757
task = task_service.run(obj=obj)
5858
return response_base.success(data=task)

backend/app/task/schema/task.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,15 @@ class RunParam(SchemaBase):
99
name: str = Field(description='任务名称')
1010
args: list | None = Field(default=None, description='任务函数位置参数')
1111
kwargs: dict | None = Field(default=None, description='任务函数关键字参数')
12+
13+
14+
class TaskResult(SchemaBase):
15+
result: str
16+
traceback: str
17+
status: str
18+
name: str
19+
args: list | None
20+
kwargs: dict | None
21+
worker: str
22+
retries: int | None
23+
queue: str | None

backend/app/task/service/task_service.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
43
from celery.exceptions import NotRegistered
54
from celery.result import AsyncResult
65
from starlette.concurrency import run_in_threadpool
76

87
from backend.app.task.celery import celery_app
9-
from backend.app.task.schema.task import RunParam
10-
from backend.common.dataclasses import TaskResult
8+
from backend.app.task.schema.task import RunParam, TaskResult
9+
from backend.common.exception import errors
1110
from backend.common.exception.errors import NotFoundError
1211

1312

1413
class TaskService:
1514
@staticmethod
16-
async def get_list():
15+
async def get_list() -> list[str]:
1716
registered_tasks = await run_in_threadpool(celery_app.control.inspect().registered)
17+
if not registered_tasks:
18+
raise errors.ForbiddenError(msg='celery 服务未启动')
1819
tasks = list(registered_tasks.values())[0]
1920
return tasks
2021

2122
@staticmethod
22-
def get_detail(*, tid: str):
23+
def get_detail(*, tid: str) -> TaskResult:
2324
try:
2425
result = AsyncResult(id=tid, app=celery_app)
2526
except NotRegistered:
@@ -45,7 +46,7 @@ def revoke(*, tid: str):
4546
result.revoke(terminate=True)
4647

4748
@staticmethod
48-
def run(*, obj: RunParam):
49+
def run(*, obj: RunParam) -> str:
4950
task: AsyncResult = celery_app.send_task(name=obj.name, args=obj.args, kwargs=obj.kwargs)
5051
return task.task_id
5152

backend/common/dataclasses.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,3 @@ class AccessToken:
5252
class RefreshToken:
5353
refresh_token: str
5454
refresh_token_expire_time: datetime
55-
56-
57-
@dataclasses.dataclass
58-
class TaskResult:
59-
result: str
60-
traceback: str
61-
status: str
62-
name: str
63-
args: list | None
64-
kwargs: dict | None
65-
worker: str
66-
retries: int | None
67-
queue: str | None

0 commit comments

Comments
 (0)