Skip to content

Commit d7462da

Browse files
committed
feat: import Sample SQL
1 parent e321abd commit d7462da

File tree

1 file changed

+62
-116
lines changed

1 file changed

+62
-116
lines changed

backend/apps/data_training/curd/data_training.py

Lines changed: 62 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid:
146146

147147

148148
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
149+
"""
150+
创建单个数据训练记录
151+
"""
149152
# 基本验证
150153
if not info.question or not info.question.strip():
151154
raise Exception(trans("i18n_data_training.question_cannot_be_empty"))
@@ -154,45 +157,56 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
154157
raise Exception(trans("i18n_data_training.description_cannot_be_empty"))
155158

156159
create_time = datetime.datetime.now()
160+
161+
# 检查数据源和高级应用不能同时为空
157162
if info.datasource is None and info.advanced_application is None:
158163
if oid == 1:
159164
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
160165
else:
161166
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
162167

163-
parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
164-
datasource=info.datasource, enabled=info.enabled,
165-
advanced_application=info.advanced_application)
166-
167-
stmt = select(DataTraining.id).where(and_(DataTraining.question == info.question, DataTraining.oid == oid))
168+
# 检查重复记录
169+
stmt = select(DataTraining.id).where(
170+
and_(DataTraining.question == info.question.strip(), DataTraining.oid == oid)
171+
)
168172

169173
if info.datasource is not None and info.advanced_application is not None:
170174
stmt = stmt.where(
171-
or_(DataTraining.datasource == info.datasource,
172-
DataTraining.advanced_application == info.advanced_application))
175+
or_(
176+
DataTraining.datasource == info.datasource,
177+
DataTraining.advanced_application == info.advanced_application
178+
)
179+
)
173180
elif info.datasource is not None and info.advanced_application is None:
174-
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
181+
stmt = stmt.where(DataTraining.datasource == info.datasource)
175182
elif info.datasource is None and info.advanced_application is not None:
176-
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))
183+
stmt = stmt.where(DataTraining.advanced_application == info.advanced_application)
177184

178185
exists = session.query(stmt.exists()).scalar()
179186

180187
if exists:
181188
raise Exception(trans("i18n_data_training.exists_in_db"))
182189

183-
result = DataTraining(**parent.model_dump())
190+
# 创建记录
191+
data_training = DataTraining(
192+
question=info.question.strip(),
193+
description=info.description.strip(),
194+
oid=oid,
195+
datasource=info.datasource,
196+
advanced_application=info.advanced_application,
197+
create_time=create_time,
198+
enabled=info.enabled if info.enabled is not None else True
199+
)
184200

185-
session.add(parent)
201+
session.add(data_training)
186202
session.flush()
187-
session.refresh(parent)
188-
189-
result.id = parent.id
203+
session.refresh(data_training)
190204
session.commit()
191205

192-
# embedding
193-
run_save_data_training_embeddings([result.id])
206+
# 处理embedding
207+
run_save_data_training_embeddings([data_training.id])
194208

195-
return result.id
209+
return data_training.id
196210

197211

198212
def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
@@ -250,14 +264,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
250264

251265
def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo], oid: int, trans: Trans):
252266
"""
253-
批量创建数据训练记录
254-
Args:
255-
session: 数据库会话
256-
info_list: DataTrainingInfo对象列表
257-
oid: 组织ID
258-
trans: 翻译对象
259-
Returns:
260-
dict: 包含成功数量、失败记录和统计信息的结果字典
267+
批量创建数据训练记录(复用单条插入逻辑)
261268
"""
262269
if not info_list:
263270
return {
@@ -268,48 +275,45 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
268275
'deduplicated_count': 0
269276
}
270277

271-
create_time = datetime.datetime.now()
272-
failed_records = [] # 存储失败的数据和原因
278+
failed_records = []
273279
success_count = 0
274-
inserted_ids = [] # 存储成功插入的ID
280+
inserted_ids = []
275281

276282
# 第一步:数据去重
277283
unique_records = {}
278-
duplicate_records = [] # 存储重复的数据
284+
duplicate_records = []
279285

280286
for info in info_list:
281-
# 创建唯一标识:问题 + 数据源名称 + 高级应用名称
287+
# 创建唯一标识
282288
unique_key = (
283289
info.question.strip().lower() if info.question else "",
284290
info.datasource_name.strip().lower() if info.datasource_name else "",
285291
info.advanced_application_name.strip().lower() if info.advanced_application_name else ""
286292
)
287293

288294
if unique_key in unique_records:
289-
# 如果是重复数据,记录到重复列表中
290295
duplicate_records.append(info)
291296
else:
292297
unique_records[unique_key] = info
293298

294299
# 将去重后的数据转换为列表
295300
deduplicated_list = list(unique_records.values())
296301

297-
# 预加载数据源名称到ID的映射(CoreDatasource需要判断oid)
302+
# 预加载数据源和高级应用名称到ID的映射
298303
datasource_name_to_id = {}
299304
datasource_stmt = select(CoreDatasource.id, CoreDatasource.name).where(CoreDatasource.oid == oid)
300305
datasource_result = session.execute(datasource_stmt).all()
301306
for ds in datasource_result:
302307
datasource_name_to_id[ds.name.strip()] = ds.id
303308

304-
# 只有在oid=1时才预加载高级应用名称到ID的映射
305309
assistant_name_to_id = {}
306310
if oid == 1:
307311
assistant_stmt = select(AssistantModel.id, AssistantModel.name).where(AssistantModel.type == 1)
308312
assistant_result = session.execute(assistant_stmt).all()
309313
for assistant in assistant_result:
310314
assistant_name_to_id[assistant.name.strip()] = assistant.id
311315

312-
# 验证和准备数据
316+
# 验证和转换数据
313317
valid_records = []
314318
for info in deduplicated_list:
315319
error_messages = []
@@ -321,15 +325,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
321325
if not info.description or not info.description.strip():
322326
error_messages.append(trans("i18n_data_training.description_cannot_be_empty"))
323327

324-
# 数据源验证
328+
# 数据源验证和转换
325329
datasource_id = None
326330
if info.datasource_name and info.datasource_name.strip():
327331
if info.datasource_name.strip() in datasource_name_to_id:
328332
datasource_id = datasource_name_to_id[info.datasource_name.strip()]
329333
else:
330334
error_messages.append(trans("i18n_data_training.datasource_not_found").format(info.datasource_name))
331335

332-
# 高级应用验证(只有在oid=1时才需要)
336+
# 高级应用验证和转换
333337
advanced_application_id = None
334338
if oid == 1 and info.advanced_application_name and info.advanced_application_name.strip():
335339
if info.advanced_application_name.strip() in assistant_name_to_id:
@@ -346,101 +350,43 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
346350
if not datasource_id:
347351
error_messages.append(trans("i18n_data_training.datasource_cannot_be_none"))
348352

349-
# 如果有错误,添加到失败列表
350353
if error_messages:
351-
# 返回原始的info对象,不包含转换后的ID
352354
failed_records.append({
353-
'data': info, # 直接返回原始传入的数据
355+
'data': info,
354356
'errors': error_messages
355357
})
356358
continue
357359

358-
# 检查数据库中是否已存在重复记录
359-
stmt = select(DataTraining.id).where(
360-
and_(
361-
DataTraining.question == info.question.strip(),
362-
DataTraining.oid == oid
363-
)
360+
# 创建处理后的DataTrainingInfo对象
361+
processed_info = DataTrainingInfo(
362+
question=info.question.strip(),
363+
description=info.description.strip(),
364+
datasource=datasource_id,
365+
datasource_name=info.datasource_name,
366+
advanced_application=advanced_application_id,
367+
advanced_application_name=info.advanced_application_name,
368+
enabled=info.enabled if info.enabled is not None else True
364369
)
365370

366-
# 根据oid决定重复检查条件
367-
if oid == 1:
368-
if datasource_id is not None and advanced_application_id is not None:
369-
stmt = stmt.where(
370-
or_(
371-
DataTraining.datasource == datasource_id,
372-
DataTraining.advanced_application == advanced_application_id
373-
)
374-
)
375-
elif datasource_id is not None:
376-
stmt = stmt.where(DataTraining.datasource == datasource_id)
377-
elif advanced_application_id is not None:
378-
stmt = stmt.where(DataTraining.advanced_application == advanced_application_id)
379-
else:
380-
# oid != 1时,只检查数据源
381-
if datasource_id is not None:
382-
stmt = stmt.where(DataTraining.datasource == datasource_id)
383-
384-
exists = session.query(stmt.exists()).scalar()
385-
386-
if exists:
387-
# 返回原始的info对象
388-
failed_records.append({
389-
'data': info, # 直接返回原始传入的数据
390-
'errors': [trans("i18n_data_training.exists_in_db")]
391-
})
392-
continue
393-
394-
# 验证通过,添加到有效记录
395-
valid_records.append({
396-
'info': info,
397-
'datasource_id': datasource_id,
398-
'advanced_application_id': advanced_application_id
399-
})
371+
valid_records.append(processed_info)
400372

401-
# 批量插入有效记录
373+
# 使用事务处理有效记录
402374
if valid_records:
403-
data_training_objects = []
404-
for record in valid_records:
405-
info = record['info']
406-
data_training = DataTraining(
407-
question=info.question.strip(),
408-
description=info.description.strip(),
409-
oid=oid,
410-
datasource=record['datasource_id'],
411-
advanced_application=record['advanced_application_id'] if oid == 1 else None, # 只有oid=1才设置高级应用
412-
create_time=create_time,
413-
enabled=info.enabled if info.enabled is not None else True
414-
)
415-
data_training_objects.append(data_training)
416-
417-
try:
418-
# 批量插入
419-
session.bulk_save_objects(data_training_objects, return_defaults=True)
420-
session.commit()
375+
for info in valid_records:
376+
try:
377+
# 直接复用create_training方法
378+
training_id = create_training(session, info, oid, trans)
379+
inserted_ids.append(training_id)
380+
success_count += 1
421381

422-
# 获取插入的ID
423-
for obj in data_training_objects:
424-
if obj.id is not None: # 确保ID已经被赋值
425-
inserted_ids.append(obj.id)
426-
success_count += 1
427-
428-
except Exception as e:
429-
session.rollback()
430-
# 将所有的有效记录标记为失败
431-
for record in valid_records:
432-
# 返回原始的info对象
382+
except Exception as e:
383+
# 如果单条插入失败,回滚当前记录
384+
session.rollback()
433385
failed_records.append({
434-
'data': record['info'], # 直接返回原始传入的数据
386+
'data': info,
435387
'errors': [str(e)]
436388
})
437-
success_count = 0
438-
439-
# 批量处理embedding
440-
if success_count > 0 and inserted_ids:
441-
run_save_data_training_embeddings(inserted_ids)
442389

443-
# 返回结果,包含去重统计信息
444390
return {
445391
'success_count': success_count,
446392
'failed_records': failed_records,

0 commit comments

Comments
 (0)