88
99from apps .chat .curd .chat import list_chats , get_chat_with_records , create_chat , rename_chat , \
1010 delete_chat , list_records
11- from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , Chat , ChatQuestion
11+ from apps .chat .models .chat_model import CreateChat , ChatRecord , RenameChat , Chat , ChatQuestion , ChatMcp
1212from apps .chat .task .llm import LLMService
1313from apps .datasource .crud .datasource import get_table_schema
1414from apps .datasource .models .datasource import CoreDatasource
@@ -57,6 +57,18 @@ async def delete(session: SessionDep, chart_id: int):
5757 )
5858
5959
60+ @router .post ("/mcp_start" , operation_id = "mcp_start" )
61+ async def mcp_start (session : SessionDep , chat : ChatMcp ):
62+ user = await get_current_user (session , chat .token )
63+ return create_chat (session , user , CreateChat (), False )
64+
65+
66+ @router .post ("/mcp_question" , operation_id = "mcp_question" )
67+ async def mcp_question (session : SessionDep , chat : ChatMcp ):
68+ user = await get_current_user (session , chat .token )
69+ return await stream_sql (session , user , chat )
70+
71+
6072@router .post ("/start" )
6173async def start_chat (session : SessionDep , current_user : CurrentUser , create_chat_obj : CreateChat ):
6274 try :
@@ -68,25 +80,6 @@ async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat
6880 )
6981
7082
71- @router .post ("/mcp_question" , operation_id = "mcp_question" )
72- async def mcp_question (session : SessionDep , token : str , request_question : ChatQuestion ):
73- user = await get_current_user (session , token )
74- # return await stream_sql(session, user, request_question)
75- return {"content" :"""步骤1: 确定需要查询的字段。
76- 我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
77-
78- 步骤2: 确定筛选条件。
79- 问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
80-
81- 步骤3: 避免关键字冲突。
82- 因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
83-
84- 最终答案:
85- ```json
86- {"success":true,"sql":"SELECT COUNT(*) AS \" TotalOrders\" FROM \" public\" .\" Sheet1_c27345b66e\" WHERE \" 城市\" = '上海';"}
87- ```""" }
88-
89-
9083@router .post ("/question" , operation_id = "question" )
9184async def stream_sql (session : SessionDep , current_user : CurrentUser , request_question : ChatQuestion ):
9285 """Stream SQL analysis results
@@ -106,16 +99,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
10699 status_code = 400 ,
107100 detail = f"Chat with id { request_question .chat_id } not found"
108101 )
109-
110- # Get available datasource
111- ds = session .query (CoreDatasource ).filter (CoreDatasource .id == chat .datasource ).first ()
112- if not ds :
113- raise HTTPException (
114- status_code = 500 ,
115- detail = "No available datasource configuration found"
116- )
117-
118- request_question .engine = ds .type_name if ds .type != 'excel' else 'PostgreSQL'
102+ ds : CoreDatasource | None = None
103+ if chat .datasource :
104+ # Get available datasource
105+ ds = session .query (CoreDatasource ).filter (CoreDatasource .id == chat .datasource ).first ()
106+ if not ds :
107+ raise HTTPException (
108+ status_code = 500 ,
109+ detail = "No available datasource configuration found"
110+ )
111+
112+ request_question .engine = ds .type_name if ds .type != 'excel' else 'PostgreSQL'
119113
120114 # Get available AI model
121115 aimodel = session .exec (select (AiModelDetail ).where (
@@ -128,14 +122,18 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
128122 detail = "No available AI model configuration found"
129123 )
130124
131- history_records : List [ChatRecord ] = list_records (session = session , current_user = current_user ,
132- chart_id = request_question .chat_id )
125+ history_records : List [ChatRecord ] = list (filter (lambda r : True if r .first_chat != True else False ,
126+ list_records (session = session , current_user = current_user ,
127+ chart_id = request_question .chat_id )))
133128 # get schema
134- request_question .db_schema = get_table_schema (session = session , ds = ds )
129+ if ds :
130+ request_question .db_schema = get_table_schema (session = session , ds = ds )
131+
135132 db_user = get_user_info (session = session , user_id = current_user .id )
136133 request_question .lang = db_user .language
137134
138- llm_service = LLMService (request_question , aimodel , history_records , CoreDatasource (** ds .model_dump ()))
135+ llm_service = LLMService (request_question , aimodel , history_records ,
136+ CoreDatasource (** ds .model_dump ()) if ds else None )
139137
140138 llm_service .init_record (session = session , current_user = current_user )
141139
@@ -144,6 +142,16 @@ def run_task():
144142 # return id
145143 yield orjson .dumps ({'type' : 'id' , 'id' : llm_service .get_record ().id }).decode () + '\n \n '
146144
145+ # select datasource if datasource is none
146+ if not ds :
147+ ds_res = llm_service .select_datasource (session = session )
148+ for chunk in ds_res :
149+ yield orjson .dumps ({'content' : chunk , 'type' : 'datasource-result' }).decode () + '\n \n '
150+ yield orjson .dumps ({'id' : llm_service .ds .id , 'datasource_name' : llm_service .ds .name ,
151+ 'engine_type' : llm_service .ds .type_name , 'type' : 'datasource' }).decode () + '\n \n '
152+
153+ llm_service .chat_question .db_schema = get_table_schema (session = session , ds = llm_service .ds )
154+
147155 # generate sql
148156 sql_res = llm_service .generate_sql (session = session )
149157 full_sql_text = ''
0 commit comments