Skip to content

Commit 2ec96d9

Browse files
committed
feat(DataSynthesis): refactor data synthesis models and update task handling logic
1 parent 4aaf0fd commit 2ec96d9

File tree

9 files changed

+300
-183
lines changed

9 files changed

+300
-183
lines changed

runtime/datamate-python/app/db/models/data_synthesis.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,105 @@
11
import uuid
2-
from xml.etree.ElementTree import tostring
32

4-
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, ForeignKey, func
5-
from sqlalchemy.orm import relationship
3+
from sqlalchemy import Column, String, Text, Integer, JSON, TIMESTAMP, func
64

75
from app.db.session import Base
86
from app.module.generation.schema.generation import CreateSynthesisTaskRequest
97

108

119
async def save_synthesis_task(db_session, synthesis_task: CreateSynthesisTaskRequest):
12-
"""保存数据合成任务。"""
13-
# 转换为模型实例
10+
"""保存数据合成任务。
11+
12+
注意:当前 MySQL 表 `t_data_synth_instances` 结构中只包含 synth_type / synth_config 等字段,
13+
没有 model_id、text_split_config、source_file_id、result_data_location 等列,因此这里只保存
14+
与表结构一致的字段,其他信息由上层逻辑或其它表负责管理。
15+
"""
1416
gid = str(uuid.uuid4())
15-
synthesis_task_instance = DataSynthesisInstance(
17+
18+
# 兼容旧请求结构:从请求对象中提取必要字段,
19+
# - 合成类型:synthesis_type -> synth_type
20+
# - 合成配置:text_split_config + synthesis_config 合并后写入 synth_config
21+
synth_config = {
22+
"text_split_config": synthesis_task.text_split_config.model_dump()
23+
if synthesis_task.text_split_config
24+
else None,
25+
"synthesis_config": synthesis_task.synthesis_config.model_dump()
26+
if synthesis_task.synthesis_config
27+
else None,
28+
"model_id": synthesis_task.model_id,
29+
"source_file_id": list(synthesis_task.source_file_id or []),
30+
}
31+
32+
synth_task_instance = DataSynthInstance(
1633
id=gid,
1734
name=synthesis_task.name,
1835
description=synthesis_task.description,
1936
status="pending",
20-
model_id=synthesis_task.model_id,
21-
synthesis_type=synthesis_task.synthesis_type.value,
37+
synth_type=synthesis_task.synthesis_type.value,
2238
progress=0,
23-
result_data_location=f"/dataset/synthesis_results/{gid}/",
24-
text_split_config=synthesis_task.text_split_config.model_dump(),
25-
synthesis_config=synthesis_task.synthesis_config.model_dump(),
26-
source_file_id=synthesis_task.source_file_id,
27-
total_files=len(synthesis_task.source_file_id),
39+
synth_config=synth_config,
40+
total_files=len(synthesis_task.source_file_id or []),
2841
processed_files=0,
2942
total_chunks=0,
3043
processed_chunks=0,
31-
total_synthesis_data=0,
44+
total_synth_data=0,
3245
created_at=func.now(),
3346
updated_at=func.now(),
3447
created_by="system",
35-
updated_by="system"
48+
updated_by="system",
3649
)
37-
db_session.add(synthesis_task_instance)
50+
db_session.add(synth_task_instance)
3851
await db_session.commit()
39-
await db_session.refresh(synthesis_task_instance)
40-
return synthesis_task_instance
52+
await db_session.refresh(synth_task_instance)
53+
return synth_task_instance
4154

4255

43-
class DataSynthesisInstance(Base):
44-
"""数据合成任务表,对应表 t_data_synthesis_instances
56+
class DataSynthInstance(Base):
57+
"""数据合成任务表,对应表 t_data_synth_instances
4558
46-
create table if not exists t_data_synthesis_instances
59+
create table if not exists t_data_synth_instances
4760
(
4861
id VARCHAR(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci PRIMARY KEY COMMENT 'UUID',
4962
name VARCHAR(255) NOT NULL COMMENT '任务名称',
5063
description TEXT COMMENT '任务描述',
5164
status VARCHAR(20) COMMENT '任务状态',
52-
synthesis_type VARCHAR(20) NOT NULL COMMENT '合成类型',
53-
model_id VARCHAR(255) NOT NULL COMMENT '模型ID',
65+
synth_type VARCHAR(20) NOT NULL COMMENT '合成类型',
5466
progress INT DEFAULT 0 COMMENT '任务进度(百分比)',
55-
result_data_location VARCHAR(1000) COMMENT '结果数据存储位置',
56-
text_split_config JSON NOT NULL COMMENT '文本切片配置',
57-
synthesis_config JSON NOT NULL COMMENT '合成配置',
58-
source_file_id JSON NOT NULL COMMENT '原始文件ID列表',
67+
synth_config JSON NOT NULL COMMENT '合成配置',
5968
total_files INT DEFAULT 0 COMMENT '总文件数',
6069
processed_files INT DEFAULT 0 COMMENT '已处理文件数',
6170
total_chunks INT DEFAULT 0 COMMENT '总文本块数',
6271
processed_chunks INT DEFAULT 0 COMMENT '已处理文本块数',
63-
total_synthesis_data INT DEFAULT 0 COMMENT '总合成数据量',
72+
total_synth_data INT DEFAULT 0 COMMENT '总合成数据量',
6473
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
6574
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
6675
created_by VARCHAR(255) COMMENT '创建者',
6776
updated_by VARCHAR(255) COMMENT '更新者'
6877
) COMMENT='数据合成任务表(UUID 主键)';
6978
"""
7079

71-
__tablename__ = "t_data_synthesis_instances"
80+
__tablename__ = "t_data_synth_instances"
7281

7382
id = Column(String(36), primary_key=True, index=True, comment="UUID")
7483
name = Column(String(255), nullable=False, comment="任务名称")
7584
description = Column(Text, nullable=True, comment="任务描述")
7685
status = Column(String(20), nullable=True, comment="任务状态")
77-
synthesis_type = Column(String(20), nullable=False, comment="合成类型")
78-
model_id = Column(String(255), nullable=False, comment="模型ID")
86+
# 与数据库字段保持一致:synth_type / synth_config
87+
synth_type = Column(String(20), nullable=False, comment="合成类型")
7988
progress = Column(Integer, nullable=False, default=0, comment="任务进度(百分比)")
80-
result_data_location = Column(String(1000), nullable=True, comment="结果数据存储位置")
81-
text_split_config = Column(JSON, nullable=False, comment="文本切片配置")
82-
synthesis_config = Column(JSON, nullable=False, comment="合成配置")
83-
source_file_id = Column(JSON, nullable=False, comment="原始文件ID列表")
89+
synth_config = Column(JSON, nullable=False, comment="合成配置")
8490
total_files = Column(Integer, nullable=False, default=0, comment="总文件数")
8591
processed_files = Column(Integer, nullable=False, default=0, comment="已处理文件数")
8692
total_chunks = Column(Integer, nullable=False, default=0, comment="总文本块数")
8793
processed_chunks = Column(Integer, nullable=False, default=0, comment="已处理文本块数")
88-
total_synthesis_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
89-
90-
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), nullable=True, comment="创建时间")
91-
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), nullable=True, comment="更新时间")
94+
total_synth_data = Column(Integer, nullable=False, default=0, comment="总合成数据量")
95+
created_at = Column(TIMESTAMP, nullable=False, default=func.now(), comment="创建时间")
96+
updated_at = Column(
97+
TIMESTAMP,
98+
nullable=False,
99+
default=func.now(),
100+
onupdate=func.now(),
101+
comment="更新时间",
102+
)
92103
created_by = Column(String(255), nullable=True, comment="创建者")
93104
updated_by = Column(String(255), nullable=True, comment="更新者")
94105

runtime/datamate-python/app/module/evaluation/service/evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from app.db.session import AsyncSessionLocal
1414
from app.module.evaluation.schema.evaluation import SourceType
1515
from app.module.shared.schema import TaskStatus
16-
from app.module.shared.util.model_chat import call_openai_style_model, _extract_json_substring
16+
from app.module.shared.util.model_chat import call_openai_style_model, extract_json_substring
1717
from app.module.evaluation.schema.prompt import get_prompt
1818
from app.module.shared.util.structured_file import StructuredFileHandlerFactory
1919
from app.module.system.service.common_service import get_model_by_id
@@ -73,7 +73,7 @@ async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asy
7373
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
7474
prompt_text,
7575
)
76-
resp_text = _extract_json_substring(resp_text)
76+
resp_text = extract_json_substring(resp_text)
7777
try:
7878
json.loads(resp_text)
7979
except Exception as e:

runtime/datamate-python/app/module/generation/interface/generation_api.py

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from typing import cast
23

34
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
45
from sqlalchemy import select, func, delete
@@ -7,7 +8,7 @@
78
from app.core.logging import get_logger
89
from app.db.models.data_synthesis import (
910
save_synthesis_task,
10-
DataSynthesisInstance,
11+
DataSynthInstance,
1112
DataSynthesisFileInstance,
1213
DataSynthesisChunkInstance,
1314
SynthesisData,
@@ -65,32 +66,64 @@ async def create_synthesis_task(
6566
synthesis_task = await save_synthesis_task(db, request)
6667

6768
# 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
69+
synth_files = []
6870
for f in dataset_files:
6971
file_instance = DataSynthesisFileInstance(
7072
id=str(uuid.uuid4()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
7173
synthesis_instance_id=synthesis_task.id,
7274
file_name=f.file_name,
7375
source_file_id=str(f.id),
74-
target_file_location=synthesis_task.result_data_location or "",
7576
status="pending",
7677
total_chunks=0,
7778
processed_chunks=0,
7879
created_by="system",
7980
updated_by="system",
8081
)
81-
db.add(file_instance)
82+
synth_files.append(file_instance)
8283

8384
if dataset_files:
85+
db.add_all(synth_files)
8486
await db.commit()
8587

8688
generation_service = GenerationService(db)
8789
# 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
8890
background_tasks.add_task(generation_service.process_task, synthesis_task.id)
8991

92+
# 将 ORM 对象包装成 DataSynthesisTaskItem,兼容新字段从 synth_config 还原
93+
synth_cfg = getattr(synthesis_task, "synth_config", {}) or {}
94+
text_split_cfg = synth_cfg.get("text_split_config") or {}
95+
synthesis_cfg = synth_cfg.get("synthesis_config") or {}
96+
source_file_ids = synth_cfg.get("source_file_id") or request.source_file_id or []
97+
model_id = synth_cfg.get("model_id") or request.model_id
98+
result_location = synth_cfg.get("result_data_location")
99+
100+
task_item = DataSynthesisTaskItem(
101+
id=synthesis_task.id,
102+
name=synthesis_task.name,
103+
description=synthesis_task.description,
104+
status=synthesis_task.status,
105+
synthesis_type=synthesis_task.synth_type,
106+
model_id=model_id,
107+
progress=synthesis_task.progress,
108+
result_data_location=result_location,
109+
text_split_config=text_split_cfg,
110+
synthesis_config=synthesis_cfg,
111+
source_file_id=list(source_file_ids),
112+
total_files=synthesis_task.total_files,
113+
processed_files=synthesis_task.processed_files,
114+
total_chunks=synthesis_task.total_chunks,
115+
processed_chunks=synthesis_task.processed_chunks,
116+
total_synthesis_data=synthesis_task.total_synth_data,
117+
created_at=synthesis_task.created_at,
118+
updated_at=synthesis_task.updated_at,
119+
created_by=synthesis_task.created_by,
120+
updated_by=synthesis_task.updated_by,
121+
)
122+
90123
return StandardResponse(
91124
code=200,
92125
message="success",
93-
data=synthesis_task,
126+
data=task_item,
94127
)
95128

96129

@@ -100,7 +133,7 @@ async def get_synthesis_task(
100133
db: AsyncSession = Depends(get_db)
101134
):
102135
"""获取数据合成任务详情"""
103-
result = await db.get(DataSynthesisInstance, task_id)
136+
result = await db.get(DataSynthInstance, task_id)
104137
if not result:
105138
raise HTTPException(status_code=404, detail="Synthesis task not found")
106139

@@ -121,16 +154,16 @@ async def list_synthesis_tasks(
121154
db: AsyncSession = Depends(get_db)
122155
):
123156
"""分页列出所有数据合成任务,默认按创建时间倒序"""
124-
query = select(DataSynthesisInstance)
157+
query = select(DataSynthInstance)
125158
if synthesis_type:
126-
query = query.filter(DataSynthesisInstance.synthesis_type == synthesis_type)
159+
query = query.filter(DataSynthInstance.synth_type == synthesis_type)
127160
if status:
128-
query = query.filter(DataSynthesisInstance.status == status)
161+
query = query.filter(DataSynthInstance.status == status)
129162
if name:
130-
query = query.filter(DataSynthesisInstance.name.like(f"%{name}%"))
163+
query = query.filter(DataSynthInstance.name.like(f"%{name}%"))
131164

132165
# 默认按创建时间倒序排列
133-
query = query.order_by(DataSynthesisInstance.created_at.desc())
166+
query = query.order_by(DataSynthInstance.created_at.desc())
134167

135168
count_q = select(func.count()).select_from(query.subquery())
136169
total = (await db.execute(count_q)).scalar_one()
@@ -143,31 +176,39 @@ async def list_synthesis_tasks(
143176
result = await db.execute(query.offset((page - 1) * page_size).limit(page_size))
144177
rows = result.scalars().all()
145178

146-
task_items = [
147-
DataSynthesisTaskItem(
148-
id=row.id,
149-
name=row.name,
150-
description=row.description,
151-
status=row.status,
152-
synthesis_type=row.synthesis_type,
153-
model_id=row.model_id,
154-
progress=row.progress,
155-
result_data_location=row.result_data_location,
156-
text_split_config=row.text_split_config,
157-
synthesis_config=row.synthesis_config,
158-
source_file_id=row.source_file_id,
159-
total_files=row.total_files,
160-
processed_files=row.processed_files,
161-
total_chunks=row.total_chunks,
162-
processed_chunks=row.processed_chunks,
163-
total_synthesis_data=row.total_synthesis_data,
164-
created_at=row.created_at,
165-
updated_at=row.updated_at,
166-
created_by=row.created_by,
167-
updated_by=row.updated_by,
179+
task_items: list[DataSynthesisTaskItem] = []
180+
for row in rows:
181+
synth_cfg = getattr(row, "synth_config", {}) or {}
182+
text_split_cfg = synth_cfg.get("text_split_config") or {}
183+
synthesis_cfg = synth_cfg.get("synthesis_config") or {}
184+
source_file_ids = synth_cfg.get("source_file_id") or []
185+
model_id = synth_cfg.get("model_id")
186+
result_location = synth_cfg.get("result_data_location")
187+
188+
task_items.append(
189+
DataSynthesisTaskItem(
190+
id=str(row.id),
191+
name=str(row.name),
192+
description=cast(str | None, row.description),
193+
status=cast(str | None, row.status),
194+
synthesis_type=str(row.synth_type),
195+
model_id=model_id or "",
196+
progress=int(cast(int, row.progress)),
197+
result_data_location=result_location,
198+
text_split_config=text_split_cfg,
199+
synthesis_config=synthesis_cfg,
200+
source_file_id=list(source_file_ids),
201+
total_files=int(cast(int, row.total_files)),
202+
processed_files=int(cast(int, row.processed_files)),
203+
total_chunks=int(cast(int, row.total_chunks)),
204+
processed_chunks=int(cast(int, row.processed_chunks)),
205+
total_synthesis_data=int(cast(int, row.total_synth_data)),
206+
created_at=row.created_at,
207+
updated_at=row.updated_at,
208+
created_by=row.created_by,
209+
updated_by=row.updated_by,
210+
)
168211
)
169-
for row in rows
170-
]
171212

172213
paged = PagedDataSynthesisTaskResponse(
173214
content=task_items,
@@ -190,7 +231,7 @@ async def delete_synthesis_task(
190231
db: AsyncSession = Depends(get_db)
191232
):
192233
"""删除数据合成任务"""
193-
task = await db.get(DataSynthesisInstance, task_id)
234+
task = await db.get(DataSynthInstance, task_id)
194235
if not task:
195236
raise HTTPException(status_code=404, detail="Synthesis task not found")
196237

@@ -241,7 +282,7 @@ async def delete_synthesis_file_task(
241282
):
242283
"""删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量"""
243284
# 先获取任务和文件任务记录
244-
task = await db.get(DataSynthesisInstance, task_id)
285+
task = await db.get(DataSynthInstance, task_id)
245286
if not task:
246287
raise HTTPException(status_code=404, detail="Synthesis task not found")
247288

@@ -306,7 +347,7 @@ async def list_synthesis_file_tasks(
306347
):
307348
"""分页获取某个数据合成任务下的文件任务列表"""
308349
# 先校验任务是否存在
309-
task = await db.get(DataSynthesisInstance, task_id)
350+
task = await db.get(DataSynthInstance, task_id)
310351
if not task:
311352
raise HTTPException(status_code=404, detail="Synthesis task not found")
312353

@@ -523,7 +564,7 @@ async def delete_synthesis_data_by_chunk(
523564
result = await db.execute(
524565
delete(SynthesisData).where(SynthesisData.chunk_instance_id == chunk_id)
525566
)
526-
deleted = result.rowcount or 0
567+
deleted = int(getattr(result, "rowcount", 0) or 0)
527568

528569
await db.commit()
529570

@@ -542,7 +583,7 @@ async def batch_delete_synthesis_data(
542583
result = await db.execute(
543584
delete(SynthesisData).where(SynthesisData.id.in_(request.ids))
544585
)
545-
deleted = result.rowcount or 0
586+
deleted = int(getattr(result, "rowcount", 0) or 0)
546587
await db.commit()
547588

548589
return StandardResponse(code=200, message="success", data=deleted)

0 commit comments

Comments
 (0)