Skip to content

Commit 816580a

Browse files
committed
[refactor] move all registry to a container class to avoid import before dataflow init issue
1 parent 43e57bd commit 816580a

18 files changed

+221
-161
lines changed

backend/app/api/v1/endpoints/datasets.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,51 @@
11
import os
22
from fastapi import APIRouter, HTTPException
33
from app.schemas.dataset import DatasetIn, DatasetOut
4-
from app.services.dataset_registry import DatasetRegistry, _DATASET_REGISTRY
4+
from app.core.container import container
55
from app.api.v1.resp import ok, created
66
from app.api.v1.envelope import ApiResponse
77
from app.api.v1.errors import *
88

99

1010
router = APIRouter(tags=["datasets"])
11-
_registry = _DATASET_REGISTRY
1211

1312
@router.get("/", response_model=ApiResponse[list[DatasetOut]], operation_id="list_datasets", summary="返回目前所有注册的数据集列表,包含每个数据集的条目数和文件大小")
1413
def list_datasets():
1514
"""返回所有数据集列表,每个数据集包含条目数(num_samples)和文件大小(file_size)信息"""
16-
return ok(_registry.list())
15+
return ok(container.dataset_registry.list())
1716

1817
@router.post("/", response_model=ApiResponse[DatasetOut], operation_id="register_dataset", summary="注册一个新的数据集或更新已有数据集的信息,根据路径作为唯一主键")
1918
def register_dataset(payload: DatasetIn):
2019
try:
21-
ds = _registry.add_or_update(payload.model_dump(mode="json")) # to dict
20+
ds = container.dataset_registry.add_or_update(payload.model_dump(mode="json")) # to dict
2221
except Exception as e:
2322
raise HTTPException(400, f"Failed to register dataset: {e}")
2423
return created(ds)
2524

2625
@router.get("/{ds_id}", response_model=ApiResponse[DatasetOut], operation_id="get_dataset", summary="根据数据集 ID 获取数据集信息")
2726
def get_dataset(ds_id: str):
28-
ds = _registry.get(ds_id)
27+
ds = container.dataset_registry.get(ds_id)
2928
if not ds:
3029
raise HTTPException(404, "Dataset not found")
3130
return ok(ds)
3231

3332
@router.delete("/{ds_id}", response_model=ApiResponse[dict], operation_id="delete_dataset", summary="根据数据集 ID 删除数据集")
3433
def delete_dataset(ds_id: str):
35-
ds = _registry.get(ds_id)
34+
ds = container.dataset_registry.get(ds_id)
3635
if not ds:
3736
raise HTTPException(404, "Dataset not found")
38-
_registry.remove(ds_id)
37+
container.dataset_registry.remove(ds_id)
3938
return ok(message="Dataset deleted")
4039

4140

4241
# getting sample data for visualization
43-
from app.services.visualize_dataset import VisualizeDatasetService
44-
_visualize_service = VisualizeDatasetService()
4542
@router.get("/pandas_type_sample/{ds_id}", response_model=ApiResponse[str], operation_id="get_pandas_data", summary="获取指定数据集的 Pandas 类型样本数据,用于前端展示预览,可以通过start和end参数控制获取多少数据")
4643
def get_pandas_data(ds_id: str, start: int = 0, end: int = 5):
4744
try:
48-
ds = _registry.get(ds_id)
45+
ds = container.dataset_registry.get(ds_id)
4946
if not ds:
5047
raise HTTPException(404, "Dataset not found")
51-
return ok(_visualize_service.get_pandas_read_function(ds, start, end))
48+
return ok(container.dataset_visualize_service.get_pandas_read_function(ds, start, end))
5249
except Exception as e:
5350
raise HTTPException(500, f"Failed to get pandas data: {e}")
5451

@@ -57,10 +54,10 @@ def get_pandas_data(ds_id: str, start: int = 0, end: int = 5):
5754
@router.get("/file_type_sample/{ds_id}", operation_id="get_file_type_data", summary="获取指定数据集的文件类型样本数据,用于前端展示下载,可以是图片、文本等")
5855
def get_file_type_data(ds_id: str):
5956
try:
60-
ds = _registry.get(ds_id)
57+
ds = container.dataset_registry.get(ds_id)
6158
if not ds:
6259
raise HTTPException(404, "Dataset not found")
63-
file_path, media_type = _visualize_service.get_other_visualization_data(ds)
60+
file_path, media_type = container.dataset_visualize_service.get_other_visualization_data(ds)
6461
except Exception as e:
6562
raise HTTPException(500, f"Failed to get file type data: {e}")
6663

@@ -82,7 +79,7 @@ def get_dataset_preview(ds_id: str, num_lines: int = 5):
8279
预览内容的列表,每个元素是一个字典
8380
"""
8481
try:
85-
preview_data = _registry.preview(ds_id, num_lines)
82+
preview_data = container.dataset_registry.preview(ds_id, num_lines)
8683
return ok(preview_data)
8784
except FileNotFoundError:
8885
raise HTTPException(404, "Dataset not found")
@@ -100,7 +97,7 @@ def get_dataset_columns(ds_id: str):
10097
列名列表,如果不支持则返回空列表
10198
"""
10299
try:
103-
columns_data = _registry.get_columns(ds_id)
100+
columns_data = container.dataset_registry.get_columns(ds_id)
104101
return ok(columns_data)
105102
except FileNotFoundError:
106103
raise HTTPException(404, "Dataset not found")

backend/app/api/v1/endpoints/operators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from app.api.v1.envelope import ApiResponse
1616

1717
# --- 2. 导入服务层 ---
18-
from app.services.operator_registry import _op_registry, OPS_JSON_PATH
18+
from app.services.operator_registry import OPS_JSON_PATH
19+
from app.core.container import container
1920

2021
router = APIRouter(tags=["operators"])
2122

@@ -28,7 +29,7 @@
2829
def list_operators():
2930
"""返回所有注册的算子列表(简化版)。"""
3031
try:
31-
op_list = _op_registry.get_op_list()
32+
op_list = container.operator_registry.get_op_list()
3233
return ok(op_list)
3334
except Exception as e:
3435
log.error(f"获取算子列表失败: {e}")
@@ -49,7 +50,7 @@ def list_operators_details():
4950
try:
5051
if not OPS_JSON_PATH.exists():
5152
log.info("ops.json 缓存文件未找到,自动触发一次算子扫描并生成缓存...")
52-
ops_data = _op_registry.dump_ops_to_json()
53+
ops_data = container.operator_registry.dump_ops_to_json()
5354
else:
5455
with open(OPS_JSON_PATH, "r", encoding="utf-8") as f:
5556
ops_data = json.load(f)
@@ -80,7 +81,7 @@ def get_operator_detail_by_name(op_name: str):
8081
# 确保缓存存在
8182
if not OPS_JSON_PATH.exists():
8283
log.info("ops.json 缓存文件未找到,自动触发一次算子扫描并生成缓存...")
83-
ops_data = _op_registry.dump_ops_to_json()
84+
ops_data = container.operator_registry.dump_ops_to_json()
8485
else:
8586
with open(OPS_JSON_PATH, "r", encoding="utf-8") as f:
8687
ops_data = json.load(f)

backend/app/api/v1/endpoints/pipelines.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
PipelineExecutionRequest,
77
PipelineExecutionResult
88
)
9-
from app.services.pipeline_registry import PipelineRegistry, _PIPELINE_REGISTRY
9+
from app.core.container import container
1010
from app.services.dataflow_engine import dataflow_engine
1111
from app.api.v1.resp import ok, created
1212
from app.api.v1.envelope import ApiResponse
@@ -23,7 +23,7 @@
2323
def list_pipelines(request: Request):
2424
try:
2525
logger.info(f"Request: {request.method} {request.url.path}")
26-
pipelines = _PIPELINE_REGISTRY.list_pipelines()
26+
pipelines = container.pipeline_registry.list_pipelines()
2727
logger.info(f"Successfully listed {len(pipelines)} pipelines")
2828
return ok(pipelines)
2929
except Exception as e:
@@ -38,9 +38,9 @@ def create_pipeline(request: Request, payload: PipelineIn):
3838

3939
operators = pipeline_in_data.get("config", {}).get("operators", [])
4040
for op in operators:
41-
op["params"] = _PIPELINE_REGISTRY.parse_frontend_params(op.get("params", []))
41+
op["params"] = container.pipeline_registry.parse_frontend_params(op.get("params", []))
4242

43-
pipeline = _PIPELINE_REGISTRY.create_pipeline(pipeline_in_data)
43+
pipeline = container.pipeline_registry.create_pipeline(pipeline_in_data)
4444
return created(pipeline)
4545
except ValueError as e:
4646
logger.error(f"Invalid pipeline configuration: {str(e)}", exc_info=True)
@@ -51,7 +51,7 @@ def create_pipeline(request: Request, payload: PipelineIn):
5151

5252
@router.get("/{pipeline_id}", response_model=ApiResponse[PipelineOut], operation_id="get_pipeline", summary="根据ID获取Pipeline详情")
5353
def get_pipeline(pipeline_id: str):
54-
pipeline = _PIPELINE_REGISTRY.get_pipeline(pipeline_id)
54+
pipeline = container.pipeline_registry.get_pipeline(pipeline_id)
5555
if not pipeline:
5656
raise HTTPException(404, f"Pipeline with id {pipeline_id} not found")
5757
return ok(pipeline)
@@ -63,9 +63,9 @@ def update_pipeline(pipeline_id: str, payload: PipelineIn):
6363

6464
# operators = pipeline_in_data.get("config", {}).get("operators", [])
6565
# for op in operators:
66-
# op["params"] = _PIPELINE_REGISTRY.parse_frontend_params(op.get("params", []))
66+
# op["params"] = container.pipeline_registry.parse_frontend_params(op.get("params", []))
6767

68-
updated_pipeline = _PIPELINE_REGISTRY.update_pipeline(pipeline_id, pipeline_in_data)
68+
updated_pipeline = container.pipeline_registry.update_pipeline(pipeline_id, pipeline_in_data)
6969
return ok(updated_pipeline)
7070
except ValueError as e:
7171
logger.error(f"Failed to update pipeline: {str(e)}")
@@ -77,7 +77,7 @@ def update_pipeline(pipeline_id: str, payload: PipelineIn):
7777
@router.delete("/{pipeline_id}", response_model=ApiResponse[Dict], operation_id="delete_pipeline", summary="删除指定的Pipeline")
7878
def delete_pipeline(pipeline_id: str):
7979
try:
80-
success = _PIPELINE_REGISTRY.delete_pipeline(pipeline_id)
80+
success = container.pipeline_registry.delete_pipeline(pipeline_id)
8181
if not success:
8282
raise HTTPException(404, f"Pipeline with id {pipeline_id} not found")
8383
return ok(message=f"Pipeline {pipeline_id} deleted successfully")
@@ -93,10 +93,10 @@ async def execute_pipeline(request: Request, pipeline_id):
9393
try:
9494
logger.info(f"Request: {request.method} {request.url.path}")
9595

96-
pipeline_config = _PIPELINE_REGISTRY.get_pipeline(pipeline_id)
96+
pipeline_config = container.pipeline_registry.get_pipeline(pipeline_id)
9797

9898
# 调用服务层开始执行
99-
execution_id, pipeline_config, initial_result = _PIPELINE_REGISTRY.start_execution(
99+
execution_id, pipeline_config, initial_result = container.pipeline_registry.start_execution(
100100
pipeline_id=pipeline_id,
101101
config=pipeline_config
102102
)
@@ -111,15 +111,15 @@ async def execute_pipeline(request: Request, pipeline_id):
111111

112112
@router.get("/execution/{execution_id}", response_model=ApiResponse[PipelineExecutionResult], operation_id="get_execution_result", summary="获取Pipeline执行结果")
113113
def get_execution_result(execution_id: str):
114-
result = _PIPELINE_REGISTRY.get_execution_result(execution_id)
114+
result = container.pipeline_registry.get_execution_result(execution_id)
115115
if not result:
116116
raise HTTPException(404, f"Execution with id {execution_id} not found")
117117
return ok(result)
118118

119119
@router.get("/executions", response_model=ApiResponse[List[PipelineExecutionResult]], operation_id="list_executions", summary="列出所有Pipeline执行记录")
120120
def list_executions():
121121
try:
122-
executions = _PIPELINE_REGISTRY.list_executions()
122+
executions = container.pipeline_registry.list_executions()
123123
return ok(executions)
124124
except Exception as e:
125125
logger.error(f"Failed to list executions: {e}")

backend/app/api/v1/endpoints/prompts.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
from fastapi import APIRouter, HTTPException
33
from app.schemas.prompt import GetPromptSchema, PromptSourceOut, OperatorPromptMapOut, PromptInfoMapOut
4-
from app.services.prompt_registry import _PROMPT_REGISTRY
4+
# from app.services.prompt_registry import _PROMPT_REGISTRY
5+
from app.core.container import container
56
from app.api.v1.resp import ok
67
from app.api.v1.envelope import ApiResponse
78

@@ -13,7 +14,7 @@
1314
summary="查看所有算子及其对应的 Prompt 列表"
1415
)
1516
def get_operator_prompt_mapping():
16-
result = _PROMPT_REGISTRY.list_operator_prompts()
17+
result = container.prompt_registry.list_operator_prompts()
1718
return ok(result)
1819

1920
@router.get(
@@ -22,15 +23,15 @@ def get_operator_prompt_mapping():
2223
summary="查看所有 prompt 的信息(operator, class string, category)"
2324
)
2425
def get_prompt_info():
25-
return ok(_PROMPT_REGISTRY.list_prompt_info())
26+
return ok(container.prompt_registry.list_prompt_info())
2627

2728
@router.get(
2829
"/{operator_name}",
2930
response_model=ApiResponse[GetPromptSchema],
3031
summary="根据算子名称获取对应的 Prompt 列表"
3132
)
3233
def get_prompts(operator_name: str):
33-
result = _PROMPT_REGISTRY.get_prompts(operator_name)
34+
result = container.prompt_registry.get_prompts(operator_name)
3435
if not result:
3536
raise HTTPException(404, "Operator not found")
3637
return ok(result)
@@ -41,7 +42,7 @@ def get_prompts(operator_name: str):
4142
summary="根据 Prompt 名称返回 Prompt 类的源码"
4243
)
4344
def get_prompt_source(prompt_name: str):
44-
result = _PROMPT_REGISTRY.get_prompt_source(prompt_name)
45+
result = container.prompt_registry.get_prompt_source(prompt_name)
4546
if not result:
4647
raise HTTPException(status_code=404, detail="Prompt not found")
4748

backend/app/api/v1/endpoints/serving.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
from typing import List, Dict, Any
44
from fastapi import APIRouter, HTTPException
5-
from app.services.prompt_registry import _PROMPT_REGISTRY
5+
from app.core.container import container
66
from app.api.v1.resp import ok
77
from app.api.v1.envelope import ApiResponse
88

@@ -15,7 +15,7 @@
1515
ServingTestSchema,
1616
ServingUpdateSchema
1717
)
18-
from app.services.serving_registry import _SERVING_REGISTRY, SERVING_CLS_REGISTRY
18+
from app.services.serving_registry import SERVING_CLS_REGISTRY
1919

2020
router = APIRouter(tags=["serving"])
2121

@@ -25,7 +25,7 @@
2525
)
2626
def list_serving_instances():
2727
try:
28-
serving_list = _SERVING_REGISTRY._get_all()
28+
serving_list = container.serving_registry._get_all()
2929
if not serving_list:
3030
result = []
3131
else:
@@ -52,7 +52,7 @@ def list_serving_classes():
5252
返回所有注册的 Serving 类及其初始化参数信息 (名称、类型、默认值)。
5353
"""
5454
try:
55-
classes_info = copy.deepcopy(_SERVING_REGISTRY.get_serving_classes())
55+
classes_info = copy.deepcopy(container.serving_registry.get_serving_classes())
5656
api_llm_info = [x for x in classes_info if x['cls_name'] == 'APILLMServing_request']
5757
for item in api_llm_info:
5858
item['params'] = [p for p in item['params'] if p['name'] != 'key_name_of_api_key']
@@ -78,7 +78,7 @@ def get_serving_detail(id: str):
7878
根据 Serving 实例的 ID,获取其详细信息。
7979
"""
8080
try:
81-
serving_data = _SERVING_REGISTRY._get(id)
81+
serving_data = container.serving_registry._get(id)
8282
print(type(serving_data))
8383
if not serving_data:
8484
raise HTTPException(status_code=404, detail=f"Serving instance with id {id} not found")
@@ -118,7 +118,7 @@ def update_serving_instance(id: str, body: ServingUpdateSchema):
118118

119119
params_list.append(p_dict)
120120

121-
success = _SERVING_REGISTRY._update(
121+
success = container.serving_registry._update(
122122
id,
123123
name=body.name,
124124
params=params_list
@@ -142,7 +142,7 @@ def delete_serving_instance(id: str):
142142
删除指定的 Serving 实例。
143143
"""
144144
try:
145-
success = _SERVING_REGISTRY._delete(id)
145+
success = container.serving_registry._delete(id)
146146
if not success:
147147
raise HTTPException(status_code=404, detail=f"Serving instance with id {id} not found")
148148
return ok({'id': id})
@@ -166,7 +166,7 @@ def create_serving_instance(
166166
try:
167167
# Get class default params info
168168
cls_info = None
169-
all_classes = _SERVING_REGISTRY.get_serving_classes()
169+
all_classes = container.serving_registry.get_serving_classes()
170170
for c in all_classes:
171171
if c['cls_name'] == cls_name:
172172
cls_info = c
@@ -197,7 +197,7 @@ def create_serving_instance(
197197

198198
new_params = list(final_params_map.values())
199199

200-
new_id = _SERVING_REGISTRY._set(name, cls_name, new_params)
200+
new_id = container.serving_registry._set(name, cls_name, new_params)
201201
return ok({
202202
'id': new_id
203203
})
@@ -217,7 +217,7 @@ def test_serving_instance(id: str, body: ServingTestSchema):
217217
"""
218218
try:
219219
prompt: str = body.prompt or "Hello, which model are you?"
220-
serving_info = _SERVING_REGISTRY._get(id)
220+
serving_info = container.serving_registry._get(id)
221221
params_dict = {}
222222

223223
## This part of code is only for APILLMServing_request

0 commit comments

Comments
 (0)