Skip to content

Commit f0bf758

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents dc41cab + 46523fa commit f0bf758

File tree

7 files changed

+159
-122
lines changed

7 files changed

+159
-122
lines changed

backend/apps/datasource/api/datasource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ async def edit_field(session: SessionDep, field: CoreField):
9999

100100

101101
@router.post("/previewData/{id}")
102-
async def edit_local(session: SessionDep, id: int, data: TableObj):
103-
return preview(session, id, data)
102+
async def edit_local(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
103+
return preview(session, current_user, id, data)
104104

105105

106106
@router.post("/fieldEnum/{id}")

backend/apps/datasource/crud/datasource.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,27 @@ def updateField(session: SessionDep, field: CoreField):
237237
update_field(session, field)
238238

239239

240-
def preview(session: SessionDep, id: int, data: TableObj):
240+
def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj):
241241
if data.fields is None or len(data.fields) == 0:
242242
return {"fields": [], "data": [], "sql": ''}
243243

244-
fields = [f.field_name for f in data.fields if f.checked]
244+
# column is checked, and, column permission for data.fields
245+
f_list = [f for f in data.fields if f.checked]
246+
column_permissions = session.query(DsPermission).filter(
247+
and_(DsPermission.table_id == data.table.id, DsPermission.type == 'column')).all()
248+
if column_permissions is not None:
249+
for permission in column_permissions:
250+
# check permission and user in same rules
251+
obj = session.query(DsRules).filter(
252+
and_(DsRules.permission_list.op('@>')(cast([permission.id], JSONB)),
253+
or_(DsRules.user_list.op('@>')(cast([f'{current_user.id}'], JSONB)),
254+
DsRules.user_list.op('@>')(cast([current_user.id], JSONB))))
255+
).first()
256+
if obj is not None:
257+
permission_list = json.loads(permission.permissions)
258+
f_list = filter_list(f_list, permission_list)
259+
260+
fields = [f.field_name for f in f_list]
245261
if fields is None or len(fields) == 0:
246262
return {"fields": [], "data": [], "sql": ''}
247263

backend/apps/system/api/user.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, HTTPException, Query
33
from sqlmodel import func, or_, select, delete as sqlmodel_delete
4-
from apps.system.crud.user import clean_user_cache, get_db_user, single_delete, user_ws_options
4+
from apps.system.crud.user import get_db_user, single_delete, user_ws_options
55
from apps.system.models.system_model import UserWsModel
66
from apps.system.models.user import UserModel
77
from apps.system.schemas.auth import CacheName, CacheNamespace
@@ -95,14 +95,13 @@ async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Tran
9595
return await user_ws_options(session, current_user.id, trans)
9696

9797
@router.put("/ws/{oid}")
98-
# @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
98+
#@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
9999
async def ws_change(session: SessionDep, current_user: CurrentUser, oid: int):
100100
ws_list: list[UserWs] = await user_ws_options(session, current_user.id)
101101
if not any(x.id == oid for x in ws_list):
102102
raise HTTPException(f"oid [{oid}] is invalid!")
103103
user_model: UserModel = get_db_user(session = session, user_id = current_user.id)
104104
user_model.oid = oid
105-
await clean_user_cache(user_model.id)
106105
session.add(user_model)
107106
session.commit()
108107

@@ -181,7 +180,6 @@ async def langChange(session: SessionDep, current_user: CurrentUser, language: U
181180
return {"message": "Language not supported"}
182181
db_user: UserModel = get_db_user(session=session, user_id=current_user.id)
183182
db_user.language = lang
184-
await clean_user_cache(db_user.id)
185183
session.add(db_user)
186184
session.commit()
187185

@@ -202,6 +200,5 @@ async def pwdUpdate(session: SessionDep, current_user: CurrentUser, editor: PwdE
202200
if not verify_md5pwd(editor.pwd, db_user.password):
203201
raise HTTPException("pwd error")
204202
db_user.password = md5pwd(editor.new_pwd)
205-
await clean_user_cache(db_user.id)
206203
session.add(db_user)
207204
session.commit()

backend/apps/system/middleware/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def validateToken(self, token: Optional[str]):
7171
) """
7272
return True, session_user
7373
except Exception as e:
74-
SQLBotLogUtil.exception(f"Token validation error: {str(e)}", exc_info=True)
74+
SQLBotLogUtil.exception(f"Token validation error: {str(e)}")
7575
return False, e
7676

7777

@@ -97,6 +97,6 @@ async def validateAssistant(self, assistantToken: Optional[str]) -> tuple[any]:
9797
assistant_info = AssistantHeader.model_validate(assistant_info.model_dump(exclude_unset=True))
9898
return True, session_user, assistant_info
9999
except Exception as e:
100-
SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}", exc_info=True)
100+
SQLBotLogUtil.exception(f"Assistant validation error: {str(e)}")
101101
# Return False and the exception message
102102
return False, e

backend/common/core/response_middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def dispatch(self, request, call_next):
6868
class exception_handler():
6969
@staticmethod
7070
async def http_exception_handler(request: Request, exc: HTTPException):
71-
SQLBotLogUtil.exception(f"HTTP Exception: {exc.detail}", exc_info=True)
71+
SQLBotLogUtil.error(f"HTTP Exception: {exc.detail}", exc_info=True)
7272
return JSONResponse(
7373
status_code=exc.status_code,
7474
content={
@@ -82,7 +82,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
8282

8383
@staticmethod
8484
async def global_exception_handler(request: Request, exc: Exception):
85-
SQLBotLogUtil.exception(f"Unhandled Exception: {str(exc)}", exc_info=True)
85+
SQLBotLogUtil.error(f"Unhandled Exception: {str(exc)}", exc_info=True)
8686
return JSONResponse(
8787
status_code=500,
8888
content={
Lines changed: 102 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,167 @@
1+
import contextvars
12
from fastapi_cache import FastAPICache
2-
from fastapi_cache.decorator import cache as original_cache
33
from functools import partial, wraps
4-
from typing import Optional, Set, Any, Dict, Tuple
5-
from inspect import Parameter, signature
4+
from typing import Optional, Any, Dict, Tuple
5+
from inspect import signature
66
from contextlib import asynccontextmanager
77
import asyncio
8+
import random
9+
from collections import defaultdict
810

911
from common.utils.utils import SQLBotLogUtil
1012

11-
12-
# 锁管理
13-
_cache_locks = {}
13+
# 使用contextvar来跟踪当前线程已持有的锁
14+
_held_locks = contextvars.ContextVar('held_locks', default=set())
15+
# 高效锁管理器
16+
class LockManager:
17+
_locks = defaultdict(asyncio.Lock)
18+
19+
@classmethod
20+
def get_lock(cls, key: str) -> asyncio.Lock:
21+
return cls._locks[key]
1422

1523
@asynccontextmanager
1624
async def _get_cache_lock(key: str):
17-
lock = _cache_locks.setdefault(key, asyncio.Lock())
18-
async with lock:
19-
try:
20-
yield
21-
finally:
22-
if key in _cache_locks and not lock.locked():
23-
del _cache_locks[key]
24-
25-
def should_skip_param(param: Parameter) -> bool:
26-
"""判断参数是否应该被忽略(依赖注入参数)"""
27-
return (
28-
param.kind == Parameter.VAR_KEYWORD or # **kwargs
29-
param.kind == Parameter.VAR_POSITIONAL or # *args
30-
hasattr(param.annotation, "__module__") and
31-
param.annotation.__module__.startswith(('fastapi', 'starlette', "sqlmodel.orm.session"))
32-
)
25+
# 获取当前已持有的锁集合
26+
current_locks = _held_locks.get()
27+
28+
# 如果已经持有这个锁,直接yield(锁传递)
29+
if key in current_locks:
30+
yield
31+
return
32+
33+
# 否则获取锁并添加到当前上下文中
34+
lock = LockManager.get_lock(key)
35+
try:
36+
await lock.acquire()
37+
# 更新当前持有的锁集合
38+
new_locks = current_locks | {key}
39+
token = _held_locks.set(new_locks)
40+
41+
yield
42+
43+
finally:
44+
# 恢复之前的锁集合
45+
_held_locks.reset(token)
46+
if lock.locked():
47+
lock.release()
3348

3449
def custom_key_builder(
3550
func: Any,
3651
namespace: str = "",
3752
*,
3853
args: Tuple[Any, ...] = (),
3954
kwargs: Dict[str, Any],
40-
additional_skip_args: Optional[Set[str]] = None,
41-
cacheName: Optional[str] = None,
55+
cacheName: str,
4256
keyExpression: Optional[str] = None,
4357
) -> str:
44-
"""完全兼容FastAPICache的键生成器"""
45-
if cacheName:
58+
try:
4659
base_key = f"{namespace}:{cacheName}:"
4760

4861
if keyExpression:
49-
try:
50-
sig = signature(func)
51-
bound_args = sig.bind_partial(*args, **kwargs)
52-
bound_args.apply_defaults()
53-
54-
if keyExpression.startswith("args["):
55-
import re
56-
match = re.match(r"args\[(\d+)\]", keyExpression)
57-
if match:
58-
index = int(match.group(1))
59-
value = bound_args.args[index]
60-
base_key += f"{value}:"
61-
else:
62-
parts = keyExpression.split('.')
63-
value = bound_args.arguments[parts[0]]
64-
for part in parts[1:]:
65-
value = getattr(value, part)
66-
base_key += f"{value}:"
67-
68-
except (IndexError, KeyError, AttributeError) as e:
69-
SQLBotLogUtil.warning(f"Failed to evaluate keyExpression '{keyExpression}': {str(e)}")
62+
sig = signature(func)
63+
bound_args = sig.bind_partial(*args, **kwargs)
64+
bound_args.apply_defaults()
65+
66+
# 支持args[0]格式
67+
if keyExpression.startswith("args["):
68+
import re
69+
if match := re.match(r"args\[(\d+)\]", keyExpression):
70+
index = int(match.group(1))
71+
value = bound_args.args[index]
72+
return f"{base_key}{value}"
73+
74+
# 支持属性路径格式
75+
parts = keyExpression.split('.')
76+
value = bound_args.arguments[parts[0]]
77+
for part in parts[1:]:
78+
value = getattr(value, part)
79+
return f"{base_key}{value}"
7080

71-
return base_key
72-
73-
sig = signature(func)
74-
auto_skip_args = {
75-
name for name, param in sig.parameters.items()
76-
if should_skip_param(param)
77-
}
78-
skip_args = auto_skip_args.union(additional_skip_args or set())
79-
filtered_kwargs = {
80-
k: v for k, v in kwargs.items() if k not in skip_args
81-
}
82-
83-
bound_args = sig.bind_partial(*args, **kwargs)
84-
bound_args.apply_defaults()
85-
86-
filtered_args = []
87-
for i, (name, value) in enumerate(bound_args.arguments.items()):
88-
if i < len(args) and name not in skip_args:
89-
filtered_args.append(value)
90-
filtered_args = tuple(filtered_args)
91-
92-
default_key_builder = FastAPICache.get_key_builder()
93-
return default_key_builder(
94-
func=func,
95-
namespace=namespace,
96-
args=filtered_args,
97-
kwargs=filtered_kwargs,
98-
)
81+
# 默认使用第一个参数作为key
82+
return f"{base_key}{args[0] if args else 'default'}"
83+
84+
except Exception as e:
85+
SQLBotLogUtil.error(f"Key builder error: {str(e)}")
86+
raise ValueError(f"Invalid cache key generation: {e}") from e
9987

10088
def cache(
101-
expire: Optional[int] = 60 * 60 * 24,
102-
namespace: Optional[str] = None,
103-
key_builder: Optional[Any] = None,
89+
expire: int = 60 * 60 * 24,
90+
namespace: str = "",
10491
*,
105-
additional_skip_args: Optional[Set[str]] = None,
106-
cacheName: Optional[str] = None,
92+
cacheName: str, # 必须提供cacheName
10793
keyExpression: Optional[str] = None,
94+
jitter: int = 60, # 默认抖动60秒
10895
):
109-
"""完全兼容的缓存装饰器"""
11096
def decorator(func):
111-
if key_builder is None:
112-
used_key_builder = partial(
113-
custom_key_builder,
114-
additional_skip_args=additional_skip_args,
115-
cacheName=cacheName,
116-
keyExpression=keyExpression
117-
)
118-
else:
119-
used_key_builder = key_builder
120-
97+
# 预先生成key builder
98+
used_key_builder = partial(
99+
custom_key_builder,
100+
cacheName=cacheName,
101+
keyExpression=keyExpression
102+
)
103+
121104
@wraps(func)
122105
async def wrapper(*args, **kwargs):
106+
# 生成缓存键
123107
cache_key = used_key_builder(
124108
func=func,
125-
namespace=namespace or "",
109+
namespace=namespace,
126110
args=args,
127111
kwargs=kwargs
128112
)
129113

114+
# 防击穿锁
130115
async with _get_cache_lock(cache_key):
131-
SQLBotLogUtil.info(f"Using cache key: {cache_key}")
132116
backend = FastAPICache.get_backend()
133-
cached_value = await backend.get(cache_key)
134-
if cached_value is not None:
135-
SQLBotLogUtil.info(f"Cache hit for key: {cache_key}, the value is: {cached_value}")
136-
return cached_value
137117

118+
# 双重检查
119+
if (cached := await backend.get(cache_key)) is not None:
120+
SQLBotLogUtil.debug(f"Cache hit: {cache_key}")
121+
return cached
122+
123+
# 执行函数并缓存结果
138124
result = await func(*args, **kwargs)
139-
await backend.set(cache_key, result, expire)
140-
SQLBotLogUtil.info(f"Cache miss for key: {cache_key}, result cached.")
125+
126+
actual_expire = expire + random.randint(-jitter, jitter)
127+
await backend.set(cache_key, result, actual_expire)
128+
129+
SQLBotLogUtil.debug(f"Cache set: {cache_key} (expire: {actual_expire}s)")
141130
return result
142131

143132
return wrapper
144133
return decorator
145134

146135
def clear_cache(
147-
namespace: Optional[str] = None,
148-
cacheName: Optional[str] = None,
136+
namespace: str = "",
137+
*,
138+
cacheName: str,
149139
keyExpression: Optional[str] = None,
150140
):
151-
"""清除缓存的装饰器"""
141+
"""精确清除单个缓存项的装饰器"""
152142
def decorator(func):
153143
@wraps(func)
154144
async def wrapper(*args, **kwargs):
155145
cache_key = custom_key_builder(
156146
func=func,
157-
namespace=namespace or "",
147+
namespace=namespace,
158148
args=args,
159149
kwargs=kwargs,
160150
cacheName=cacheName,
161151
keyExpression=keyExpression,
162152
)
163153

154+
# 加锁防止竞争
164155
async with _get_cache_lock(cache_key):
165-
if await FastAPICache.get_backend().get(cache_key):
166-
await FastAPICache.clear(key=cache_key)
156+
backend = FastAPICache.get_backend()
157+
result = None
158+
if await backend.get(cache_key):
159+
await backend.clear(cache_key)
167160
result = await func(*args, **kwargs)
168-
SQLBotLogUtil.info(f"Clearing cache for key: {cache_key}")
169-
return result
170-
return await func(*args, **kwargs)
161+
if await backend.get(cache_key):
162+
await backend.clear(cache_key)
163+
SQLBotLogUtil.info(f"Cache cleared: {cache_key}")
164+
return result
171165

172166
return wrapper
173167
return decorator

0 commit comments

Comments
 (0)