Skip to content

Commit e69caad

Browse files
perf: Assistant token
1 parent 4bd69b8 commit e69caad

File tree

10 files changed

+149
-58
lines changed

10 files changed

+149
-58
lines changed

backend/apps/system/api/assistant.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
from datetime import timedelta
22
from fastapi import APIRouter, FastAPI, Request
33
from sqlmodel import Session, select
4-
from apps.system.crud.user import get_user_by_account
4+
from apps.system.crud.assistant import get_assistant_info
55
from apps.system.models.system_model import AssistantModel
6+
from apps.system.schemas.auth import CacheName, CacheNamespace
67
from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantValidator
78
from common.core.deps import SessionDep
89
from common.core.security import create_access_token
10+
from common.core.sqlbot_cache import clear_cache
911
from common.utils.time import get_timestamp
1012
from starlette.middleware.cors import CORSMiddleware
1113
from common.core.config import settings
1214
router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")
1315

1416
@router.get("/validator/{id}", response_model=AssistantValidator)
1517
async def info(session: SessionDep, id: int):
16-
db_model = session.get(AssistantModel, id)
18+
db_model = get_assistant_info(session, id)
1719
if not db_model:
1820
return AssistantValidator()
1921
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
20-
user = get_user_by_account(session=session, account='admin')
22+
assistantDict = {
23+
"id": 1, "account": 'admin', "oid": 1, "assistant_id": id
24+
}
2125
access_token = create_access_token(
22-
user.to_dict(), expires_delta=access_token_expires
26+
assistantDict, expires_delta=access_token_expires
2327
)
2428
return AssistantValidator(True, True, True, access_token)
2529

@@ -39,6 +43,7 @@ async def add(request: Request, session: SessionDep, creator: AssistantBase):
3943

4044

4145
@router.put("")
46+
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="editor.id")
4247
async def update(request: Request, session: SessionDep, editor: AssistantDTO):
4348
id = editor.id
4449
db_model = session.get(AssistantModel, id)
@@ -52,12 +57,13 @@ async def update(request: Request, session: SessionDep, editor: AssistantDTO):
5257

5358
@router.get("/{id}", response_model=AssistantModel)
5459
async def get_one(session: SessionDep, id: int):
55-
db_model = session.get(AssistantModel, id)
60+
db_model = get_assistant_info(session, id)
5661
if not db_model:
5762
raise ValueError(f"AssistantModel with id {id} not found")
5863
return db_model
5964

60-
@router.delete("/{id}")
65+
@router.delete("/{id}")
66+
@clear_cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="id")
6167
async def delete(request: Request, session: SessionDep, id: int):
6268
db_model = session.get(AssistantModel, id)
6369
if not db_model:
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
2+
3+
from sqlmodel import Session
4+
from apps.system.models.system_model import AssistantModel
5+
from apps.system.schemas.auth import CacheName, CacheNamespace
6+
from common.core.sqlbot_cache import cache
7+
8+
9+
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
10+
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
11+
db_model = session.get(AssistantModel, assistant_id)
12+
return db_model
Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,88 @@
11

2-
from fastapi import Depends
2+
from typing import Optional
3+
from fastapi import Request
34
from fastapi.responses import JSONResponse
5+
import jwt
6+
from sqlmodel import Session
47
from starlette.middleware.base import BaseHTTPMiddleware
8+
from common.core.db import engine
9+
from apps.system.crud.assistant import get_assistant_info
10+
from apps.system.crud.user import get_user_info
11+
from apps.system.schemas.system_schema import UserInfoDTO
12+
from common.core import security
513
from common.core.config import settings
6-
# from common.core.deps import get_current_user
14+
from common.core.schemas import TokenPayload
715
from common.utils.whitelist import whiteUtils
8-
16+
from fastapi.security.utils import get_authorization_scheme_param
917
class TokenMiddleware(BaseHTTPMiddleware):
1018

19+
20+
1121
def __init__(self, app):
1222
super().__init__(app)
1323

1424
async def dispatch(self, request, call_next):
15-
tokenkey = settings.TOKEN_KEY
25+
1626
if self.is_options(request) or whiteUtils.is_whitelisted(request.url.path):
1727
return await call_next(request)
28+
assistantTokenKey = settings.ASSISTANT_TOKEN_KEY
29+
assistantToken = request.headers.get(assistantTokenKey)
30+
#if assistantToken and assistantToken.lower().startswith("assistant "):
31+
if assistantToken:
32+
validate_pass, data, assistant = await self.validateAssistant(assistantToken)
33+
if validate_pass:
34+
request.state.current_user = data
35+
request.state.assistant = assistant
36+
return await call_next(request)
37+
return JSONResponse({"error": f"Unauthorized:[{data}]"}, status_code=401)
38+
#validate pass
39+
tokenkey = settings.TOKEN_KEY
1840
token = request.headers.get(tokenkey)
19-
if not token or not token.startswith("Bearer "):
20-
return JSONResponse({"error": "Unauthorized"}, status_code=401)
21-
""" user = await get_current_user()
22-
request.state.user = user """
23-
return await call_next(request)
41+
validate_pass, data = await self.validateToken(token)
42+
if validate_pass:
43+
request.state.current_user = data
44+
return await call_next(request)
45+
return JSONResponse({"error": f"Unauthorized:[{data}]"}, status_code=401)
2446

25-
def is_options(self, request):
47+
def is_options(self, request: Request):
2648
return request.method == "OPTIONS"
2749

28-
50+
async def validateToken(self, token: Optional[str]):
51+
if not token:
52+
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
53+
schema, param = get_authorization_scheme_param(token)
54+
if schema.lower() != "bearer":
55+
return False, f"Token schema error!"
56+
payload = jwt.decode(
57+
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
58+
)
59+
token_data = TokenPayload(**payload)
60+
try:
61+
with Session(engine) as session:
62+
session_user = await get_user_info(session = session, user_id = token_data.id)
63+
session_user = UserInfoDTO.model_validate(session_user)
64+
return True, session_user
65+
except Exception as e:
66+
return False, e
67+
68+
69+
async def validateAssistant(self, assistantToken: Optional[str]):
70+
if not assistantToken:
71+
return False, f"Miss Token[{settings.TOKEN_KEY}]!"
72+
schema, param = get_authorization_scheme_param(assistantToken)
73+
if schema.lower() != "assistant":
74+
return False, f"Token schema error!"
75+
payload = jwt.decode(
76+
param, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
77+
)
78+
token_data = TokenPayload(**payload)
79+
if not payload['assistant_id']:
80+
return False, f"Miss assistant payload error!"
81+
try:
82+
with Session(engine) as session:
83+
session_user = await get_user_info(session = session, user_id = token_data.id)
84+
session_user = UserInfoDTO.model_validate(session_user)
85+
assistant_info = get_assistant_info(session, payload['assistant_id'])
86+
return True, session_user, assistant_info
87+
except Exception as e:
88+
return False, e

backend/apps/system/schemas/auth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ class LocalLoginSchema(BaseModel):
88

99
class CacheNamespace(Enum):
1010
AUTH_INFO = "sqlbot:auth"
11+
EMBEDDED_INFO = "sqlbot:embedded"
1112
def __str__(self):
1213
return self.value
1314
class CacheName(Enum):
1415
USER_INFO = "user:info"
15-
16+
ASSISTANT_INFO = "assistant:info"
1617
def __str__(self):
1718
return self.value

backend/common/core/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def all_cors_origins(self) -> list[str]:
5757
POSTGRES_PASSWORD: str = ""
5858
POSTGRES_DB: str = ""
5959

60-
TOKEN_KEY: str
61-
DEFAULT_PWD: str
60+
TOKEN_KEY: str = "X-SQLBOT-TOKEN"
61+
DEFAULT_PWD: str = "SQLBot@123456"
62+
ASSISTANT_TOKEN_KEY: str = "X-SQLBOT-ASSISTANT-TOKEN"
6263

6364
@computed_field # type: ignore[prop-decorator]
6465
@property

backend/common/core/deps.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,21 @@
11
from typing import Annotated
22

3-
import jwt
4-
from fastapi import Depends, HTTPException, Request, status
5-
# from fastapi.security import OAuth2PasswordBearer
6-
from jwt.exceptions import InvalidTokenError
7-
from pydantic import ValidationError
3+
from fastapi import Depends, Request
84
from sqlmodel import Session
9-
from apps.system.crud.user import get_user_info
105
from apps.system.schemas.system_schema import UserInfoDTO
11-
from common.core.schemas import TokenPayload, XOAuth2PasswordBearer
12-
from common.core import security
13-
from common.core.config import settings
146
from common.core.db import get_session
157
from common.utils.locale import I18n
16-
reusable_oauth2 = XOAuth2PasswordBearer(
17-
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
18-
)
198

209

21-
22-
2310
SessionDep = Annotated[Session, Depends(get_session)]
24-
TokenDep = Annotated[str, Depends(reusable_oauth2)]
2511
i18n = I18n()
2612
async def get_i18n(request: Request):
2713
return i18n(request)
2814

2915
Trans = Annotated[I18n, Depends(get_i18n)]
30-
async def get_current_user(session: SessionDep, token: TokenDep) -> UserInfoDTO:
31-
try:
32-
payload = jwt.decode(
33-
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
34-
)
35-
token_data = TokenPayload(**payload)
36-
except (InvalidTokenError, ValidationError):
37-
raise HTTPException(
38-
status_code=status.HTTP_403_FORBIDDEN,
39-
detail="Could not validate credentials",
40-
)
41-
session_user = await get_user_info(session = session, user_id = token_data.id)
42-
session_user = UserInfoDTO.model_validate(session_user)
43-
if not session_user:
44-
raise HTTPException(status_code=404, detail="User not found")
45-
46-
if session_user.status != 1:
47-
raise HTTPException(status_code=400, detail="Inactive user")
48-
return session_user
16+
async def get_current_user(request: Request) -> UserInfoDTO:
17+
return request.state.current_user
18+
4919
CurrentUser = Annotated[UserInfoDTO, Depends(get_current_user)]
5020

5121

backend/common/core/schemas.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ class Token(SQLModel):
1818
class XOAuth2PasswordBearer(OAuth2PasswordBearer):
1919
async def __call__(self, request: Request) -> Optional[str]:
2020
authorization = request.headers.get(settings.TOKEN_KEY)
21+
if request.headers.get(settings.ASSISTANT_TOKEN_KEY):
22+
authorization = request.headers.get(settings.ASSISTANT_TOKEN_KEY)
2123
scheme, param = get_authorization_scheme_param(authorization)
22-
if not authorization or scheme.lower() != "bearer":
24+
25+
if not authorization or scheme.lower() not in ["bearer", "assistant"]:
2326
if self.auto_error:
2427
raise HTTPException(
2528
status_code=HTTP_401_UNAUTHORIZED,

frontend/src/stores/assistant.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { defineStore } from 'pinia'
2+
import { store } from './index'
3+
4+
interface AssistantState {
5+
token: string
6+
}
7+
8+
export const AssistantStore = defineStore('assistant', {
9+
state: (): AssistantState => {
10+
return {
11+
token: '',
12+
}
13+
},
14+
getters: {
15+
getToken(): string {
16+
return this.token
17+
},
18+
},
19+
actions: {
20+
setToken(token: string) {
21+
this.token = token
22+
},
23+
clear() {
24+
this.$reset()
25+
},
26+
},
27+
})
28+
29+
export const useAssistantStore = () => {
30+
return AssistantStore(store)
31+
}

frontend/src/utils/request.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import axios, {
1010

1111
import { useCache } from '@/utils/useCache'
1212
import { getLocale } from './utils'
13+
import { useAssistantStore } from '@/stores/assistant'
14+
const assistantStore = useAssistantStore()
1315
const { wsCache } = useCache()
1416
// Response data structure
1517
export interface ApiResponse<T = unknown> {
@@ -70,6 +72,10 @@ class HttpService {
7072
if (token && config.headers) {
7173
config.headers['X-SQLBOT-TOKEN'] = `Bearer ${token}`
7274
}
75+
if (assistantStore.getToken) {
76+
config.headers['X-SQLBOT-ASSISTANT-TOKEN'] = `Assistant ${assistantStore.getToken}`
77+
if (config.headers['X-SQLBOT-TOKEN']) config.headers.delete('X-SQLBOT-TOKEN')
78+
}
7379
const locale = getLocale()
7480
if (locale) {
7581
/* const mapping = {

frontend/src/views/embedded/assistant.vue

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ import IconOpeEdit from '@/assets/svg/operate/ope-edit.svg'
6262
import IconOpeDelete from '@/assets/svg/operate/ope-delete.svg'
6363
import { useRoute } from 'vue-router'
6464
import { assistantApi } from '@/api/assistant'
65-
import { useUserStore } from '@/stores/user'
66-
const userStore = useUserStore()
65+
import { useAssistantStore } from '@/stores/assistant'
66+
67+
const assistantStore = useAssistantStore()
6768
const route = useRoute()
6869
6970
const chatRef = ref()
@@ -96,7 +97,7 @@ const loading = ref(true)
9697
onBeforeMount(async () => {
9798
const assistantId = route.params.id
9899
validator.value = await assistantApi.validate(assistantId)
99-
userStore.setToken(validator.value.token)
100+
assistantStore.setToken(validator.value.token)
100101
loading.value = false
101102
})
102103
</script>

0 commit comments

Comments
 (0)