11import traceback
2- from typing import List
32
43import orjson
54from fastapi import APIRouter , HTTPException
65from fastapi .responses import StreamingResponse
7- from sqlmodel import select
86
97from apps .chat .curd .chat import list_chats , get_chat_with_records , create_chat , rename_chat , \
10- delete_chat , list_records
11- from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , Chat , ChatQuestion , ChatMcp
12- from apps .chat .task .llm import LLMService
13- from apps .datasource .crud .datasource import get_table_schema
14- from apps .datasource .models .datasource import CoreDatasource
15- from apps .system .crud .user import get_user_info
16- from apps .system .models .system_model import AiModelDetail
17- from common .core .deps import SessionDep , CurrentUser , get_current_user
8+ delete_chat
9+ from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , ChatQuestion
10+ from apps .chat .task .llm import LLMService , run_task
11+ from common .core .deps import SessionDep , CurrentUser
1812
1913router = APIRouter (tags = ["Data Q&A" ], prefix = "/chat" )
2014
@@ -57,18 +51,6 @@ async def delete(session: SessionDep, chart_id: int):
5751 )
5852
5953
60- @router .post ("/mcp_start" , operation_id = "mcp_start" )
61- async def mcp_start (session : SessionDep , chat : ChatMcp ):
62- user = await get_current_user (session , chat .token )
63- return create_chat (session , user , CreateChat (), False )
64-
65-
66- @router .post ("/mcp_question" , operation_id = "mcp_question" )
67- async def mcp_question (session : SessionDep , chat : ChatMcp ):
68- user = await get_current_user (session , chat .token )
69- return await stream_sql (session , user , chat )
70-
71-
7254@router .post ("/start" )
7355async def start_chat (session : SessionDep , current_user : CurrentUser , create_chat_obj : CreateChat ):
7456 try :
@@ -80,7 +62,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
8062 )
8163
8264
83- @router .post ("/question" , operation_id = "question" )
65+ @router .post ("/question" )
8466async def stream_sql (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ):
8567 """Stream SQL analysis results
8668
@@ -93,107 +75,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
9375 Streaming response with analysis results
9476 """
9577
96- chat = session .query (Chat ).filter (Chat .id == request_question .chat_id ).first ()
97- if not chat :
98- raise HTTPException (
99- status_code = 400 ,
100- detail = f"Chat with id { request_question .chat_id } not found"
101- )
102- ds : CoreDatasource | None = None
103- if chat .datasource :
104- # Get available datasource
105- ds = session .query (CoreDatasource ).filter (CoreDatasource .id == chat .datasource ).first ()
106- if not ds :
107- raise HTTPException (
108- status_code = 500 ,
109- detail = "No available datasource configuration found"
110- )
111-
112- request_question .engine = ds .type_name if ds .type != 'excel' else 'PostgreSQL'
113-
114- # Get available AI model
115- aimodel = session .exec (select (AiModelDetail ).where (
116- AiModelDetail .status == True ,
117- AiModelDetail .api_key .is_not (None )
118- )).first ()
119- if not aimodel :
78+ try :
79+ llm_service = LLMService (session , current_user , request_question )
80+ llm_service .init_record ()
81+ except Exception as e :
82+ traceback .print_exc ()
12083 raise HTTPException (
12184 status_code = 500 ,
122- detail = "No available AI model configuration found"
85+ detail = str ( e )
12386 )
12487
125- history_records : List [ChatRecord ] = list (filter (lambda r : True if r .first_chat != True else False ,
126- list_records (session = session , current_user = current_user ,
127- chart_id = request_question .chat_id )))
128- # get schema
129- if ds :
130- request_question .db_schema = get_table_schema (session = session , ds = ds )
131-
132- db_user = get_user_info (session = session , user_id = current_user .id )
133- request_question .lang = db_user .language
134-
135- llm_service = LLMService (request_question , aimodel , history_records ,
136- CoreDatasource (** ds .model_dump ()) if ds else None )
137-
138- llm_service .init_record (session = session , current_user = current_user )
139-
140- def run_task ():
141- try :
142- # return id
143- yield orjson .dumps ({'type' : 'id' , 'id' : llm_service .get_record ().id }).decode () + '\n \n '
144-
145- # select datasource if datasource is none
146- if not ds :
147- ds_res = llm_service .select_datasource (session = session )
148- for chunk in ds_res :
149- yield orjson .dumps ({'content' : chunk , 'type' : 'datasource-result' }).decode () + '\n \n '
150- yield orjson .dumps ({'id' : llm_service .ds .id , 'datasource_name' : llm_service .ds .name ,
151- 'engine_type' : llm_service .ds .type_name , 'type' : 'datasource' }).decode () + '\n \n '
152-
153- llm_service .chat_question .db_schema = get_table_schema (session = session , ds = llm_service .ds )
154-
155- # generate sql
156- sql_res = llm_service .generate_sql (session = session )
157- full_sql_text = ''
158- for chunk in sql_res :
159- full_sql_text += chunk
160- yield orjson .dumps ({'content' : chunk , 'type' : 'sql-result' }).decode () + '\n \n '
161- yield orjson .dumps ({'type' : 'info' , 'msg' : 'sql generated' }).decode () + '\n \n '
162-
163- # filter sql
164- print (full_sql_text )
165- sql = llm_service .check_save_sql (session = session , res = full_sql_text )
166- print (sql )
167- yield orjson .dumps ({'content' : sql , 'type' : 'sql' }).decode () + '\n \n '
168-
169- # execute sql
170- result = llm_service .execute_sql (sql = sql )
171- llm_service .save_sql_data (session = session , data_obj = result )
172- yield orjson .dumps ({'content' : orjson .dumps (result ).decode (), 'type' : 'sql-data' }).decode () + '\n \n '
173-
174- # generate chart
175- chart_res = llm_service .generate_chart (session = session )
176- full_chart_text = ''
177- for chunk in chart_res :
178- full_chart_text += chunk
179- yield orjson .dumps ({'content' : chunk , 'type' : 'chart-result' }).decode () + '\n \n '
180- yield orjson .dumps ({'type' : 'info' , 'msg' : 'chart generated' }).decode () + '\n \n '
181-
182- # filter chart
183- print (full_chart_text )
184- chart = llm_service .check_save_chart (session = session , res = full_chart_text )
185- print (chart )
186- yield orjson .dumps ({'content' : orjson .dumps (chart ).decode (), 'type' : 'chart' }).decode () + '\n \n '
187-
188- llm_service .finish (session = session )
189- yield orjson .dumps ({'type' : 'finish' }).decode () + '\n \n '
190-
191- except Exception as e :
192- traceback .print_exc ()
193- llm_service .save_error (session = session , message = str (e ))
194- yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
195-
196- return StreamingResponse (run_task (), media_type = "text/event-stream" )
88+ return StreamingResponse (run_task (llm_service , session ), media_type = "text/event-stream" )
19789
19890
19991@router .post ("/record/{chart_record_id}/{action_type}" )
@@ -217,35 +109,9 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
217109 detail = f"Chat record with id { chart_record_id } has not generated chart, do not support to analyze it"
218110 )
219111
220- chat = session .query (Chat ).filter (Chat .id == record .chat_id ).first ()
221- if not chat :
222- raise HTTPException (
223- status_code = 400 ,
224- detail = f"Chat with id { record .chart_id } not found"
225- )
226-
227- if chat .create_by != current_user .id :
228- raise HTTPException (
229- status_code = 401 ,
230- detail = f"You cannot use the chat with id { record .chart_id } "
231- )
232-
233- # Get available AI model
234- aimodel = session .exec (select (AiModelDetail ).where (
235- AiModelDetail .status == True ,
236- AiModelDetail .api_key .is_not (None )
237- )).first ()
238- if not aimodel :
239- raise HTTPException (
240- status_code = 500 ,
241- detail = "No available AI model configuration found"
242- )
243-
244- request_question = ChatQuestion (chat_id = chat .id , question = '' )
245- db_user = get_user_info (session = session , user_id = current_user .id )
246- request_question .lang = db_user .language
112+ request_question = ChatQuestion (chat_id = record .chat_id , question = '' )
247113
248- llm_service = LLMService (request_question , aimodel )
114+ llm_service = LLMService (session , current_user , request_question )
249115 llm_service .set_record (record )
250116
251117 def run_task ():
@@ -261,14 +127,14 @@ def run_task():
261127
262128 elif action_type == 'predict' :
263129 # generate predict
264- analysis_res = llm_service .generate_predict (session = session )
130+ analysis_res = llm_service .generate_predict ()
265131 full_text = ''
266132 for chunk in analysis_res :
267133 yield orjson .dumps ({'content' : chunk , 'type' : 'predict-result' }).decode () + '\n \n '
268134 full_text += chunk
269135 yield orjson .dumps ({'type' : 'info' , 'msg' : 'predict generated' }).decode () + '\n \n '
270136
271- _data = llm_service .check_save_predict_data (session = session , res = full_text )
137+ _data = llm_service .check_save_predict_data (res = full_text )
272138 yield orjson .dumps ({'type' : 'predict' , 'content' : _data }).decode () + '\n \n '
273139
274140 yield orjson .dumps ({'type' : 'predict_finish' }).decode () + '\n \n '
0 commit comments