11
2- from fastapi import Depends
2+ from typing import Optional
3+ from fastapi import Request
34from fastapi .responses import JSONResponse
5+ import jwt
6+ from sqlmodel import Session
47from starlette .middleware .base import BaseHTTPMiddleware
8+ from common .core .db import engine
9+ from apps .system .crud .assistant import get_assistant_info
10+ from apps .system .crud .user import get_user_info
11+ from apps .system .schemas .system_schema import UserInfoDTO
12+ from common .core import security
513from common .core .config import settings
6- # from common.core.deps import get_current_user
14+ from common .core .schemas import TokenPayload
715from common .utils .whitelist import whiteUtils
8-
16+ from fastapi . security . utils import get_authorization_scheme_param
917class TokenMiddleware (BaseHTTPMiddleware ):
1018
19+
20+
1121 def __init__ (self , app ):
1222 super ().__init__ (app )
1323
1424 async def dispatch (self , request , call_next ):
15- tokenkey = settings . TOKEN_KEY
25+
1626 if self .is_options (request ) or whiteUtils .is_whitelisted (request .url .path ):
1727 return await call_next (request )
28+ assistantTokenKey = settings .ASSISTANT_TOKEN_KEY
29+ assistantToken = request .headers .get (assistantTokenKey )
30+ #if assistantToken and assistantToken.lower().startswith("assistant "):
31+ if assistantToken :
32+ validate_pass , data , assistant = await self .validateAssistant (assistantToken )
33+ if validate_pass :
34+ request .state .current_user = data
35+ request .state .assistant = assistant
36+ return await call_next (request )
37+ return JSONResponse ({"error" : f"Unauthorized:[{ data } ]" }, status_code = 401 )
38+ #validate pass
39+ tokenkey = settings .TOKEN_KEY
1840 token = request .headers .get (tokenkey )
19- if not token or not token . startswith ( "Bearer " ):
20- return JSONResponse ({ "error" : "Unauthorized" }, status_code = 401 )
21- """ user = await get_current_user()
22- request.state.user = user """
23- return await call_next ( request )
41+ validate_pass , data = await self . validateToken ( token )
42+ if validate_pass :
43+ request . state . current_user = data
44+ return await call_next ( request )
45+ return JSONResponse ({ "error" : f"Unauthorized:[ { data } ]" }, status_code = 401 )
2446
25- def is_options (self , request ):
47+ def is_options (self , request : Request ):
2648 return request .method == "OPTIONS"
2749
28-
50+ async def validateToken (self , token : Optional [str ]):
51+ if not token :
52+ return False , f"Miss Token[{ settings .TOKEN_KEY } ]!"
53+ schema , param = get_authorization_scheme_param (token )
54+ if schema .lower () != "bearer" :
55+ return False , f"Token schema error!"
56+ payload = jwt .decode (
57+ param , settings .SECRET_KEY , algorithms = [security .ALGORITHM ]
58+ )
59+ token_data = TokenPayload (** payload )
60+ try :
61+ with Session (engine ) as session :
62+ session_user = await get_user_info (session = session , user_id = token_data .id )
63+ session_user = UserInfoDTO .model_validate (session_user )
64+ return True , session_user
65+ except Exception as e :
66+ return False , e
67+
68+
69+ async def validateAssistant (self , assistantToken : Optional [str ]):
70+ if not assistantToken :
71+ return False , f"Miss Token[{ settings .TOKEN_KEY } ]!"
72+ schema , param = get_authorization_scheme_param (assistantToken )
73+ if schema .lower () != "assistant" :
74+ return False , f"Token schema error!"
75+ payload = jwt .decode (
76+ param , settings .SECRET_KEY , algorithms = [security .ALGORITHM ]
77+ )
78+ token_data = TokenPayload (** payload )
79+ if not payload ['assistant_id' ]:
80+ return False , f"Miss assistant payload error!"
81+ try :
82+ with Session (engine ) as session :
83+ session_user = await get_user_info (session = session , user_id = token_data .id )
84+ session_user = UserInfoDTO .model_validate (session_user )
85+ assistant_info = get_assistant_info (session , payload ['assistant_id' ])
86+ return True , session_user , assistant_info
87+ except Exception as e :
88+ return False , e
0 commit comments