Skip to content

Commit b97fb5f

Browse files
committed
improve: improve chat record select
1 parent 8aa745f commit b97fb5f

File tree

8 files changed

+258
-125
lines changed

8 files changed

+258
-125
lines changed

backend/apps/chat/api/chat.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import pandas as pd
77
from fastapi import APIRouter, HTTPException
88
from fastapi.responses import StreamingResponse
9+
from sqlalchemy import and_, select
910

1011
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
11-
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data
12+
delete_chat, get_chat_chart_data, get_chat_predict_data, get_chat_with_records_with_data, get_chat_record_by_id
1213
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, ExcelData
1314
from apps.chat.task.llm import LLMService
1415
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
@@ -23,48 +24,37 @@ async def chats(session: SessionDep, current_user: CurrentUser):
2324

2425
@router.get("/get/{chart_id}")
2526
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant):
26-
try:
27+
def inner():
2728
return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user,
2829
current_assistant=current_assistant)
29-
except Exception as e:
30-
raise HTTPException(
31-
status_code=500,
32-
detail=str(e)
33-
)
30+
31+
return await asyncio.to_thread(inner)
3432

3533

3634
@router.get("/get/with_data/{chart_id}")
37-
async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant):
38-
try:
35+
async def get_chat_with_data(session: SessionDep, current_user: CurrentUser, chart_id: int,
36+
current_assistant: CurrentAssistant):
37+
def inner():
3938
return get_chat_with_records_with_data(chart_id=chart_id, session=session, current_user=current_user,
4039
current_assistant=current_assistant)
41-
except Exception as e:
42-
raise HTTPException(
43-
status_code=500,
44-
detail=str(e)
45-
)
40+
41+
return await asyncio.to_thread(inner)
4642

4743

4844
@router.get("/record/get/{chart_record_id}/data")
4945
async def chat_record_data(session: SessionDep, chart_record_id: int):
50-
try:
46+
def inner():
5147
return get_chat_chart_data(chart_record_id=chart_record_id, session=session)
52-
except Exception as e:
53-
raise HTTPException(
54-
status_code=500,
55-
detail=str(e)
56-
)
48+
49+
return await asyncio.to_thread(inner)
5750

5851

5952
@router.get("/record/get/{chart_record_id}/predict_data")
6053
async def chat_predict_data(session: SessionDep, chart_record_id: int):
61-
try:
54+
def inner():
6255
return get_chat_predict_data(chart_record_id=chart_record_id, session=session)
63-
except Exception as e:
64-
raise HTTPException(
65-
status_code=500,
66-
detail=str(e)
67-
)
56+
57+
return await asyncio.to_thread(inner)
6858

6959

7060
@router.post("/rename")
@@ -115,15 +105,16 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):
115105
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
116106
current_assistant: CurrentAssistant):
117107
try:
118-
record = session.get(ChatRecord, chat_record_id)
108+
record = get_chat_record_by_id(session, chat_record_id)
109+
119110
if not record:
120111
raise HTTPException(
121112
status_code=400,
122113
detail=f"Chat record with id {chat_record_id} not found"
123114
)
124115
request_question = ChatQuestion(chat_id=record.chat_id, question=record.question if record.question else '')
125116

126-
llm_service = LLMService(current_user, request_question, current_assistant)
117+
llm_service = LLMService(current_user, request_question, current_assistant, True)
127118
llm_service.set_record(record)
128119
llm_service.run_recommend_questions_task_async()
129120
except Exception as e:
@@ -172,8 +163,17 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
172163
status_code=404,
173164
detail="Not Found"
174165
)
166+
record: ChatRecord | None = None
167+
168+
stmt = select(ChatRecord.id, ChatRecord.question, ChatRecord.chat_id, ChatRecord.datasource, ChatRecord.engine_type,
169+
ChatRecord.ai_modal_id, ChatRecord.create_by, ChatRecord.chart, ChatRecord.data).where(
170+
and_(ChatRecord.id == chat_record_id))
171+
result = session.execute(stmt)
172+
for r in result:
173+
record = ChatRecord(id=r.id, question=r.question, chat_id=r.chat_id, datasource=r.datasource,
174+
engine_type=r.engine_type, ai_modal_id=r.ai_modal_id, create_by=r.create_by, chart=r.chart,
175+
data=r.data)
175176

176-
record = session.query(ChatRecord).get(chat_record_id)
177177
if not record:
178178
raise HTTPException(
179179
status_code=400,

0 commit comments

Comments
 (0)