Skip to content

Commit 7ad8d82

Browse files
committed
feat: Vector retrieval matches datasource
1 parent 632076d commit 7ad8d82

File tree

3 files changed

+51
-47
lines changed

3 files changed

+51
-47
lines changed

backend/apps/chat/task/llm.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -413,50 +413,50 @@ def select_datasource(self, _session: Session):
413413
if not ignore_auto_select:
414414
if settings.TABLE_EMBEDDING_ENABLED and (
415415
not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)):
416-
ds = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance,
416+
_ds_list = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance,
417417
self.chat_question.question, self.current_assistant)
418-
yield {'content': '{"id":' + str(ds.get('id')) + '}'}
419-
else:
420-
_ds_list_dict = []
421-
for _ds in _ds_list:
422-
_ds_list_dict.append(_ds)
423-
datasource_msg.append(
424-
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
425-
426-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=_session,
427-
ai_modal_id=self.chat_question.ai_modal_id,
428-
ai_modal_name=self.chat_question.ai_modal_name,
429-
operate=OperationEnum.CHOOSE_DATASOURCE,
430-
record_id=self.record.id,
431-
full_message=[{'type': msg.type,
432-
'content': msg.content}
433-
for
434-
msg in datasource_msg])
435-
436-
token_usage = {}
437-
res = process_stream(self.llm.stream(datasource_msg), token_usage)
438-
for chunk in res:
439-
if chunk.get('content'):
440-
full_text += chunk.get('content')
441-
if chunk.get('reasoning_content'):
442-
full_thinking_text += chunk.get('reasoning_content')
443-
yield chunk
444-
datasource_msg.append(AIMessage(full_text))
445-
446-
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=_session,
447-
log=self.current_logs[
448-
OperationEnum.CHOOSE_DATASOURCE],
449-
full_message=[
450-
{'type': msg.type,
451-
'content': msg.content}
452-
for msg in datasource_msg],
453-
reasoning_content=full_thinking_text,
454-
token_usage=token_usage)
455-
456-
json_str = extract_nested_json(full_text)
457-
if json_str is None:
458-
raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
459-
ds = orjson.loads(json_str)
418+
# yield {'content': '{"id":' + str(ds.get('id')) + '}'}
419+
420+
_ds_list_dict = []
421+
for _ds in _ds_list:
422+
_ds_list_dict.append(_ds)
423+
datasource_msg.append(
424+
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
425+
426+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = start_log(session=_session,
427+
ai_modal_id=self.chat_question.ai_modal_id,
428+
ai_modal_name=self.chat_question.ai_modal_name,
429+
operate=OperationEnum.CHOOSE_DATASOURCE,
430+
record_id=self.record.id,
431+
full_message=[{'type': msg.type,
432+
'content': msg.content}
433+
for
434+
msg in datasource_msg])
435+
436+
token_usage = {}
437+
res = process_stream(self.llm.stream(datasource_msg), token_usage)
438+
for chunk in res:
439+
if chunk.get('content'):
440+
full_text += chunk.get('content')
441+
if chunk.get('reasoning_content'):
442+
full_thinking_text += chunk.get('reasoning_content')
443+
yield chunk
444+
datasource_msg.append(AIMessage(full_text))
445+
446+
self.current_logs[OperationEnum.CHOOSE_DATASOURCE] = end_log(session=_session,
447+
log=self.current_logs[
448+
OperationEnum.CHOOSE_DATASOURCE],
449+
full_message=[
450+
{'type': msg.type,
451+
'content': msg.content}
452+
for msg in datasource_msg],
453+
reasoning_content=full_thinking_text,
454+
token_usage=token_usage)
455+
456+
json_str = extract_nested_json(full_text)
457+
if json_str is None:
458+
raise SingleMessageError(f'Cannot parse datasource from answer: {full_text}')
459+
ds = orjson.loads(json_str)
460460

461461
_error: Exception | None = None
462462
_datasource: int | None = None

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from apps.datasource.embedding.utils import cosine_similarity
1010
from apps.datasource.models.datasource import CoreDatasource
1111
from apps.system.crud.assistant import AssistantOutDs
12+
from common.core.config import settings
1213
from common.core.deps import CurrentAssistant
1314
from common.core.deps import SessionDep, CurrentUser
1415
from common.utils.utils import SQLBotLogUtil
@@ -45,8 +46,9 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
4546
[{"id": ele.get("id"), "name": ele.get("ds").name,
4647
"cosine_similarity": ele.get("cosine_similarity")}
4748
for ele in _list]))
48-
ds = _list[0].get('ds')
49-
return {"id": ds.id, "name": ds.name, "description": ds.description}
49+
ds_l = _list[:settings.DS_EMBEDDING_COUNT]
50+
return [{"id": obj.get('ds').id, "name": obj.get('ds').name, "description": obj.get('ds').description}
51+
for obj in ds_l]
5052
except Exception:
5153
traceback.print_exc()
5254
else:
@@ -81,8 +83,9 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
8183
[{"id": ele.get("id"), "name": ele.get("ds").name,
8284
"cosine_similarity": ele.get("cosine_similarity")}
8385
for ele in _list]))
84-
ds = _list[0].get('ds')
85-
return {"id": ds.id, "name": ds.name, "description": ds.description}
86+
ds_l = _list[:settings.DS_EMBEDDING_COUNT]
87+
return [{"id": obj.get('ds').id, "name": obj.get('ds').name, "description": obj.get('ds').description}
88+
for obj in ds_l]
8689
except Exception:
8790
traceback.print_exc()
8891
return _list

backend/common/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn | str:
109109

110110
TABLE_EMBEDDING_ENABLED: bool = True
111111
TABLE_EMBEDDING_COUNT: int = 10
112+
DS_EMBEDDING_COUNT: int = 10
112113

113114

114115
settings = Settings() # type: ignore

0 commit comments

Comments
 (0)