11import traceback
2- from typing import List
32
43import orjson
54from fastapi import APIRouter , HTTPException
65from fastapi .responses import StreamingResponse
7- from sqlmodel import select
86
97from apps .chat .curd .chat import list_chats , get_chat_with_records , create_chat , rename_chat , \
10- delete_chat , list_records
11- from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , Chat , ChatQuestion , ChatMcp
12- from apps .chat .task .llm import LLMService
13- from apps .datasource .crud .datasource import get_table_schema
14- from apps .datasource .models .datasource import CoreDatasource
15- from apps .system .crud .user import get_user_info
16- from apps .system .models .system_model import AiModelDetail
17- from common .core .deps import SessionDep , CurrentUser , get_current_user
8+ delete_chat
9+ from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , ChatQuestion
10+ from apps .chat .task .llm import LLMService , run_task
11+ from common .core .deps import SessionDep , CurrentUser
1812
1913router = APIRouter (tags = ["Data Q&A" ], prefix = "/chat" )
2014
@@ -57,34 +51,6 @@ async def delete(session: SessionDep, chart_id: int):
5751 )
5852
5953
60- @router .post ("/mcp_start" , operation_id = "mcp_start" )
61- async def mcp_start (session : SessionDep , chat : ChatMcp ):
62- user = await get_current_user (session , chat .token )
63- return create_chat (session , user , CreateChat (), False )
64-
65-
66- @router .post ("/mcp_question" , operation_id = "mcp_question" )
67- async def mcp_question (session : SessionDep , chat : ChatMcp ):
68- user = await get_current_user (session , chat .token )
69- # return await stream_sql(session, user, chat)
70- return {"content" : """这是一段写死的测试内容:
71-
72- 步骤1: 确定需要查询的字段。
73- 我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
74-
75- 步骤2: 确定筛选条件。
76- 问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
77-
78- 步骤3: 避免关键字冲突。
79- 因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
80-
81- 最终答案:
82- ```json
83- {"success":true,"sql":"SELECT COUNT(*) AS \" TotalOrders\" FROM \" public\" .\" Sheet1_c27345b66e\" WHERE \" 城市\" = '上海';"}
84- ```
85- <img src="https://sqlbot.fit2cloud.cn/images/111.png">""" }
86-
87-
8854@router .post ("/start" )
8955async def start_chat (session : SessionDep , current_user : CurrentUser , create_chat_obj : CreateChat ):
9056 try :
@@ -96,7 +62,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
9662 )
9763
9864
99- @router .post ("/question" , operation_id = "question" )
65+ @router .post ("/question" )
10066async def stream_sql (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ):
10167 """Stream SQL analysis results
10268
@@ -109,107 +75,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
10975 Streaming response with analysis results
11076 """
11177
112- chat = session .query (Chat ).filter (Chat .id == request_question .chat_id ).first ()
113- if not chat :
114- raise HTTPException (
115- status_code = 400 ,
116- detail = f"Chat with id { request_question .chat_id } not found"
117- )
118- ds : CoreDatasource | None = None
119- if chat .datasource :
120- # Get available datasource
121- ds = session .query (CoreDatasource ).filter (CoreDatasource .id == chat .datasource ).first ()
122- if not ds :
123- raise HTTPException (
124- status_code = 500 ,
125- detail = "No available datasource configuration found"
126- )
127-
128- request_question .engine = ds .type_name if ds .type != 'excel' else 'PostgreSQL'
129-
130- # Get available AI model
131- aimodel = session .exec (select (AiModelDetail ).where (
132- AiModelDetail .status == True ,
133- AiModelDetail .api_key .is_not (None )
134- )).first ()
135- if not aimodel :
78+ try :
79+ llm_service = LLMService (session , current_user , request_question )
80+ llm_service .init_record ()
81+ except Exception as e :
82+ traceback .print_exc ()
13683 raise HTTPException (
13784 status_code = 500 ,
138- detail = "No available AI model configuration found"
85+ detail = str ( e )
13986 )
14087
141- history_records : List [ChatRecord ] = list (filter (lambda r : True if r .first_chat != True else False ,
142- list_records (session = session , current_user = current_user ,
143- chart_id = request_question .chat_id )))
144- # get schema
145- if ds :
146- request_question .db_schema = get_table_schema (session = session , ds = ds )
147-
148- db_user = get_user_info (session = session , user_id = current_user .id )
149- request_question .lang = db_user .language
150-
151- llm_service = LLMService (request_question , aimodel , history_records ,
152- CoreDatasource (** ds .model_dump ()) if ds else None )
153-
154- llm_service .init_record (session = session , current_user = current_user )
155-
156- def run_task ():
157- try :
158- # return id
159- yield orjson .dumps ({'type' : 'id' , 'id' : llm_service .get_record ().id }).decode () + '\n \n '
160-
161- # select datasource if datasource is none
162- if not ds :
163- ds_res = llm_service .select_datasource (session = session )
164- for chunk in ds_res :
165- yield orjson .dumps ({'content' : chunk , 'type' : 'datasource-result' }).decode () + '\n \n '
166- yield orjson .dumps ({'id' : llm_service .ds .id , 'datasource_name' : llm_service .ds .name ,
167- 'engine_type' : llm_service .ds .type_name , 'type' : 'datasource' }).decode () + '\n \n '
168-
169- llm_service .chat_question .db_schema = get_table_schema (session = session , ds = llm_service .ds )
170-
171- # generate sql
172- sql_res = llm_service .generate_sql (session = session )
173- full_sql_text = ''
174- for chunk in sql_res :
175- full_sql_text += chunk
176- yield orjson .dumps ({'content' : chunk , 'type' : 'sql-result' }).decode () + '\n \n '
177- yield orjson .dumps ({'type' : 'info' , 'msg' : 'sql generated' }).decode () + '\n \n '
178-
179- # filter sql
180- print (full_sql_text )
181- sql = llm_service .check_save_sql (session = session , res = full_sql_text )
182- print (sql )
183- yield orjson .dumps ({'content' : sql , 'type' : 'sql' }).decode () + '\n \n '
184-
185- # execute sql
186- result = llm_service .execute_sql (sql = sql )
187- llm_service .save_sql_data (session = session , data_obj = result )
188- yield orjson .dumps ({'content' : orjson .dumps (result ).decode (), 'type' : 'sql-data' }).decode () + '\n \n '
189-
190- # generate chart
191- chart_res = llm_service .generate_chart (session = session )
192- full_chart_text = ''
193- for chunk in chart_res :
194- full_chart_text += chunk
195- yield orjson .dumps ({'content' : chunk , 'type' : 'chart-result' }).decode () + '\n \n '
196- yield orjson .dumps ({'type' : 'info' , 'msg' : 'chart generated' }).decode () + '\n \n '
197-
198- # filter chart
199- print (full_chart_text )
200- chart = llm_service .check_save_chart (session = session , res = full_chart_text )
201- print (chart )
202- yield orjson .dumps ({'content' : orjson .dumps (chart ).decode (), 'type' : 'chart' }).decode () + '\n \n '
203-
204- llm_service .finish (session = session )
205- yield orjson .dumps ({'type' : 'finish' }).decode () + '\n \n '
206-
207- except Exception as e :
208- traceback .print_exc ()
209- llm_service .save_error (session = session , message = str (e ))
210- yield orjson .dumps ({'content' : str (e ), 'type' : 'error' }).decode () + '\n \n '
211-
212- return StreamingResponse (run_task (), media_type = "text/event-stream" )
88+ return StreamingResponse (run_task (llm_service , session ), media_type = "text/event-stream" )
21389
21490
21591@router .post ("/record/{chart_record_id}/{action_type}" )
@@ -233,35 +109,9 @@ async def analysis_or_predict(session: SessionDep, current_user: CurrentUser, ch
233109 detail = f"Chat record with id { chart_record_id } has not generated chart, do not support to analyze it"
234110 )
235111
236- chat = session .query (Chat ).filter (Chat .id == record .chat_id ).first ()
237- if not chat :
238- raise HTTPException (
239- status_code = 400 ,
240- detail = f"Chat with id { record .chart_id } not found"
241- )
242-
243- if chat .create_by != current_user .id :
244- raise HTTPException (
245- status_code = 401 ,
246- detail = f"You cannot use the chat with id { record .chart_id } "
247- )
248-
249- # Get available AI model
250- aimodel = session .exec (select (AiModelDetail ).where (
251- AiModelDetail .status == True ,
252- AiModelDetail .api_key .is_not (None )
253- )).first ()
254- if not aimodel :
255- raise HTTPException (
256- status_code = 500 ,
257- detail = "No available AI model configuration found"
258- )
259-
260- request_question = ChatQuestion (chat_id = chat .id , question = '' )
261- db_user = get_user_info (session = session , user_id = current_user .id )
262- request_question .lang = db_user .language
112+ request_question = ChatQuestion (chat_id = record .chat_id , question = '' )
263113
264- llm_service = LLMService (request_question , aimodel )
114+ llm_service = LLMService (session , current_user , request_question )
265115 llm_service .set_record (record )
266116
267117 def run_task ():
@@ -277,14 +127,14 @@ def run_task():
277127
278128 elif action_type == 'predict' :
279129 # generate predict
280- analysis_res = llm_service .generate_predict (session = session )
130+ analysis_res = llm_service .generate_predict ()
281131 full_text = ''
282132 for chunk in analysis_res :
283133 yield orjson .dumps ({'content' : chunk , 'type' : 'predict-result' }).decode () + '\n \n '
284134 full_text += chunk
285135 yield orjson .dumps ({'type' : 'info' , 'msg' : 'predict generated' }).decode () + '\n \n '
286136
287- _data = llm_service .check_save_predict_data (session = session , res = full_text )
137+ _data = llm_service .check_save_predict_data (res = full_text )
288138 yield orjson .dumps ({'type' : 'predict' , 'content' : _data }).decode () + '\n \n '
289139
290140 yield orjson .dumps ({'type' : 'predict_finish' }).decode () + '\n \n '
0 commit comments