@@ -146,6 +146,9 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid:
146146
147147
148148def create_training (session : SessionDep , info : DataTrainingInfo , oid : int , trans : Trans ):
149+ """
150+ 创建单个数据训练记录
151+ """
149152 # 基本验证
150153 if not info .question or not info .question .strip ():
151154 raise Exception (trans ("i18n_data_training.question_cannot_be_empty" ))
@@ -154,45 +157,56 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
154157 raise Exception (trans ("i18n_data_training.description_cannot_be_empty" ))
155158
156159 create_time = datetime .datetime .now ()
160+
161+ # 检查数据源和高级应用不能同时为空
157162 if info .datasource is None and info .advanced_application is None :
158163 if oid == 1 :
159164 raise Exception (trans ("i18n_data_training.datasource_assistant_cannot_be_none" ))
160165 else :
161166 raise Exception (trans ("i18n_data_training.datasource_cannot_be_none" ))
162167
163- parent = DataTraining (question = info .question , create_time = create_time , description = info .description , oid = oid ,
164- datasource = info .datasource , enabled = info .enabled ,
165- advanced_application = info .advanced_application )
166-
167- stmt = select (DataTraining .id ).where (and_ (DataTraining .question == info .question , DataTraining .oid == oid ))
168+ # 检查重复记录
169+ stmt = select (DataTraining .id ).where (
170+ and_ (DataTraining .question == info .question .strip (), DataTraining .oid == oid )
171+ )
168172
169173 if info .datasource is not None and info .advanced_application is not None :
170174 stmt = stmt .where (
171- or_ (DataTraining .datasource == info .datasource ,
172- DataTraining .advanced_application == info .advanced_application ))
175+ or_ (
176+ DataTraining .datasource == info .datasource ,
177+ DataTraining .advanced_application == info .advanced_application
178+ )
179+ )
173180 elif info .datasource is not None and info .advanced_application is None :
174- stmt = stmt .where (and_ ( DataTraining .datasource == info .datasource ) )
181+ stmt = stmt .where (DataTraining .datasource == info .datasource )
175182 elif info .datasource is None and info .advanced_application is not None :
176- stmt = stmt .where (and_ ( DataTraining .advanced_application == info .advanced_application ) )
183+ stmt = stmt .where (DataTraining .advanced_application == info .advanced_application )
177184
178185 exists = session .query (stmt .exists ()).scalar ()
179186
180187 if exists :
181188 raise Exception (trans ("i18n_data_training.exists_in_db" ))
182189
183- result = DataTraining (** parent .model_dump ())
190+ # 创建记录
191+ data_training = DataTraining (
192+ question = info .question .strip (),
193+ description = info .description .strip (),
194+ oid = oid ,
195+ datasource = info .datasource ,
196+ advanced_application = info .advanced_application ,
197+ create_time = create_time ,
198+ enabled = info .enabled if info .enabled is not None else True
199+ )
184200
185- session .add (parent )
201+ session .add (data_training )
186202 session .flush ()
187- session .refresh (parent )
188-
189- result .id = parent .id
203+ session .refresh (data_training )
190204 session .commit ()
191205
192- # embedding
193- run_save_data_training_embeddings ([result .id ])
206+ # 处理embedding
207+ run_save_data_training_embeddings ([data_training .id ])
194208
195- return result .id
209+ return data_training .id
196210
197211
198212def update_training (session : SessionDep , info : DataTrainingInfo , oid : int , trans : Trans ):
@@ -250,14 +264,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
250264
251265def batch_create_training (session : SessionDep , info_list : List [DataTrainingInfo ], oid : int , trans : Trans ):
252266 """
253- 批量创建数据训练记录
254- Args:
255- session: 数据库会话
256- info_list: DataTrainingInfo对象列表
257- oid: 组织ID
258- trans: 翻译对象
259- Returns:
260- dict: 包含成功数量、失败记录和统计信息的结果字典
267+ 批量创建数据训练记录(复用单条插入逻辑)
261268 """
262269 if not info_list :
263270 return {
@@ -268,48 +275,45 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
268275 'deduplicated_count' : 0
269276 }
270277
271- create_time = datetime .datetime .now ()
272- failed_records = [] # 存储失败的数据和原因
278+ failed_records = []
273279 success_count = 0
274- inserted_ids = [] # 存储成功插入的ID
280+ inserted_ids = []
275281
276282 # 第一步:数据去重
277283 unique_records = {}
278- duplicate_records = [] # 存储重复的数据
284+ duplicate_records = []
279285
280286 for info in info_list :
281- # 创建唯一标识:问题 + 数据源名称 + 高级应用名称
287+ # 创建唯一标识
282288 unique_key = (
283289 info .question .strip ().lower () if info .question else "" ,
284290 info .datasource_name .strip ().lower () if info .datasource_name else "" ,
285291 info .advanced_application_name .strip ().lower () if info .advanced_application_name else ""
286292 )
287293
288294 if unique_key in unique_records :
289- # 如果是重复数据,记录到重复列表中
290295 duplicate_records .append (info )
291296 else :
292297 unique_records [unique_key ] = info
293298
294299 # 将去重后的数据转换为列表
295300 deduplicated_list = list (unique_records .values ())
296301
297- # 预加载数据源名称到ID的映射(CoreDatasource需要判断oid)
302+ # 预加载数据源和高级应用名称到ID的映射
298303 datasource_name_to_id = {}
299304 datasource_stmt = select (CoreDatasource .id , CoreDatasource .name ).where (CoreDatasource .oid == oid )
300305 datasource_result = session .execute (datasource_stmt ).all ()
301306 for ds in datasource_result :
302307 datasource_name_to_id [ds .name .strip ()] = ds .id
303308
304- # 只有在oid=1时才预加载高级应用名称到ID的映射
305309 assistant_name_to_id = {}
306310 if oid == 1 :
307311 assistant_stmt = select (AssistantModel .id , AssistantModel .name ).where (AssistantModel .type == 1 )
308312 assistant_result = session .execute (assistant_stmt ).all ()
309313 for assistant in assistant_result :
310314 assistant_name_to_id [assistant .name .strip ()] = assistant .id
311315
312- # 验证和准备数据
316+ # 验证和转换数据
313317 valid_records = []
314318 for info in deduplicated_list :
315319 error_messages = []
@@ -321,15 +325,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
321325 if not info .description or not info .description .strip ():
322326 error_messages .append (trans ("i18n_data_training.description_cannot_be_empty" ))
323327
324- # 数据源验证
328+ # 数据源验证和转换
325329 datasource_id = None
326330 if info .datasource_name and info .datasource_name .strip ():
327331 if info .datasource_name .strip () in datasource_name_to_id :
328332 datasource_id = datasource_name_to_id [info .datasource_name .strip ()]
329333 else :
330334 error_messages .append (trans ("i18n_data_training.datasource_not_found" ).format (info .datasource_name ))
331335
332- # 高级应用验证(只有在oid=1时才需要)
336+ # 高级应用验证和转换
333337 advanced_application_id = None
334338 if oid == 1 and info .advanced_application_name and info .advanced_application_name .strip ():
335339 if info .advanced_application_name .strip () in assistant_name_to_id :
@@ -346,101 +350,43 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
346350 if not datasource_id :
347351 error_messages .append (trans ("i18n_data_training.datasource_cannot_be_none" ))
348352
349- # 如果有错误,添加到失败列表
350353 if error_messages :
351- # 返回原始的info对象,不包含转换后的ID
352354 failed_records .append ({
353- 'data' : info , # 直接返回原始传入的数据
355+ 'data' : info ,
354356 'errors' : error_messages
355357 })
356358 continue
357359
358- # 检查数据库中是否已存在重复记录
359- stmt = select (DataTraining .id ).where (
360- and_ (
361- DataTraining .question == info .question .strip (),
362- DataTraining .oid == oid
363- )
360+ # 创建处理后的DataTrainingInfo对象
361+ processed_info = DataTrainingInfo (
362+ question = info .question .strip (),
363+ description = info .description .strip (),
364+ datasource = datasource_id ,
365+ datasource_name = info .datasource_name ,
366+ advanced_application = advanced_application_id ,
367+ advanced_application_name = info .advanced_application_name ,
368+ enabled = info .enabled if info .enabled is not None else True
364369 )
365370
366- # 根据oid决定重复检查条件
367- if oid == 1 :
368- if datasource_id is not None and advanced_application_id is not None :
369- stmt = stmt .where (
370- or_ (
371- DataTraining .datasource == datasource_id ,
372- DataTraining .advanced_application == advanced_application_id
373- )
374- )
375- elif datasource_id is not None :
376- stmt = stmt .where (DataTraining .datasource == datasource_id )
377- elif advanced_application_id is not None :
378- stmt = stmt .where (DataTraining .advanced_application == advanced_application_id )
379- else :
380- # oid != 1时,只检查数据源
381- if datasource_id is not None :
382- stmt = stmt .where (DataTraining .datasource == datasource_id )
383-
384- exists = session .query (stmt .exists ()).scalar ()
385-
386- if exists :
387- # 返回原始的info对象
388- failed_records .append ({
389- 'data' : info , # 直接返回原始传入的数据
390- 'errors' : [trans ("i18n_data_training.exists_in_db" )]
391- })
392- continue
393-
394- # 验证通过,添加到有效记录
395- valid_records .append ({
396- 'info' : info ,
397- 'datasource_id' : datasource_id ,
398- 'advanced_application_id' : advanced_application_id
399- })
371+ valid_records .append (processed_info )
400372
401- # 批量插入有效记录
373+ # 使用事务处理有效记录
402374 if valid_records :
403- data_training_objects = []
404- for record in valid_records :
405- info = record ['info' ]
406- data_training = DataTraining (
407- question = info .question .strip (),
408- description = info .description .strip (),
409- oid = oid ,
410- datasource = record ['datasource_id' ],
411- advanced_application = record ['advanced_application_id' ] if oid == 1 else None , # 只有oid=1才设置高级应用
412- create_time = create_time ,
413- enabled = info .enabled if info .enabled is not None else True
414- )
415- data_training_objects .append (data_training )
416-
417- try :
418- # 批量插入
419- session .bulk_save_objects (data_training_objects , return_defaults = True )
420- session .commit ()
375+ for info in valid_records :
376+ try :
377+ # 直接复用create_training方法
378+ training_id = create_training (session , info , oid , trans )
379+ inserted_ids .append (training_id )
380+ success_count += 1
421381
422- # 获取插入的ID
423- for obj in data_training_objects :
424- if obj .id is not None : # 确保ID已经被赋值
425- inserted_ids .append (obj .id )
426- success_count += 1
427-
428- except Exception as e :
429- session .rollback ()
430- # 将所有的有效记录标记为失败
431- for record in valid_records :
432- # 返回原始的info对象
382+ except Exception as e :
383+ # 如果单条插入失败,回滚当前记录
384+ session .rollback ()
433385 failed_records .append ({
434- 'data' : record [ ' info' ], # 直接返回原始传入的数据
386+ 'data' : info ,
435387 'errors' : [str (e )]
436388 })
437- success_count = 0
438-
439- # 批量处理embedding
440- if success_count > 0 and inserted_ids :
441- run_save_data_training_embeddings (inserted_ids )
442389
443- # 返回结果,包含去重统计信息
444390 return {
445391 'success_count' : success_count ,
446392 'failed_records' : failed_records ,
0 commit comments