77import jwt
88from sqlmodel import Session
99from starlette .middleware .base import BaseHTTPMiddleware
10- from apps .system .models .system_model import AssistantModel
10+ from apps .system .crud .apikey_manage import get_api_key
11+ from apps .system .models .system_model import ApiKeyModel , AssistantModel
1112from common .core .db import engine
1213from apps .system .crud .assistant import get_assistant_info , get_assistant_user
1314from apps .system .crud .user import get_user_by_account , get_user_info
@@ -33,7 +34,15 @@ async def dispatch(self, request, call_next):
3334 return await call_next (request )
3435 assistantTokenKey = settings .ASSISTANT_TOKEN_KEY
3536 assistantToken = request .headers .get (assistantTokenKey )
37+ askToken = request .headers .get ("X-SQLBOT-ASK-TOKEN" )
3638 trans = await get_i18n (request )
39+ if askToken :
40+ validate_pass , data = await self .validateAskToken (askToken , trans )
41+ if validate_pass :
42+ request .state .current_user = data
43+ return await call_next (request )
44+ message = trans ('i18n_permission.authenticate_invalid' , msg = data )
45+ return JSONResponse (message , status_code = 401 , headers = {"Access-Control-Allow-Origin" : "*" })
3746 #if assistantToken and assistantToken.lower().startswith("assistant "):
3847 if assistantToken :
3948 validator : tuple [any ] = await self .validateAssistant (assistantToken , trans )
@@ -62,6 +71,50 @@ async def dispatch(self, request, call_next):
6271 def is_options (self , request : Request ):
6372 return request .method == "OPTIONS"
6473
74+ async def validateAskToken (self , askToken : Optional [str ], trans : I18n ):
75+ if not askToken :
76+ return False , f"Miss Token[X-SQLBOT-ASK-TOKEN]!"
77+ schema , param = get_authorization_scheme_param (askToken )
78+ if schema .lower () != "sk" :
79+ return False , f"Token schema error!"
80+ try :
81+ payload = jwt .decode (
82+ param , options = {"verify_signature" : False , "verify_exp" : False }, algorithms = [security .ALGORITHM ]
83+ )
84+ access_key = payload .get ('access_key' , None )
85+
86+ if not access_key :
87+ return False , f"Miss access_key payload error!"
88+ with Session (engine ) as session :
89+ api_key_model = await get_api_key (session , access_key )
90+ api_key_model = ApiKeyModel .model_validate (api_key_model ) if api_key_model else None
91+ if not api_key_model :
92+ return False , f"Invalid access_key!"
93+ if not api_key_model .status :
94+ return False , f"Disabled access_key!"
95+ payload = jwt .decode (
96+ param , api_key_model .secret_key , algorithms = [security .ALGORITHM ]
97+ )
98+ uid = api_key_model .uid
99+ session_user = await get_user_info (session = session , user_id = uid )
100+ if not session_user :
101+ message = trans ('i18n_not_exist' , msg = trans ('i18n_user.account' ))
102+ raise Exception (message )
103+ session_user = UserInfoDTO .model_validate (session_user )
104+ if session_user .status != 1 :
105+ message = trans ('i18n_login.user_disable' , msg = trans ('i18n_concat_admin' ))
106+ raise Exception (message )
107+ if not session_user .oid or session_user .oid == 0 :
108+ message = trans ('i18n_login.no_associated_ws' , msg = trans ('i18n_concat_admin' ))
109+ raise Exception (message )
110+ return True , session_user
111+ except Exception as e :
112+ msg = str (e )
113+ SQLBotLogUtil .exception (f"Token validation error: { msg } " )
114+ if 'expired' in msg :
115+ return False , jwt .ExpiredSignatureError (trans ('i18n_permission.token_expired' ))
116+ return False , e
117+
65118 async def validateToken (self , token : Optional [str ], trans : I18n ):
66119 if not token :
67120 return False , f"Miss Token[{ settings .TOKEN_KEY } ]!"
0 commit comments