1+ # Author: Junjun
2+ # Date: 2025/7/1
3+
4+ from typing import Annotated
5+ from fastapi import APIRouter , Depends , HTTPException
6+ from fastapi .security import OAuth2PasswordRequestForm
7+ from common .core .deps import SessionDep , get_current_user
8+ from apps .system .crud .user import authenticate
9+ from common .core .security import create_access_token
10+ from datetime import timedelta
11+ from common .core .config import settings
12+ from common .core .schemas import Token
13+ from apps .datasource .crud .datasource import get_datasource_list
14+ from apps .system .models .system_model import AiModelDetail
15+ from apps .chat .models .chat_model import ChatMcp , CreateChat
16+ from apps .chat .api .chat import create_chat , stream_sql
17+
18+ router = APIRouter (tags = ["mcp" ], prefix = "/mcp" )
19+
20+
21+ @router .post ("/access_token" , operation_id = "access_token" )
22+ def local_login (
23+ session : SessionDep ,
24+ form_data : Annotated [OAuth2PasswordRequestForm , Depends ()]
25+ ) -> Token :
26+ user = authenticate (session = session , account = form_data .username , password = form_data .password )
27+ if not user :
28+ raise HTTPException (status_code = 400 , detail = "Incorrect account or password" )
29+ access_token_expires = timedelta (minutes = settings .ACCESS_TOKEN_EXPIRE_MINUTES )
30+ user_dict = user .to_dict ()
31+ return Token (access_token = create_access_token (
32+ user_dict , expires_delta = access_token_expires
33+ ))
34+
35+
36+ @router .get ("/ds_list" , operation_id = "get_datasource_list" )
37+ async def datasource_list (session : SessionDep ):
38+ return get_datasource_list (session = session )
39+
40+
41+ @router .get ("/model_list" , operation_id = "get_model_list" )
42+ async def get_model_list (session : SessionDep ):
43+ return session .query (AiModelDetail ).all ()
44+
45+
46+ @router .post ("/mcp_start" , operation_id = "mcp_start" )
47+ async def mcp_start (session : SessionDep , chat : ChatMcp ):
48+ user = await get_current_user (session , chat .token )
49+ return create_chat (session , user , CreateChat (), False )
50+
51+
52+ @router .post ("/mcp_question" , operation_id = "mcp_question" )
53+ async def mcp_question (session : SessionDep , chat : ChatMcp ):
54+ user = await get_current_user (session , chat .token )
55+ # return await stream_sql(session, user, chat)
56+ return {"content" : """这是一段写死的测试内容:
57+
58+ 步骤1: 确定需要查询的字段。
59+ 我们需要统计上海的订单总数,因此需要从"城市"字段中筛选出值为"上海"的记录,并使用COUNT函数计算这些记录的数量。
60+
61+ 步骤2: 确定筛选条件。
62+ 问题要求统计上海的订单总数,所以我们需要在SQL语句中添加WHERE "城市" = '上海'来筛选出符合条件的记录。
63+
64+ 步骤3: 避免关键字冲突。
65+ 因为这个Excel/CSV数据库是 PostgreSQL 类型,所以在schema、表名、字段名和别名外层加双引号。
66+
67+ 最终答案:
68+ ```json
69+ {"success":true,"sql":"SELECT COUNT(*) AS \" TotalOrders\" FROM \" public\" .\" Sheet1_c27345b66e\" WHERE \" 城市\" = '上海';"}
70+ ```
71+ <img src="https://sqlbot.fit2cloud.cn/images/111.png">""" }
0 commit comments