Skip to content

Commit e5b9dc9

Browse files
committed
feat: terminology settings add datasource
#127
1 parent f49fd3a commit e5b9dc9

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

backend/apps/chat/task/llm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def generate_analysis(self):
241241
self.chat_question.data = orjson.dumps(data.get('data')).decode()
242242
analysis_msg: List[Union[BaseMessage, dict[str, Any]]] = []
243243

244+
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
244245
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
245-
self.current_user.oid)
246+
self.current_user.oid, ds_id)
246247

247248
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
248249
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
@@ -504,7 +505,8 @@ def select_datasource(self):
504505
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
505506
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
506507

507-
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid)
508+
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, oid,
509+
ds_id)
508510
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id,
509511
oid)
510512

@@ -897,7 +899,7 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
897899
oid = self.ds.oid if isinstance(self.ds, CoreDatasource) else 1
898900
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
899901
self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question,
900-
oid)
902+
oid, ds_id)
901903
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
902904
ds_id, oid)
903905

backend/apps/terminology/curd/terminology.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import logging
33
import traceback
4-
from typing import List, Optional
4+
from typing import List, Optional, Any
55
from xml.dom.minidom import parseString
66

77
import dicttoxml
@@ -367,17 +367,22 @@ def save_embeddings(session: Session, ids: List[int]):
367367
embedding_sql = f"""
368368
SELECT id, pid, word, similarity
369369
FROM
370-
(SELECT id, pid, word, oid,
370+
(SELECT id, pid, word, oid, specific_ds, datasource_ids,
371371
( 1 - (embedding <=> :embedding_array) ) AS similarity
372372
FROM terminology AS child
373373
) TEMP
374-
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} and oid = :oid
374+
WHERE similarity > {settings.EMBEDDING_TERMINOLOGY_SIMILARITY} AND oid = :oid
375+
AND (
376+
(:datasource IS NULL AND (specific_ds = false OR specific_ds IS NULL))
377+
OR
378+
(:datasource IS NOT NULL AND ((specific_ds = false OR specific_ds IS NULL) OR (specific_ds = true AND datasource_ids IS NOT NULL AND datasource_ids @> jsonb_build_array(:datasource))))
379+
)
375380
ORDER BY similarity DESC
376381
LIMIT {settings.EMBEDDING_TERMINOLOGY_TOP_COUNT}
377382
"""
378383

379384

380-
def select_terminology_by_word(session: SessionDep, word: str, oid: int):
385+
def select_terminology_by_word(session: SessionDep, word: str, oid: int, datasource: int = None):
381386
if word.strip() == "":
382387
return []
383388

@@ -394,7 +399,26 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):
394399
)
395400
)
396401

397-
results = session.execute(stmt, {'sentence': word}).fetchall()
402+
if datasource is not None:
403+
stmt = stmt.where(
404+
or_(
405+
or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)),
406+
and_(
407+
Terminology.specific_ds == True,
408+
Terminology.datasource_ids.isnot(None),
409+
text("datasource_ids @> jsonb_build_array(:datasource)")
410+
)
411+
)
412+
)
413+
else:
414+
stmt = stmt.where(or_(Terminology.specific_ds == False, Terminology.specific_ds.is_(None)))
415+
416+
# 执行查询
417+
params: dict[str, Any] = {'sentence': word}
418+
if datasource is not None:
419+
params['datasource'] = datasource
420+
421+
results = session.execute(stmt, params).fetchall()
398422

399423
for row in results:
400424
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
@@ -405,7 +429,8 @@ def select_terminology_by_word(session: SessionDep, word: str, oid: int):
405429

406430
embedding = model.embed_query(word)
407431

408-
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid})
432+
results = session.execute(text(embedding_sql), {'embedding_array': str(embedding), 'oid': oid,
433+
'datasource': datasource}).fetchall()
409434

410435
for row in results:
411436
_list.append(Terminology(id=row.id, word=row.word, pid=row.pid))
@@ -481,10 +506,11 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'terminologies') -> str:
481506
return pretty_xml
482507

483508

484-
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1) -> str:
509+
def get_terminology_template(session: SessionDep, question: str, oid: Optional[int] = 1,
510+
datasource: Optional[int] = None) -> str:
485511
if not oid:
486512
oid = 1
487-
_results = select_terminology_by_word(session, question, oid)
513+
_results = select_terminology_by_word(session, question, oid, datasource)
488514
if _results and len(_results) > 0:
489515
terminology = to_xml_string(_results)
490516
template = get_base_terminology_template().format(terminologies=terminology)

0 commit comments

Comments
 (0)