11import uuid
2+ from typing import cast
23
34from fastapi import APIRouter , HTTPException , Depends , BackgroundTasks
45from sqlalchemy import select , func , delete
78from app .core .logging import get_logger
89from app .db .models .data_synthesis import (
910 save_synthesis_task ,
10- DataSynthesisInstance ,
11+ DataSynthInstance ,
1112 DataSynthesisFileInstance ,
1213 DataSynthesisChunkInstance ,
1314 SynthesisData ,
@@ -65,32 +66,64 @@ async def create_synthesis_task(
6566 synthesis_task = await save_synthesis_task (db , request )
6667
6768 # 将已有的 DatasetFiles 记录保存到 t_data_synthesis_file_instances
69+ synth_files = []
6870 for f in dataset_files :
6971 file_instance = DataSynthesisFileInstance (
7072 id = str (uuid .uuid4 ()), # 使用新的 UUID 作为文件任务记录的主键,避免与 DatasetFiles 主键冲突
7173 synthesis_instance_id = synthesis_task .id ,
7274 file_name = f .file_name ,
7375 source_file_id = str (f .id ),
74- target_file_location = synthesis_task .result_data_location or "" ,
7576 status = "pending" ,
7677 total_chunks = 0 ,
7778 processed_chunks = 0 ,
7879 created_by = "system" ,
7980 updated_by = "system" ,
8081 )
81- db . add (file_instance )
82+ synth_files . append (file_instance )
8283
8384 if dataset_files :
85+ db .add_all (synth_files )
8486 await db .commit ()
8587
8688 generation_service = GenerationService (db )
8789 # 异步处理任务:只传任务 ID,后台任务中使用新的 DB 会话重新加载任务对象
8890 background_tasks .add_task (generation_service .process_task , synthesis_task .id )
8991
92+ # 将 ORM 对象包装成 DataSynthesisTaskItem,兼容新字段从 synth_config 还原
93+ synth_cfg = getattr (synthesis_task , "synth_config" , {}) or {}
94+ text_split_cfg = synth_cfg .get ("text_split_config" ) or {}
95+ synthesis_cfg = synth_cfg .get ("synthesis_config" ) or {}
96+ source_file_ids = synth_cfg .get ("source_file_id" ) or request .source_file_id or []
97+ model_id = synth_cfg .get ("model_id" ) or request .model_id
98+ result_location = synth_cfg .get ("result_data_location" )
99+
100+ task_item = DataSynthesisTaskItem (
101+ id = synthesis_task .id ,
102+ name = synthesis_task .name ,
103+ description = synthesis_task .description ,
104+ status = synthesis_task .status ,
105+ synthesis_type = synthesis_task .synth_type ,
106+ model_id = model_id ,
107+ progress = synthesis_task .progress ,
108+ result_data_location = result_location ,
109+ text_split_config = text_split_cfg ,
110+ synthesis_config = synthesis_cfg ,
111+ source_file_id = list (source_file_ids ),
112+ total_files = synthesis_task .total_files ,
113+ processed_files = synthesis_task .processed_files ,
114+ total_chunks = synthesis_task .total_chunks ,
115+ processed_chunks = synthesis_task .processed_chunks ,
116+ total_synthesis_data = synthesis_task .total_synth_data ,
117+ created_at = synthesis_task .created_at ,
118+ updated_at = synthesis_task .updated_at ,
119+ created_by = synthesis_task .created_by ,
120+ updated_by = synthesis_task .updated_by ,
121+ )
122+
90123 return StandardResponse (
91124 code = 200 ,
92125 message = "success" ,
93- data = synthesis_task ,
126+ data = task_item ,
94127 )
95128
96129
@@ -100,7 +133,7 @@ async def get_synthesis_task(
100133 db : AsyncSession = Depends (get_db )
101134):
102135 """获取数据合成任务详情"""
103- result = await db .get (DataSynthesisInstance , task_id )
136+ result = await db .get (DataSynthInstance , task_id )
104137 if not result :
105138 raise HTTPException (status_code = 404 , detail = "Synthesis task not found" )
106139
@@ -121,16 +154,16 @@ async def list_synthesis_tasks(
121154 db : AsyncSession = Depends (get_db )
122155):
123156 """分页列出所有数据合成任务,默认按创建时间倒序"""
124- query = select (DataSynthesisInstance )
157+ query = select (DataSynthInstance )
125158 if synthesis_type :
126- query = query .filter (DataSynthesisInstance . synthesis_type == synthesis_type )
159+ query = query .filter (DataSynthInstance . synth_type == synthesis_type )
127160 if status :
128- query = query .filter (DataSynthesisInstance .status == status )
161+ query = query .filter (DataSynthInstance .status == status )
129162 if name :
130- query = query .filter (DataSynthesisInstance .name .like (f"%{ name } %" ))
163+ query = query .filter (DataSynthInstance .name .like (f"%{ name } %" ))
131164
132165 # 默认按创建时间倒序排列
133- query = query .order_by (DataSynthesisInstance .created_at .desc ())
166+ query = query .order_by (DataSynthInstance .created_at .desc ())
134167
135168 count_q = select (func .count ()).select_from (query .subquery ())
136169 total = (await db .execute (count_q )).scalar_one ()
@@ -143,31 +176,39 @@ async def list_synthesis_tasks(
143176 result = await db .execute (query .offset ((page - 1 ) * page_size ).limit (page_size ))
144177 rows = result .scalars ().all ()
145178
146- task_items = [
147- DataSynthesisTaskItem (
148- id = row .id ,
149- name = row .name ,
150- description = row .description ,
151- status = row .status ,
152- synthesis_type = row .synthesis_type ,
153- model_id = row .model_id ,
154- progress = row .progress ,
155- result_data_location = row .result_data_location ,
156- text_split_config = row .text_split_config ,
157- synthesis_config = row .synthesis_config ,
158- source_file_id = row .source_file_id ,
159- total_files = row .total_files ,
160- processed_files = row .processed_files ,
161- total_chunks = row .total_chunks ,
162- processed_chunks = row .processed_chunks ,
163- total_synthesis_data = row .total_synthesis_data ,
164- created_at = row .created_at ,
165- updated_at = row .updated_at ,
166- created_by = row .created_by ,
167- updated_by = row .updated_by ,
179+ task_items : list [DataSynthesisTaskItem ] = []
180+ for row in rows :
181+ synth_cfg = getattr (row , "synth_config" , {}) or {}
182+ text_split_cfg = synth_cfg .get ("text_split_config" ) or {}
183+ synthesis_cfg = synth_cfg .get ("synthesis_config" ) or {}
184+ source_file_ids = synth_cfg .get ("source_file_id" ) or []
185+ model_id = synth_cfg .get ("model_id" )
186+ result_location = synth_cfg .get ("result_data_location" )
187+
188+ task_items .append (
189+ DataSynthesisTaskItem (
190+ id = str (row .id ),
191+ name = str (row .name ),
192+ description = cast (str | None , row .description ),
193+ status = cast (str | None , row .status ),
194+ synthesis_type = str (row .synth_type ),
195+ model_id = model_id or "" ,
196+ progress = int (cast (int , row .progress )),
197+ result_data_location = result_location ,
198+ text_split_config = text_split_cfg ,
199+ synthesis_config = synthesis_cfg ,
200+ source_file_id = list (source_file_ids ),
201+ total_files = int (cast (int , row .total_files )),
202+ processed_files = int (cast (int , row .processed_files )),
203+ total_chunks = int (cast (int , row .total_chunks )),
204+ processed_chunks = int (cast (int , row .processed_chunks )),
205+ total_synthesis_data = int (cast (int , row .total_synth_data )),
206+ created_at = row .created_at ,
207+ updated_at = row .updated_at ,
208+ created_by = row .created_by ,
209+ updated_by = row .updated_by ,
210+ )
168211 )
169- for row in rows
170- ]
171212
172213 paged = PagedDataSynthesisTaskResponse (
173214 content = task_items ,
@@ -190,7 +231,7 @@ async def delete_synthesis_task(
190231 db : AsyncSession = Depends (get_db )
191232):
192233 """删除数据合成任务"""
193- task = await db .get (DataSynthesisInstance , task_id )
234+ task = await db .get (DataSynthInstance , task_id )
194235 if not task :
195236 raise HTTPException (status_code = 404 , detail = "Synthesis task not found" )
196237
@@ -241,7 +282,7 @@ async def delete_synthesis_file_task(
241282):
242283 """删除数据合成任务中的文件任务,同时刷新任务表中的文件/切片数量"""
243284 # 先获取任务和文件任务记录
244- task = await db .get (DataSynthesisInstance , task_id )
285+ task = await db .get (DataSynthInstance , task_id )
245286 if not task :
246287 raise HTTPException (status_code = 404 , detail = "Synthesis task not found" )
247288
@@ -306,7 +347,7 @@ async def list_synthesis_file_tasks(
306347):
307348 """分页获取某个数据合成任务下的文件任务列表"""
308349 # 先校验任务是否存在
309- task = await db .get (DataSynthesisInstance , task_id )
350+ task = await db .get (DataSynthInstance , task_id )
310351 if not task :
311352 raise HTTPException (status_code = 404 , detail = "Synthesis task not found" )
312353
@@ -523,7 +564,7 @@ async def delete_synthesis_data_by_chunk(
523564 result = await db .execute (
524565 delete (SynthesisData ).where (SynthesisData .chunk_instance_id == chunk_id )
525566 )
526- deleted = result . rowcount or 0
567+ deleted = int ( getattr ( result , " rowcount" , 0 ) or 0 )
527568
528569 await db .commit ()
529570
@@ -542,7 +583,7 @@ async def batch_delete_synthesis_data(
542583 result = await db .execute (
543584 delete (SynthesisData ).where (SynthesisData .id .in_ (request .ids ))
544585 )
545- deleted = result . rowcount or 0
586+ deleted = int ( getattr ( result , " rowcount" , 0 ) or 0 )
546587 await db .commit ()
547588
548589 return StandardResponse (code = 200 , message = "success" , data = deleted )
0 commit comments