1+ from fastapi_cache import FastAPICache
2+ from fastapi_cache .decorator import cache as original_cache
3+ from functools import partial , wraps
4+ from typing import Optional , Set , Any , Dict , Tuple
5+ from inspect import Parameter , signature
6+ import logging
7+
8+ logger = logging .getLogger (__name__ )
9+
10+ def should_skip_param (param : Parameter ) -> bool :
11+ """判断参数是否应该被忽略(依赖注入参数)"""
12+ return (
13+ param .kind == Parameter .VAR_KEYWORD or # **kwargs
14+ param .kind == Parameter .VAR_POSITIONAL or # *args
15+ hasattr (param .annotation , "__module__" ) and
16+ param .annotation .__module__ .startswith (('fastapi' , 'starlette' , "sqlmodel.orm.session" ))
17+ )
18+
19+ def custom_key_builder (
20+ func : Any ,
21+ namespace : str = "" ,
22+ * ,
23+ args : Tuple [Any , ...] = (),
24+ kwargs : Dict [str , Any ],
25+ additional_skip_args : Optional [Set [str ]] = None ,
26+ cacheName : Optional [str ] = None ,
27+ keyExpression : Optional [str ] = None ,
28+ ) -> str :
29+ """
30+ 完全兼容FastAPICache的键生成器
31+ """
32+ if cacheName :
33+ base_key = f"{ namespace } :{ cacheName } :"
34+
35+ if keyExpression :
36+ try :
37+ sig = signature (func )
38+ bound_args = sig .bind_partial (* args , ** kwargs )
39+ bound_args .apply_defaults ()
40+
41+
42+ if keyExpression .startswith ("args[" ):
43+ import re
44+ match = re .match (r"args\[(\d+)\]" , keyExpression )
45+ if match :
46+ index = int (match .group (1 ))
47+ value = bound_args .args [index ]
48+ base_key += f"{ value } :"
49+ else :
50+
51+ parts = keyExpression .split ('.' )
52+ value = bound_args .arguments [parts [0 ]]
53+ for part in parts [1 :]:
54+ value = getattr (value , part )
55+ base_key += f"{ value } :"
56+
57+ except (IndexError , KeyError , AttributeError ) as e :
58+ logger .warning (f"Failed to evaluate keyExpression '{ keyExpression } ': { str (e )} " )
59+
60+ return base_key
61+ # 获取函数签名
62+ sig = signature (func )
63+
64+ # 自动识别要跳过的参数
65+ auto_skip_args = {
66+ name for name , param in sig .parameters .items ()
67+ if should_skip_param (param )
68+ }
69+
70+ # 合并用户指定的额外跳过参数
71+ skip_args = auto_skip_args .union (additional_skip_args or set ())
72+
73+ # 过滤kwargs
74+ filtered_kwargs = {
75+ k : v for k , v in kwargs .items () if k not in skip_args
76+ }
77+
78+ # 过滤args - 将位置参数映射到它们的参数名
79+ bound_args = sig .bind_partial (* args , ** kwargs )
80+ bound_args .apply_defaults ()
81+
82+ filtered_args = []
83+ for i , (name , value ) in enumerate (bound_args .arguments .items ()):
84+ # 只处理位置参数 (在args中的参数)
85+ if i < len (args ) and name not in skip_args :
86+ filtered_args .append (value )
87+ filtered_args = tuple (filtered_args )
88+
89+ # 获取默认键生成器
90+ default_key_builder = FastAPICache .get_key_builder ()
91+ # 调用默认键生成器(严格按照其要求的参数格式)
92+ return default_key_builder (
93+ func = func ,
94+ namespace = namespace ,
95+ args = filtered_args ,
96+ kwargs = filtered_kwargs ,
97+ )
98+
99+ def cache (
100+ expire : Optional [int ] = 60 * 60 * 24 ,
101+ namespace : Optional [str ] = None ,
102+ key_builder : Optional [Any ] = None ,
103+ * ,
104+ additional_skip_args : Optional [Set [str ]] = None ,
105+ cacheName : Optional [str ] = None ,
106+ keyExpression : Optional [str ] = None ,
107+ ):
108+ """
109+ 完全兼容的缓存装饰器
110+ """
111+ def decorator (func ):
112+ if key_builder is None :
113+ used_key_builder = partial (
114+ custom_key_builder ,
115+ additional_skip_args = additional_skip_args ,
116+ cacheName = cacheName ,
117+ keyExpression = keyExpression
118+ )
119+ else :
120+ used_key_builder = key_builder
121+
122+ @wraps (func )
123+ async def wrapper (* args , ** kwargs ):
124+ # 准备键生成器参数
125+ key_builder_args = {
126+ "func" : func ,
127+ "namespace" : namespace ,
128+ "args" : args ,
129+ "kwargs" : kwargs
130+ }
131+
132+ # 生成缓存键
133+ cache_key = used_key_builder (** key_builder_args )
134+ logger .debug (f"Generated cache key: { cache_key } " )
135+
136+ # 使用原始缓存装饰器
137+ return await original_cache (
138+ expire = expire ,
139+ namespace = namespace ,
140+ key_builder = lambda * _ , ** __ : cache_key # 直接使用预生成的key
141+ )(func )(* args , ** kwargs )
142+ return wrapper
143+ return decorator
144+
145+ def clear_cache (
146+ namespace : Optional [str ] = None ,
147+ cacheName : Optional [str ] = None ,
148+ keyExpression : Optional [str ] = None ,
149+ ):
150+ """
151+ 清除缓存的装饰器,参数与 @cache 保持一致
152+ 使用方式:
153+ @clear_cache(namespace="user", cacheName="info", keyExpression="user_id")
154+ async def update_user(user_id: int):
155+ ...
156+ """
157+ def decorator (func ):
158+ @wraps (func )
159+ async def wrapper (* args , ** kwargs ):
160+ # 1. 生成缓存键(复用 custom_key_builder 逻辑)
161+ cache_key = custom_key_builder (
162+ func = func ,
163+ namespace = namespace or "" ,
164+ args = args ,
165+ kwargs = kwargs ,
166+ cacheName = cacheName ,
167+ keyExpression = keyExpression ,
168+ )
169+
170+ logger .debug (f"Clearing cache for key: { cache_key } " )
171+
172+ # 2. 清除缓存
173+ await FastAPICache .clear (key = cache_key )
174+
175+ # 3. 执行原函数
176+ return await func (* args , ** kwargs )
177+
178+ return wrapper
179+ return decorator
0 commit comments