88from fastapi import APIRouter , HTTPException , Path
99from fastapi .responses import StreamingResponse
1010from sqlalchemy import and_ , select
11+ from starlette .responses import JSONResponse
1112
1213from apps .chat .curd .chat import list_chats , get_chat_with_records , create_chat , rename_chat , \
1314 delete_chat , get_chat_chart_data , get_chat_predict_data , get_chat_with_records_with_data , get_chat_record_by_id , \
1415 format_json_data , format_json_list_data , get_chart_config , list_recent_questions
1516from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , ChatQuestion , AxisObj , QuickCommand , \
16- ChatInfo , Chat
17+ ChatInfo , Chat , ChatFinishStep
1718from apps .chat .task .llm import LLMService
1819from apps .swagger .i18n import PLACEHOLDER_PREFIX
1920from apps .system .schemas .permission import SqlbotPermission , require_permissions
@@ -166,11 +167,18 @@ def find_base_question(record_id: int, session: SessionDep):
166167@require_permissions (permission = SqlbotPermission (type = 'chat' , keyExpression = "request_question.chat_id" ))
167168async def question_answer (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ,
168169 current_assistant : CurrentAssistant ):
170+ return await question_answer_inner (session , current_user , request_question , current_assistant , embedding = True )
171+
172+
173+ async def question_answer_inner (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ,
174+ current_assistant : Optional [CurrentAssistant ] = None , in_chat : bool = True ,
175+ stream : bool = True ,
176+ finish_step : ChatFinishStep = ChatFinishStep .GENERATE_CHART , embedding : bool = False ):
169177 try :
170178 command , text_before_command , record_id , warning_info = parse_quick_command (request_question .question )
171179 if command :
172- # todo 暂不支持分析和预测,需要改造前端
173- if command == QuickCommand .ANALYSIS or command == QuickCommand .PREDICT_DATA :
180+ # todo 对话界面下, 暂不支持分析和预测,需要改造前端
181+ if in_chat and ( command == QuickCommand .ANALYSIS or command == QuickCommand .PREDICT_DATA ) :
174182 raise Exception (f'Command: { command .value } temporary not supported' )
175183
176184 if record_id is not None :
@@ -221,53 +229,83 @@ async def question_answer(session: SessionDep, current_user: CurrentUser, reques
221229 if command == QuickCommand .REGENERATE :
222230 request_question .question = text_before_command
223231 request_question .regenerate_record_id = rec_id
224- return await stream_sql (session , current_user , request_question , current_assistant )
232+ return await stream_sql (session , current_user , request_question , current_assistant , in_chat , stream ,
233+ finish_step , embedding )
225234
226235 elif command == QuickCommand .ANALYSIS :
227- return await analysis_or_predict (session , current_user , rec_id , 'analysis' , current_assistant )
236+ return await analysis_or_predict (session , current_user , rec_id , 'analysis' , current_assistant , in_chat , stream )
228237
229238 elif command == QuickCommand .PREDICT_DATA :
230- return await analysis_or_predict (session , current_user , rec_id , 'predict' , current_assistant )
239+ return await analysis_or_predict (session , current_user , rec_id , 'predict' , current_assistant , in_chat , stream )
231240 else :
232241 raise Exception (f'Unknown command: { command .value } ' )
233242 else :
234- return await stream_sql (session , current_user , request_question , current_assistant )
243+ return await stream_sql (session , current_user , request_question , current_assistant , in_chat , stream ,
244+ finish_step , embedding )
235245 except Exception as e :
236246 traceback .print_exc ()
237247
238- def _err (_e : Exception ):
239- yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
248+ if stream :
249+ def _err (_e : Exception ):
250+ yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
240251
241- return StreamingResponse (_err (e ), media_type = "text/event-stream" )
252+ return StreamingResponse (_err (e ), media_type = "text/event-stream" )
253+ else :
254+ return JSONResponse (
255+ content = {'message' : str (e )},
256+ status_code = 500 ,
257+ )
242258
243259
244260async def stream_sql (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ,
245- current_assistant : CurrentAssistant ):
261+ current_assistant : Optional [CurrentAssistant ] = None , in_chat : bool = True , stream : bool = True ,
262+ finish_step : ChatFinishStep = ChatFinishStep .GENERATE_CHART , embedding : bool = False ):
246263 try :
247264 llm_service = await LLMService .create (session , current_user , request_question , current_assistant ,
248- embedding = True )
265+ embedding = embedding )
249266 llm_service .init_record (session = session )
250- llm_service .run_task_async ()
267+ llm_service .run_task_async (in_chat = in_chat , stream = stream , finish_step = finish_step )
251268 except Exception as e :
252269 traceback .print_exc ()
253270
254- def _err (_e : Exception ):
255- yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
256-
257- return StreamingResponse (_err (e ), media_type = "text/event-stream" )
271+ if stream :
272+ def _err (_e : Exception ):
273+ yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
258274
259- return StreamingResponse (llm_service .await_result (), media_type = "text/event-stream" )
275+ return StreamingResponse (_err (e ), media_type = "text/event-stream" )
276+ else :
277+ return JSONResponse (
278+ content = {'message' : str (e )},
279+ status_code = 500 ,
280+ )
281+ if stream :
282+ return StreamingResponse (llm_service .await_result (), media_type = "text/event-stream" )
283+ else :
284+ res = llm_service .await_result ()
285+ raw_data = {}
286+ for chunk in res :
287+ if chunk :
288+ raw_data = chunk
289+ status_code = 200
290+ if not raw_data .get ('success' ):
291+ status_code = 500
292+
293+ return JSONResponse (
294+ content = raw_data ,
295+ status_code = status_code ,
296+ )
260297
261298
262299@router .post ("/record/{chat_record_id}/{action_type}" , summary = f"{ PLACEHOLDER_PREFIX } analysis_or_predict" )
263300async def analysis_or_predict_question (session : SessionDep , current_user : CurrentUser ,
264301 current_assistant : CurrentAssistant , chat_record_id : int ,
265- action_type : str = Path (..., description = f"{ PLACEHOLDER_PREFIX } analysis_or_predict_action_type" )):
302+ action_type : str = Path (...,
303+ description = f"{ PLACEHOLDER_PREFIX } analysis_or_predict_action_type" )):
266304 return await analysis_or_predict (session , current_user , chat_record_id , action_type , current_assistant )
267305
268306
269307async def analysis_or_predict (session : SessionDep , current_user : CurrentUser , chat_record_id : int , action_type : str ,
270- current_assistant : CurrentAssistant ):
308+ current_assistant : CurrentAssistant , in_chat : bool = True , stream : bool = True ):
271309 try :
272310 if action_type != 'analysis' and action_type != 'predict' :
273311 raise Exception (f"Type { action_type } Not Found" )
@@ -294,16 +332,35 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
294332 request_question = ChatQuestion (chat_id = record .chat_id , question = record .question )
295333
296334 llm_service = await LLMService .create (session , current_user , request_question , current_assistant )
297- llm_service .run_analysis_or_predict_task_async (session , action_type , record )
335+ llm_service .run_analysis_or_predict_task_async (session , action_type , record , in_chat , stream )
298336 except Exception as e :
299337 traceback .print_exc ()
338+ if stream :
339+ def _err (_e : Exception ):
340+ yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
300341
301- def _err (_e : Exception ):
302- yield 'data:' + orjson .dumps ({'content' : str (_e ), 'type' : 'error' }).decode () + '\n \n '
303-
304- return StreamingResponse (_err (e ), media_type = "text/event-stream" )
305-
306- return StreamingResponse (llm_service .await_result (), media_type = "text/event-stream" )
342+ return StreamingResponse (_err (e ), media_type = "text/event-stream" )
343+ else :
344+ return JSONResponse (
345+ content = {'message' : str (e )},
346+ status_code = 500 ,
347+ )
348+ if stream :
349+ return StreamingResponse (llm_service .await_result (), media_type = "text/event-stream" )
350+ else :
351+ res = llm_service .await_result ()
352+ raw_data = {}
353+ for chunk in res :
354+ if chunk :
355+ raw_data = chunk
356+ status_code = 200
357+ if not raw_data .get ('success' ):
358+ status_code = 500
359+
360+ return JSONResponse (
361+ content = raw_data ,
362+ status_code = status_code ,
363+ )
307364
308365
309366@router .get ("/record/{chat_record_id}/excel/export" , summary = f"{ PLACEHOLDER_PREFIX } export_chart_data" )
0 commit comments