33
44from datetime import timedelta
55
6- from fastapi import APIRouter , HTTPException
6+ import jwt
7+ from fastapi import HTTPException , status , APIRouter
78from fastapi .responses import StreamingResponse
9+ # from fastapi.security import OAuth2PasswordBearer
10+ from jwt .exceptions import InvalidTokenError
11+ from pydantic import ValidationError
12+ from sqlmodel import select
813
914from apps .chat .api .chat import create_chat
1015from apps .chat .models .chat_model import ChatMcp , CreateChat , ChatStart
1116from apps .chat .task .llm import LLMService , run_task
12- from apps .datasource .crud .datasource import get_datasource_list
1317from apps .system .crud .user import authenticate
14- from apps .system .models .system_model import AiModelDetail
18+ from apps .system .crud .user import get_db_user
19+ from apps .system .models .system_model import UserWsModel
20+ from apps .system .models .user import UserModel
21+ from apps .system .schemas .system_schema import BaseUserDTO
22+ from apps .system .schemas .system_schema import UserInfoDTO
23+ from common .core import security
1524from common .core .config import settings
16- from common .core .deps import SessionDep , get_current_user
17- from common .core .schemas import Token
25+ from common .core .deps import SessionDep
26+ from common .core .schemas import TokenPayload , XOAuth2PasswordBearer , Token
1827from common .core .security import create_access_token
1928
29+ reusable_oauth2 = XOAuth2PasswordBearer (
30+ tokenUrl = f"{ settings .API_V1_STR } /login/access-token"
31+ )
32+
2033router = APIRouter (tags = ["mcp" ], prefix = "/mcp" )
2134
2235
3548# ))
3649
3750
38- @router .get ("/ds_list" , operation_id = "get_datasource_list" )
39- async def datasource_list (session : SessionDep ):
40- return get_datasource_list (session = session )
41-
42-
43- @router .get ("/model_list" , operation_id = "get_model_list" )
44- async def get_model_list (session : SessionDep ):
45- return session .query (AiModelDetail ).all ()
51+ # @router.get("/ds_list", operation_id="get_datasource_list")
52+ # async def datasource_list(session: SessionDep):
53+ # return get_datasource_list(session=session)
54+ #
55+ #
56+ # @router.get("/model_list", operation_id="get_model_list")
57+ # async def get_model_list(session: SessionDep):
58+ # return session.query(AiModelDetail).all()
4659
4760
4861@router .post ("/mcp_start" , operation_id = "mcp_start" )
4962async def mcp_start (session : SessionDep , chat : ChatStart ):
50- user = authenticate (session = session , account = chat .username , password = chat .password )
63+ user : BaseUserDTO = authenticate (session = session , account = chat .username , password = chat .password )
5164 if not user :
5265 raise HTTPException (status_code = 400 , detail = "Incorrect account or password" )
66+
67+ if not user .oid or user .oid == 0 :
68+ raise HTTPException (status_code = 400 , detail = "No associated workspace, Please contact the administrator" )
5369 access_token_expires = timedelta (minutes = settings .ACCESS_TOKEN_EXPIRE_MINUTES )
5470 user_dict = user .to_dict ()
5571 t = Token (access_token = create_access_token (
@@ -61,9 +77,36 @@ async def mcp_start(session: SessionDep, chat: ChatStart):
6177
6278@router .post ("/mcp_question" , operation_id = "mcp_question" )
6379async def mcp_question (session : SessionDep , chat : ChatMcp ):
64- user = await get_current_user (session , chat .token )
65-
66- llm_service = LLMService (session , user , chat )
80+ try :
81+ payload = jwt .decode (
82+ chat .token , settings .SECRET_KEY , algorithms = [security .ALGORITHM ]
83+ )
84+ token_data = TokenPayload (** payload )
85+ except (InvalidTokenError , ValidationError ):
86+ raise HTTPException (
87+ status_code = status .HTTP_403_FORBIDDEN ,
88+ detail = "Could not validate credentials" ,
89+ )
90+ # session_user = await get_user_info(session=session, user_id=token_data.id)
91+
92+ db_user : UserModel = get_db_user (session = session , user_id = token_data .id )
93+ session_user = UserInfoDTO .model_validate (db_user .model_dump ())
94+ session_user .isAdmin = session_user .id == 1 and session_user .account == 'admin'
95+ if session_user .isAdmin :
96+ session_user = session_user
97+ ws_model : UserWsModel = session .exec (
98+ select (UserWsModel ).where (UserWsModel .uid == session_user .id , UserWsModel .oid == session_user .oid )).first ()
99+ session_user .weight = ws_model .weight if ws_model else - 1
100+
101+ session_user = UserInfoDTO .model_validate (session_user )
102+ if not session_user :
103+ raise HTTPException (status_code = 404 , detail = "User not found" )
104+
105+ if session_user .status != 1 :
106+ raise HTTPException (status_code = 400 , detail = "Inactive user" )
107+
108+ # ask
109+ llm_service = LLMService (session , session_user , chat )
67110 llm_service .init_record ()
68111
69112 return StreamingResponse (run_task (llm_service , False ), media_type = "text/event-stream" )
0 commit comments