Skip to content

Commit c8fc476

Browse files
feat: Backend cache
1 parent b22111a commit c8fc476

File tree

9 files changed

+237
-34
lines changed

9 files changed

+237
-34
lines changed

backend/alembic/versions/021_user_ws_ddl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ def upgrade():
2929

3030

3131
def downgrade():
32-
op.drop_index(op.f('ix_sys_user_ws_id'), table_name='sys_user_ws')
32+
#op.drop_index(op.f('ix_sys_user_ws_id'), table_name='sys_user_ws')
3333
op.drop_table('sys_user_ws')

backend/apps/system/api/user.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
from fastapi import APIRouter
22
from apps.system.crud.user import get_db_user
33
from apps.system.models.user import UserModel
4+
from apps.system.schemas.auth import CacheName, CacheNamespace
45
from apps.system.schemas.system_schema import PwdEditor, UserCreator, UserEditor, UserGrid, UserLanguage
56
from common.core.deps import CurrentUser, SessionDep
67
from common.core.pagination import Paginator
78
from common.core.schemas import PaginatedResponse, PaginationParams
89
from common.core.security import md5pwd, verify_md5pwd
9-
10+
from common.core.sqlbot_cache import clear_cache
1011
router = APIRouter(tags=["user"], prefix="/user")
1112

12-
1313
@router.get("/info")
14-
async def user_info(session: SessionDep, current_user: CurrentUser):
15-
db_user = get_db_user(session=session, user_id=current_user.id)
16-
if not db_user:
17-
return {"message": "User not found"}
18-
db_user.password = None
19-
return db_user
14+
async def user_info(current_user: CurrentUser):
15+
return current_user
2016

2117

2218
@router.get("/pager/{pageNum}/{pageSize}", response_model=PaginatedResponse[UserGrid])
@@ -48,20 +44,23 @@ async def create(session: SessionDep, creator: UserCreator):
4844
session.commit()
4945

5046
@router.put("")
47+
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="editor.id")
5148
async def update(session: SessionDep, editor: UserEditor):
5249
user_model: UserModel = get_db_user(session = session, user_id = editor.id)
5350
data = editor.model_dump(exclude_unset=True)
5451
user_model.sqlmodel_update(data)
5552
session.add(user_model)
5653
session.commit()
5754

58-
@router.delete("/{id}")
55+
@router.delete("/{id}")
56+
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id")
5957
async def delete(session: SessionDep, id: int):
6058
user_model: UserModel = get_db_user(session = session, user_id = id)
6159
session.delete(user_model)
6260
session.commit()
6361

6462
@router.put("/language")
63+
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
6564
async def langChange(session: SessionDep, current_user: CurrentUser, language: UserLanguage):
6665
lang = language.language
6766
if lang not in ["zh-CN", "en"]:
@@ -73,6 +72,7 @@ async def langChange(session: SessionDep, current_user: CurrentUser, language: U
7372
return {"message": "Language changed successfully", "language": lang}
7473

7574
@router.put("/pwd")
75+
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
7676
async def pwdUpdate(session: SessionDep, current_user: CurrentUser, editor: PwdEditor):
7777
db_user: UserModel = get_db_user(session=session, user_id=current_user.id)
7878
if not verify_md5pwd(editor.pwd, db_user.password):

backend/apps/system/crud/user.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11

22
from sqlmodel import Session, select
33

4-
from apps.system.schemas.system_schema import BaseUserDTO
4+
from apps.system.models.system_model import UserWsModel
5+
from apps.system.schemas.auth import CacheName, CacheNamespace
6+
from apps.system.schemas.system_schema import BaseUserDTO, UserInfoDTO
7+
from common.core.sqlbot_cache import cache
58
from ..models.user import UserModel
69
from common.core.security import verify_md5pwd
710

811
def get_db_user(*, session: Session, user_id: int) -> UserModel:
912
db_user = session.get(UserModel, user_id)
10-
if not db_user:
11-
raise RuntimeError("user not exist")
1213
return db_user
1314

1415
def get_user_by_account(*, session: Session, account: str) -> BaseUserDTO | None:
@@ -18,9 +19,16 @@ def get_user_by_account(*, session: Session, account: str) -> BaseUserDTO | None
1819
return None
1920
return BaseUserDTO.model_validate(db_user.model_dump())
2021

21-
def get_user_info(*, session: Session, user_id: int) -> BaseUserDTO | None:
22-
db_user = get_db_user(session = session, user_id = user_id)
23-
return BaseUserDTO.model_validate(db_user.model_dump())
22+
@cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="user_id")
23+
async def get_user_info(*, session: Session, user_id: int) -> UserInfoDTO | None:
24+
db_user: UserModel = get_db_user(session = session, user_id = user_id)
25+
userInfo = UserInfoDTO.model_validate(db_user.model_dump())
26+
userInfo.isAdmin = userInfo.id == 1 and userInfo.account == 'admin'
27+
if userInfo.isAdmin:
28+
return userInfo
29+
ws_model: UserWsModel = session.exec(select(UserWsModel).where(UserWsModel.uid == userInfo.id, UserWsModel.oid == userInfo.oid)).first()
30+
userInfo.weight = ws_model.weight if ws_model else -1
31+
return userInfo
2432

2533
def authenticate(*, session: Session, account: str, password: str) -> BaseUserDTO | None:
2634
db_user = get_user_by_account(session=session, account=account)
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11

22
from pydantic import BaseModel
3-
3+
from enum import Enum
44

55
class LocalLoginSchema(BaseModel):
66
account: str
7-
password: str
7+
password: str
8+
9+
class CacheNamespace(Enum):
10+
AUTH_INFO = "sqlbot:auth"
11+
def __str__(self):
12+
return self.value
13+
class CacheName(Enum):
14+
USER_INFO = "user:info"
15+
16+
def __str__(self):
17+
return self.value

backend/apps/system/schemas/system_schema.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ def to_dict(self):
2525
"oid": self.oid
2626
}
2727

28-
29-
3028
class UserCreator(BaseUser):
3129
name: str
3230
email: str
@@ -47,4 +45,10 @@ class UserWsBase(BaseModel):
4745
uid: int
4846
oid: int
4947
class UserWsDTO(UserWsBase):
50-
weight: int = 0
48+
weight: int = 0
49+
50+
51+
class UserInfoDTO(UserEditor):
52+
language: str = "zh-CN"
53+
weight: int = 0
54+
isAdmin: bool = False

backend/common/core/deps.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from jwt.exceptions import InvalidTokenError
77
from pydantic import ValidationError
88
from sqlmodel import Session
9-
from apps.system.crud.user import get_db_user
10-
from apps.system.schemas.system_schema import BaseUserDTO
9+
from apps.system.crud.user import get_user_info
10+
from apps.system.schemas.system_schema import UserInfoDTO
1111
from common.core.schemas import TokenPayload, XOAuth2PasswordBearer
1212
from common.core import security
1313
from common.core.config import settings
1414
from common.core.db import get_session
15-
from apps.system.models.user import UserModel
1615
from common.utils.locale import I18n
1716
reusable_oauth2 = XOAuth2PasswordBearer(
1817
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
@@ -28,7 +27,7 @@ async def get_i18n(request: Request):
2827
return i18n(request)
2928

3029
Trans = Annotated[I18n, Depends(get_i18n)]
31-
async def get_current_user(session: SessionDep, token: TokenDep) -> BaseUserDTO:
30+
async def get_current_user(session: SessionDep, token: TokenDep) -> UserInfoDTO:
3231
try:
3332
payload = jwt.decode(
3433
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
@@ -39,16 +38,15 @@ async def get_current_user(session: SessionDep, token: TokenDep) -> BaseUserDTO:
3938
status_code=status.HTTP_403_FORBIDDEN,
4039
detail="Could not validate credentials",
4140
)
42-
session_user: UserModel = get_db_user(session = session, user_id = token_data.id)
41+
session_user = await get_user_info(session = session, user_id = token_data.id)
42+
session_user = UserInfoDTO.model_validate(session_user)
4343
if not session_user:
4444
raise HTTPException(status_code=404, detail="User not found")
45-
user = BaseUserDTO.model_validate(session_user.model_dump())
46-
if not user:
47-
raise HTTPException(status_code=404, detail="User not found")
48-
""" if not user.is_active:
49-
raise HTTPException(status_code=400, detail="Inactive user") """
50-
return user
51-
CurrentUser = Annotated[BaseUserDTO, Depends(get_current_user)]
45+
46+
if session_user.status != 1:
47+
raise HTTPException(status_code=400, detail="Inactive user")
48+
return session_user
49+
CurrentUser = Annotated[UserInfoDTO, Depends(get_current_user)]
5250

5351

5452

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from fastapi_cache import FastAPICache
2+
from fastapi_cache.decorator import cache as original_cache
3+
from functools import partial, wraps
4+
from typing import Optional, Set, Any, Dict, Tuple
5+
from inspect import Parameter, signature
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
9+
10+
def should_skip_param(param: Parameter) -> bool:
11+
"""判断参数是否应该被忽略(依赖注入参数)"""
12+
return (
13+
param.kind == Parameter.VAR_KEYWORD or # **kwargs
14+
param.kind == Parameter.VAR_POSITIONAL or # *args
15+
hasattr(param.annotation, "__module__") and
16+
param.annotation.__module__.startswith(('fastapi', 'starlette', "sqlmodel.orm.session"))
17+
)
18+
19+
def custom_key_builder(
20+
func: Any,
21+
namespace: str = "",
22+
*,
23+
args: Tuple[Any, ...] = (),
24+
kwargs: Dict[str, Any],
25+
additional_skip_args: Optional[Set[str]] = None,
26+
cacheName: Optional[str] = None,
27+
keyExpression: Optional[str] = None,
28+
) -> str:
29+
"""
30+
完全兼容FastAPICache的键生成器
31+
"""
32+
if cacheName:
33+
base_key = f"{namespace}:{cacheName}:"
34+
35+
if keyExpression:
36+
try:
37+
sig = signature(func)
38+
bound_args = sig.bind_partial(*args, **kwargs)
39+
bound_args.apply_defaults()
40+
41+
42+
if keyExpression.startswith("args["):
43+
import re
44+
match = re.match(r"args\[(\d+)\]", keyExpression)
45+
if match:
46+
index = int(match.group(1))
47+
value = bound_args.args[index]
48+
base_key += f"{value}:"
49+
else:
50+
51+
parts = keyExpression.split('.')
52+
value = bound_args.arguments[parts[0]]
53+
for part in parts[1:]:
54+
value = getattr(value, part)
55+
base_key += f"{value}:"
56+
57+
except (IndexError, KeyError, AttributeError) as e:
58+
logger.warning(f"Failed to evaluate keyExpression '{keyExpression}': {str(e)}")
59+
60+
return base_key
61+
# 获取函数签名
62+
sig = signature(func)
63+
64+
# 自动识别要跳过的参数
65+
auto_skip_args = {
66+
name for name, param in sig.parameters.items()
67+
if should_skip_param(param)
68+
}
69+
70+
# 合并用户指定的额外跳过参数
71+
skip_args = auto_skip_args.union(additional_skip_args or set())
72+
73+
# 过滤kwargs
74+
filtered_kwargs = {
75+
k: v for k, v in kwargs.items() if k not in skip_args
76+
}
77+
78+
# 过滤args - 将位置参数映射到它们的参数名
79+
bound_args = sig.bind_partial(*args, **kwargs)
80+
bound_args.apply_defaults()
81+
82+
filtered_args = []
83+
for i, (name, value) in enumerate(bound_args.arguments.items()):
84+
# 只处理位置参数 (在args中的参数)
85+
if i < len(args) and name not in skip_args:
86+
filtered_args.append(value)
87+
filtered_args = tuple(filtered_args)
88+
89+
# 获取默认键生成器
90+
default_key_builder = FastAPICache.get_key_builder()
91+
# 调用默认键生成器(严格按照其要求的参数格式)
92+
return default_key_builder(
93+
func=func,
94+
namespace=namespace,
95+
args=filtered_args,
96+
kwargs=filtered_kwargs,
97+
)
98+
99+
def cache(
100+
expire: Optional[int] = 60 * 60 * 24,
101+
namespace: Optional[str] = None,
102+
key_builder: Optional[Any] = None,
103+
*,
104+
additional_skip_args: Optional[Set[str]] = None,
105+
cacheName: Optional[str] = None,
106+
keyExpression: Optional[str] = None,
107+
):
108+
"""
109+
完全兼容的缓存装饰器
110+
"""
111+
def decorator(func):
112+
if key_builder is None:
113+
used_key_builder = partial(
114+
custom_key_builder,
115+
additional_skip_args=additional_skip_args,
116+
cacheName=cacheName,
117+
keyExpression=keyExpression
118+
)
119+
else:
120+
used_key_builder = key_builder
121+
122+
@wraps(func)
123+
async def wrapper(*args, **kwargs):
124+
# 准备键生成器参数
125+
key_builder_args = {
126+
"func": func,
127+
"namespace": namespace,
128+
"args": args,
129+
"kwargs": kwargs
130+
}
131+
132+
# 生成缓存键
133+
cache_key = used_key_builder(**key_builder_args)
134+
logger.debug(f"Generated cache key: {cache_key}")
135+
136+
# 使用原始缓存装饰器
137+
return await original_cache(
138+
expire=expire,
139+
namespace=namespace,
140+
key_builder=lambda *_, **__: cache_key # 直接使用预生成的key
141+
)(func)(*args, **kwargs)
142+
return wrapper
143+
return decorator
144+
145+
def clear_cache(
146+
namespace: Optional[str] = None,
147+
cacheName: Optional[str] = None,
148+
keyExpression: Optional[str] = None,
149+
):
150+
"""
151+
清除缓存的装饰器,参数与 @cache 保持一致
152+
使用方式:
153+
@clear_cache(namespace="user", cacheName="info", keyExpression="user_id")
154+
async def update_user(user_id: int):
155+
...
156+
"""
157+
def decorator(func):
158+
@wraps(func)
159+
async def wrapper(*args, **kwargs):
160+
# 1. 生成缓存键(复用 custom_key_builder 逻辑)
161+
cache_key = custom_key_builder(
162+
func=func,
163+
namespace=namespace or "",
164+
args=args,
165+
kwargs=kwargs,
166+
cacheName=cacheName,
167+
keyExpression=keyExpression,
168+
)
169+
170+
logger.debug(f"Clearing cache for key: {cache_key}")
171+
172+
# 2. 清除缓存
173+
await FastAPICache.clear(key=cache_key)
174+
175+
# 3. 执行原函数
176+
return await func(*args, **kwargs)
177+
178+
return wrapper
179+
return decorator

backend/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from fastapi_mcp import FastApiMCP
1515
from fastapi.staticfiles import StaticFiles
1616
import sqlbot_xpack
17-
17+
from fastapi_cache import FastAPICache
18+
from fastapi_cache.backends.inmemory import InMemoryBackend
1819

1920
def run_migrations():
2021
alembic_cfg = Config("alembic.ini")
@@ -24,6 +25,8 @@ def run_migrations():
2425
@asynccontextmanager
2526
async def lifespan(app: FastAPI):
2627
run_migrations()
28+
FastAPICache.init(InMemoryBackend())
29+
print("✅ FastAPICache 初始化完成")
2730
yield
2831

2932

0 commit comments

Comments
 (0)