Skip to content

Commit be9ed60

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents db34b7f + 996927b commit be9ed60

File tree

5 files changed

+291
-219
lines changed

5 files changed

+291
-219
lines changed

backend/apps/chat/api/chat.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, rename_chat, \
77
delete_chat, get_chat_chart_data, get_chat_predict_data
88
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion
9-
from apps.chat.task.llm import LLMService, run_task, run_analysis_or_predict_task, run_recommend_questions_task
9+
from apps.chat.task.llm import LLMService
1010
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
1111

1212
router = APIRouter(tags=["Data Q&A"], prefix="/chat")
@@ -20,7 +20,8 @@ async def chats(session: SessionDep, current_user: CurrentUser):
2020
@router.get("/get/{chart_id}")
2121
async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant):
2222
try:
23-
return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user, current_assistant=current_assistant)
23+
return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user,
24+
current_assistant=current_assistant)
2425
except Exception as e:
2526
raise HTTPException(
2627
status_code=500,
@@ -81,7 +82,8 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
8182
status_code=500,
8283
detail=str(e)
8384
)
84-
85+
86+
8587
@router.post("/assistant/start")
8688
async def start_chat(session: SessionDep, current_user: CurrentUser):
8789
try:
@@ -94,7 +96,8 @@ async def start_chat(session: SessionDep, current_user: CurrentUser):
9496

9597

9698
@router.post("/recommend_questions/{chat_record_id}")
97-
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, current_assistant: CurrentAssistant):
99+
async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int,
100+
current_assistant: CurrentAssistant):
98101
try:
99102
record = session.get(ChatRecord, chat_record_id)
100103
if not record:
@@ -106,18 +109,20 @@ async def recommend_questions(session: SessionDep, current_user: CurrentUser, ch
106109

107110
llm_service = LLMService(session, current_user, request_question, current_assistant)
108111
llm_service.set_record(record)
112+
llm_service.run_recommend_questions_task_async()
109113
except Exception as e:
110114
traceback.print_exc()
111115
raise HTTPException(
112116
status_code=500,
113117
detail=str(e)
114118
)
115119

116-
return StreamingResponse(run_recommend_questions_task(llm_service), media_type="text/event-stream")
120+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
117121

118122

119123
@router.post("/question")
120-
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, current_assistant: CurrentAssistant):
124+
async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion,
125+
current_assistant: CurrentAssistant):
121126
"""Stream SQL analysis results
122127
123128
Args:
@@ -132,18 +137,20 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
132137
try:
133138
llm_service = LLMService(session, current_user, request_question, current_assistant)
134139
llm_service.init_record()
140+
llm_service.run_task_async()
135141
except Exception as e:
136142
traceback.print_exc()
137143
raise HTTPException(
138144
status_code=500,
139145
detail=str(e)
140146
)
141147

142-
return StreamingResponse(run_task(llm_service), media_type="text/event-stream")
148+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")
143149

144150

145151
@router.post("/record/{chat_record_id}/{action_type}")
146-
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str, current_assistant: CurrentAssistant):
152+
async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, chat_record_id: int, action_type: str,
153+
current_assistant: CurrentAssistant):
147154
if action_type != 'analysis' and action_type != 'predict':
148155
raise HTTPException(
149156
status_code=404,
@@ -167,12 +174,12 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
167174

168175
try:
169176
llm_service = LLMService(session, current_user, request_question, current_assistant)
177+
llm_service.run_analysis_or_predict_task_async(action_type, record)
170178
except Exception as e:
171179
traceback.print_exc()
172180
raise HTTPException(
173181
status_code=500,
174182
detail=str(e)
175183
)
176184

177-
return StreamingResponse(run_analysis_or_predict_task(llm_service, action_type, record),
178-
media_type="text/event-stream")
185+
return StreamingResponse(llm_service.await_result(), media_type="text/event-stream")

0 commit comments

Comments
 (0)