Skip to content

Commit fbcb147

Browse files
perf: Simple assistant
1 parent 7137a00 commit fbcb147

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

backend/apps/chat/api/chat.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
delete_chat, get_chat_chart_data, get_chat_predict_data
88
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
99
from apps.chat.task.llm import LLMService, run_task, run_analysis_or_predict_task, run_recommend_questions_task
10-
from common.core.deps import SessionDep, CurrentUser
10+
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
1111

1212
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
1313

@@ -94,7 +94,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):
9494

9595

9696
@router.post("/recommend_questions/{chat_record_id}")
97-
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int):
97+
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, current_assistant: CurrentAssistant):
9898
try:
9999
record = session.query(ChatRecord).get(chat_record_id)
100100
if not record:
@@ -104,7 +104,7 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
104104
)
105105
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
106106

107-
llm_service = LLMService(session, current_user, request_question)
107+
llm_service = LLMService(session, current_user, request_question, current_assistant)
108108
llm_service.set_record(record)
109109
except Exception as e:
110110
traceback.print_exc()
@@ -117,7 +117,7 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
117117

118118

119119
@router.post("/question")
120-
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion):
120+
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, current_assistant: CurrentAssistant):
121121
"""Stream SQL analysis results
122122
123123
Args:
@@ -130,7 +130,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
130130
"""
131131

132132
try:
133-
llm_service = LLMService(session, current_user, request_question)
133+
llm_service = LLMService(session, current_user, request_question, current_assistant)
134134
llm_service.init_record()
135135
except Exception as e:
136136
traceback.print_exc()
@@ -143,7 +143,7 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
143143

144144

145145
@router.post("/record/{chat_record_id}/{action_type}")
146-
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str):
146+
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, current_assistant: CurrentAssistant):
147147
if action_type != 'analysis' and action_type != 'predict':
148148
raise HTTPException(
149149
status_code=404,
@@ -166,7 +166,7 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
166166
request_question = ChatQuestion(chat_id=record.chat_id, question='')
167167

168168
try:
169-
llm_service = LLMService(session, current_user, request_question)
169+
llm_service = LLMService(session, current_user, request_question, current_assistant)
170170
except Exception as e:
171171
traceback.print_exc()
172172
raise HTTPException(

backend/apps/chat/task/llm.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,21 @@ def select_datasource(self):
353353
if self.current_assistant:
354354
_ds_list = get_assistant_ds(llm_service=self)
355355
else:
356-
_ds_list = self.session.exec(select(CoreDatasource).options(
357-
load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all()
356+
oid: str = self.current_user.oid
357+
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(CoreDatasource.oid == oid)
358+
_ds_list = [
359+
{
360+
"id": ds.id,
361+
"name": ds.name,
362+
"description": ds.description
363+
}
364+
for ds in self.session.exec(stmt)
365+
]
366+
""" _ds_list = self.session.exec(select(CoreDatasource).options(
367+
load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all() """
358368
_ds_list_dict = []
359369
for _ds in _ds_list:
360-
_ds_list_dict.append({'id': _ds[0].id, 'name': _ds[0].name, 'description': _ds[0].description})
370+
_ds_list_dict.append(_ds)
361371
datasource_msg.append(
362372
HumanMessage(self.chat_question.datasource_user_question(orjson.dumps(_ds_list_dict).decode())))
363373

backend/apps/system/crud/assistant.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,22 @@ def get_assistant_ds(llm_service) -> list[dict]:
3333
config: dict[any] = json.loads(configuration)
3434
oid: str = config['oid']
3535
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(CoreDatasource.oid == oid)
36-
private_list:list[int] = config['private_list']
36+
private_list:list[int] = config.get('private_list') or None
3737
if private_list:
3838
stmt.where(~CoreDatasource.id.in_(private_list))
39-
db_ds_list = session.exec(stmt).all()
39+
db_ds_list = session.exec(stmt)
40+
41+
result_list = [
42+
{
43+
"id": ds.id,
44+
"name": ds.name,
45+
"description": ds.description
46+
}
47+
for ds in db_ds_list
48+
]
49+
4050
# filter private ds if offline
41-
return db_ds_list
51+
return result_list
4252
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant, llm_service.assistant_certificate)
4353
llm_service.out_ds_instance = out_ds_instance
4454
dslist = out_ds_instance.get_simple_ds_list()

0 commit comments

Comments
 (0)