Skip to content

Commit 4d53a02

Browse files
committed
feat: import Sample SQL / Terminologies
1 parent 25d2fe1 commit 4d53a02

File tree

5 files changed

+191
-47
lines changed

5 files changed

+191
-47
lines changed

backend/apps/data_training/curd/data_training.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,11 @@ def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid:
145145
return _list
146146

147147

148-
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
148+
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans, skip_embedding: bool = False):
149149
"""
150150
创建单个数据训练记录
151+
Args:
152+
skip_embedding: 是否跳过embedding处理(用于批量插入)
151153
"""
152154
# 基本验证
153155
if not info.question or not info.question.strip():
@@ -203,8 +205,9 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
203205
session.refresh(data_training)
204206
session.commit()
205207

206-
# 处理embedding
207-
run_save_data_training_embeddings([data_training.id])
208+
# 处理embedding(批量插入时跳过)
209+
if not skip_embedding:
210+
run_save_data_training_embeddings([data_training.id])
208211

209212
return data_training.id
210213

@@ -247,11 +250,11 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
247250
raise Exception(trans("i18n_data_training.exists_in_db"))
248251

249252
stmt = update(DataTraining).where(and_(DataTraining.id == info.id)).values(
250-
question=info.question,
251-
description=info.description,
253+
question=info.question.strip(),
254+
description=info.description.strip(),
252255
datasource=info.datasource,
253-
enabled=info.enabled,
254256
advanced_application=info.advanced_application,
257+
enabled=info.enabled if info.enabled is not None else True
255258
)
256259
session.execute(stmt)
257260
session.commit()
@@ -374,8 +377,8 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
374377
if valid_records:
375378
for info in valid_records:
376379
try:
377-
# 直接复用create_training方法
378-
training_id = create_training(session, info, oid, trans)
380+
# 直接复用create_training方法,跳过embedding处理
381+
training_id = create_training(session, info, oid, trans, skip_embedding=True)
379382
inserted_ids.append(training_id)
380383
success_count += 1
381384

@@ -387,6 +390,15 @@ def batch_create_training(session: SessionDep, info_list: List[DataTrainingInfo]
387390
'errors': [str(e)]
388391
})
389392

393+
# 批量处理embedding(只在最后执行一次)
394+
if success_count > 0 and inserted_ids:
395+
try:
396+
run_save_data_training_embeddings(inserted_ids)
397+
except Exception as e:
398+
# 如果embedding处理失败,记录错误但不回滚数据
399+
print(f"Embedding processing failed: {str(e)}")
400+
# 可以选择将embedding失败的信息记录到日志或返回给调用方
401+
390402
return {
391403
'success_count': success_count,
392404
'failed_records': failed_records,

backend/apps/terminology/curd/terminology.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,12 @@ def get_all_terminology(session: SessionDep, name: Optional[str] = None, oid: Op
200200
return _list
201201

202202

203-
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
203+
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans,
204+
skip_embedding: bool = False):
204205
"""
205206
创建单个术语记录
207+
Args:
208+
skip_embedding: 是否跳过embedding处理(用于批量插入)
206209
"""
207210
# 基本验证
208211
if not info.word or not info.word.strip():
@@ -221,16 +224,16 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
221224
raise Exception(trans("i18n_terminology.datasource_cannot_be_none"))
222225

223226
parent = Terminology(
224-
word=info.word,
227+
word=info.word.strip(),
225228
create_time=create_time,
226-
description=info.description,
229+
description=info.description.strip(),
227230
oid=oid,
228231
specific_ds=specific_ds,
229232
enabled=info.enabled,
230233
datasource_ids=datasource_ids
231234
)
232235

233-
words = [info.word]
236+
words = [info.word.strip()]
234237
for child_word in info.other_words:
235238
# 先检查是否为空字符串
236239
if not child_word or child_word.strip() == "":
@@ -239,7 +242,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
239242
if child_word in words:
240243
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
241244
else:
242-
words.append(child_word)
245+
words.append(child_word.strip())
243246

244247
# 基础查询条件(word 和 oid 必须满足)
245248
base_query = and_(
@@ -288,7 +291,7 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
288291
child_list.append(
289292
Terminology(
290293
pid=parent.id,
291-
word=other_word,
294+
word=other_word.strip(),
292295
create_time=create_time,
293296
oid=oid,
294297
enabled=info.enabled,
@@ -303,8 +306,9 @@ def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
303306

304307
session.commit()
305308

306-
# 处理embedding
307-
run_save_terminology_embeddings([parent.id])
309+
# 处理embedding(批量插入时跳过)
310+
if not skip_embedding:
311+
run_save_terminology_embeddings([parent.id])
308312

309313
return parent.id
310314

@@ -380,19 +384,9 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
380384
# 基本验证
381385
if not info.word or not info.word.strip():
382386
error_messages.append(trans("i18n_terminology.word_cannot_be_empty"))
383-
failed_records.append({
384-
'data': info,
385-
'errors': error_messages
386-
})
387-
continue
388387

389388
if not info.description or not info.description.strip():
390389
error_messages.append(trans("i18n_terminology.description_cannot_be_empty"))
391-
failed_records.append({
392-
'data': info,
393-
'errors': error_messages
394-
})
395-
continue
396390

397391
# 根据specific_ds决定是否验证数据源
398392
specific_ds = info.specific_ds if info.specific_ds is not None else False
@@ -455,8 +449,8 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
455449
if valid_records:
456450
for info in valid_records:
457451
try:
458-
# 直接复用create_terminology方法
459-
terminology_id = create_terminology(session, info, oid, trans)
452+
# 直接复用create_terminology方法,跳过embedding处理
453+
terminology_id = create_terminology(session, info, oid, trans, skip_embedding=True)
460454
inserted_ids.append(terminology_id)
461455
success_count += 1
462456

@@ -468,6 +462,15 @@ def batch_create_terminology(session: SessionDep, info_list: List[TerminologyInf
468462
'errors': [str(e)]
469463
})
470464

465+
# 批量处理embedding(只在最后执行一次)
466+
if success_count > 0 and inserted_ids:
467+
try:
468+
run_save_terminology_embeddings(inserted_ids)
469+
except Exception as e:
470+
# 如果embedding处理失败,记录错误但不回滚数据
471+
print(f"Terminology embedding processing failed: {str(e)}")
472+
# 可以选择将embedding失败的信息记录到日志或返回给调用方
473+
471474
return {
472475
'success_count': success_count,
473476
'failed_records': failed_records,
@@ -492,12 +495,12 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
492495
if not datasource_ids:
493496
raise Exception(trans("i18n_terminology.datasource_cannot_be_none"))
494497

495-
words = [info.word]
498+
words = [info.word.strip()]
496499
for child in info.other_words:
497500
if child in words:
498501
raise Exception(trans("i18n_terminology.cannot_be_repeated"))
499502
else:
500-
words.append(child)
503+
words.append(child.strip())
501504

502505
# 基础查询条件(word 和 oid 必须满足)
503506
base_query = and_(
@@ -539,8 +542,8 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
539542
raise Exception(trans("i18n_terminology.exists_in_db"))
540543

541544
stmt = update(Terminology).where(and_(Terminology.id == info.id)).values(
542-
word=info.word,
543-
description=info.description,
545+
word=info.word.strip(),
546+
description=info.description.strip(),
544547
specific_ds=specific_ds,
545548
datasource_ids=datasource_ids,
546549
enabled=info.enabled,
@@ -553,16 +556,27 @@ def update_terminology(session: SessionDep, info: TerminologyInfo, oid: int, tra
553556
session.commit()
554557

555558
create_time = datetime.datetime.now()
556-
_list: List[Terminology] = []
559+
# 插入子记录(其他词)
560+
child_list = []
557561
if info.other_words:
558562
for other_word in info.other_words:
559563
if other_word.strip() == "":
560564
continue
561-
_list.append(
562-
Terminology(pid=info.id, word=other_word, create_time=create_time, oid=oid,
563-
specific_ds=specific_ds, datasource_ids=datasource_ids, enabled=info.enabled))
564-
session.bulk_save_objects(_list)
565-
session.flush()
565+
child_list.append(
566+
Terminology(
567+
pid=info.id,
568+
word=other_word.strip(),
569+
create_time=create_time,
570+
oid=oid,
571+
enabled=info.enabled,
572+
specific_ds=specific_ds,
573+
datasource_ids=datasource_ids
574+
)
575+
)
576+
577+
if child_list:
578+
session.bulk_save_objects(child_list)
579+
session.flush()
566580
session.commit()
567581

568582
# embedding

backend/locales/zh-CN.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@
7474
"prompt_word_name": "提示词名称",
7575
"prompt_word_content": "提示词内容",
7676
"effective_data_sources": "生效数据源",
77-
"all_data_sources": "所有数据源"
77+
"all_data_sources": "所有数据源",
78+
"name_cannot_be_empty": "名称不能为空",
79+
"prompt_cannot_be_empty": "提示词内容不能为空",
80+
"type_cannot_be_empty": "类型不能为空",
81+
"datasource_not_found": "找不到数据源",
82+
"datasource_cannot_be_none": "数据源不能为空",
7883
},
7984
"i18n_excel_export": {
8085
"data_is_empty": "表单数据为空,无法导出数据"

backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
"pyyaml (>=6.0.2,<7.0.0)",
4040
"fastapi-mcp (>=0.3.4,<0.4.0)",
4141
"tabulate>=0.9.0",
42-
"sqlbot-xpack>=0.0.3.45,<1.0.0",
42+
"sqlbot-xpack>=0.0.3.46,<1.0.0",
4343
"fastapi-cache2>=0.2.2",
4444
"sqlparse>=0.5.3",
4545
"redis>=6.2.0",

0 commit comments

Comments
 (0)