Skip to content

Commit 3ffeb41

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents 34bb7fb + ea6e654 commit 3ffeb41

File tree

10 files changed

+176
-62
lines changed

10 files changed

+176
-62
lines changed

backend/apps/chat/task/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,8 @@ def run_analysis_or_predict_task(llm_service: LLMService, action_type: str, base
846846
try:
847847
llm_service.set_record(save_analysis_predict_record(llm_service.session, base_record, action_type))
848848

849+
yield orjson.dumps({'type': 'id', 'id': llm_service.get_record().id}).decode() + '\n\n'
850+
849851
if action_type == 'analysis':
850852
# generate analysis
851853
analysis_res = llm_service.generate_analysis()

backend/apps/system/api/user.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, HTTPException, Query
33
from sqlmodel import func, or_, select, delete as sqlmodel_delete
4-
from apps.system.crud.user import get_db_user, single_delete, user_ws_options
4+
from apps.system.crud.user import clean_user_cache, get_db_user, single_delete, user_ws_options
55
from apps.system.models.system_model import UserWsModel
66
from apps.system.models.user import UserModel
77
from apps.system.schemas.auth import CacheName, CacheNamespace
@@ -95,13 +95,14 @@ async def ws_options(session: SessionDep, current_user: CurrentUser, trans: Tran
9595
return await user_ws_options(session, current_user.id, trans)
9696

9797
@router.put("/ws/{oid}")
98-
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
98+
# @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
9999
async def ws_change(session: SessionDep, current_user: CurrentUser, oid: int):
100100
ws_list: list[UserWs] = await user_ws_options(session, current_user.id)
101101
if not any(x.id == oid for x in ws_list):
102102
raise HTTPException(f"oid [{oid}] is invalid!")
103103
user_model: UserModel = get_db_user(session = session, user_id = current_user.id)
104104
user_model.oid = oid
105+
await clean_user_cache(user_model.id)
105106
session.add(user_model)
106107
session.commit()
107108

@@ -173,13 +174,14 @@ async def batch_del(session: SessionDep, id_list: list[int]):
173174
await single_delete(session, id)
174175

175176
@router.put("/language")
176-
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
177+
#@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
177178
async def langChange(session: SessionDep, current_user: CurrentUser, language: UserLanguage):
178179
lang = language.language
179180
if lang not in ["zh-CN", "en"]:
180181
return {"message": "Language not supported"}
181182
db_user: UserModel = get_db_user(session=session, user_id=current_user.id)
182183
db_user.language = lang
184+
await clean_user_cache(db_user.id)
183185
session.add(db_user)
184186
session.commit()
185187

@@ -194,11 +196,12 @@ async def pwdReset(session: SessionDep, current_user: CurrentUser, id: int):
194196
session.commit()
195197

196198
@router.put("/pwd")
197-
@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
199+
#@clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="current_user.id")
198200
async def pwdUpdate(session: SessionDep, current_user: CurrentUser, editor: PwdEditor):
199201
db_user: UserModel = get_db_user(session=session, user_id=current_user.id)
200202
if not verify_md5pwd(editor.pwd, db_user.password):
201203
raise HTTPException("pwd error")
202204
db_user.password = md5pwd(editor.new_pwd)
205+
await clean_user_cache(db_user.id)
203206
session.add(db_user)
204207
session.commit()
Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
from fastapi_cache import FastAPICache
22
from fastapi_cache.decorator import cache as original_cache
33
from functools import partial, wraps
4-
from typing import Optional, Set, Any, Dict, Tuple
4+
from typing import Optional, Set, Any, Dict, Tuple, Callable
55
from inspect import Parameter, signature
66
import logging
7+
from contextlib import asynccontextmanager
8+
import asyncio
79

810
logger = logging.getLogger(__name__)
911

12+
# 锁管理
13+
_cache_locks = {}
14+
15+
@asynccontextmanager
16+
async def _get_cache_lock(key: str):
17+
lock = _cache_locks.setdefault(key, asyncio.Lock())
18+
async with lock:
19+
try:
20+
yield
21+
finally:
22+
if key in _cache_locks and not lock.locked():
23+
del _cache_locks[key]
24+
1025
def should_skip_param(param: Parameter) -> bool:
1126
"""判断参数是否应该被忽略(依赖注入参数)"""
1227
return (
@@ -26,9 +41,7 @@ def custom_key_builder(
2641
cacheName: Optional[str] = None,
2742
keyExpression: Optional[str] = None,
2843
) -> str:
29-
"""
30-
完全兼容FastAPICache的键生成器
31-
"""
44+
"""完全兼容FastAPICache的键生成器"""
3245
if cacheName:
3346
base_key = f"{namespace}:{cacheName}:"
3447

@@ -38,7 +51,6 @@ def custom_key_builder(
3851
bound_args = sig.bind_partial(*args, **kwargs)
3952
bound_args.apply_defaults()
4053

41-
4254
if keyExpression.startswith("args["):
4355
import re
4456
match = re.match(r"args\[(\d+)\]", keyExpression)
@@ -47,7 +59,6 @@ def custom_key_builder(
4759
value = bound_args.args[index]
4860
base_key += f"{value}:"
4961
else:
50-
5162
parts = keyExpression.split('.')
5263
value = bound_args.arguments[parts[0]]
5364
for part in parts[1:]:
@@ -58,37 +69,27 @@ def custom_key_builder(
5869
logger.warning(f"Failed to evaluate keyExpression '{keyExpression}': {str(e)}")
5970

6071
return base_key
61-
# 获取函数签名
62-
sig = signature(func)
6372

64-
# 自动识别要跳过的参数
73+
sig = signature(func)
6574
auto_skip_args = {
6675
name for name, param in sig.parameters.items()
6776
if should_skip_param(param)
6877
}
69-
70-
# 合并用户指定的额外跳过参数
7178
skip_args = auto_skip_args.union(additional_skip_args or set())
72-
73-
# 过滤kwargs
7479
filtered_kwargs = {
7580
k: v for k, v in kwargs.items() if k not in skip_args
7681
}
7782

78-
# 过滤args - 将位置参数映射到它们的参数名
7983
bound_args = sig.bind_partial(*args, **kwargs)
8084
bound_args.apply_defaults()
8185

8286
filtered_args = []
8387
for i, (name, value) in enumerate(bound_args.arguments.items()):
84-
# 只处理位置参数 (在args中的参数)
8588
if i < len(args) and name not in skip_args:
8689
filtered_args.append(value)
8790
filtered_args = tuple(filtered_args)
8891

89-
# 获取默认键生成器
9092
default_key_builder = FastAPICache.get_key_builder()
91-
# 调用默认键生成器(严格按照其要求的参数格式)
9293
return default_key_builder(
9394
func=func,
9495
namespace=namespace,
@@ -105,9 +106,7 @@ def cache(
105106
cacheName: Optional[str] = None,
106107
keyExpression: Optional[str] = None,
107108
):
108-
"""
109-
完全兼容的缓存装饰器
110-
"""
109+
"""完全兼容的缓存装饰器"""
111110
def decorator(func):
112111
if key_builder is None:
113112
used_key_builder = partial(
@@ -121,24 +120,23 @@ def decorator(func):
121120

122121
@wraps(func)
123122
async def wrapper(*args, **kwargs):
124-
# 准备键生成器参数
125-
key_builder_args = {
126-
"func": func,
127-
"namespace": namespace,
128-
"args": args,
129-
"kwargs": kwargs
130-
}
131-
132-
# 生成缓存键
133-
cache_key = used_key_builder(**key_builder_args)
134-
logger.debug(f"Generated cache key: {cache_key}")
123+
cache_key = used_key_builder(
124+
func=func,
125+
namespace=namespace or "",
126+
args=args,
127+
kwargs=kwargs
128+
)
135129

136-
# 使用原始缓存装饰器
137-
return await original_cache(
138-
expire=expire,
139-
namespace=namespace,
140-
key_builder=lambda *_, **__: cache_key # 直接使用预生成的key
141-
)(func)(*args, **kwargs)
130+
async with _get_cache_lock(cache_key):
131+
backend = FastAPICache.get_backend()
132+
cached_value = await backend.get(cache_key)
133+
if cached_value is not None:
134+
return cached_value
135+
136+
result = await func(*args, **kwargs)
137+
await backend.set(cache_key, result, expire)
138+
return result
139+
142140
return wrapper
143141
return decorator
144142

@@ -147,17 +145,10 @@ def clear_cache(
147145
cacheName: Optional[str] = None,
148146
keyExpression: Optional[str] = None,
149147
):
150-
"""
151-
清除缓存的装饰器,参数与 @cache 保持一致
152-
使用方式:
153-
@clear_cache(namespace="user", cacheName="info", keyExpression="user_id")
154-
async def update_user(user_id: int):
155-
...
156-
"""
148+
"""清除缓存的装饰器"""
157149
def decorator(func):
158150
@wraps(func)
159151
async def wrapper(*args, **kwargs):
160-
# 1. 生成缓存键(复用 custom_key_builder 逻辑)
161152
cache_key = custom_key_builder(
162153
func=func,
163154
namespace=namespace or "",
@@ -167,13 +158,10 @@ async def wrapper(*args, **kwargs):
167158
keyExpression=keyExpression,
168159
)
169160

170-
logger.debug(f"Clearing cache for key: {cache_key}")
171-
# 2. 清除缓存
172-
if await FastAPICache.get_backend().get(cache_key):
161+
async with _get_cache_lock(cache_key):
173162
await FastAPICache.clear(key=cache_key)
174-
175-
# 3. 执行原函数
176-
return await func(*args, **kwargs)
163+
result = await func(*args, **kwargs)
164+
return result
177165

178166
return wrapper
179167
return decorator

backend/template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ template:
129129
predict:
130130
system: |
131131
### 说明:
132-
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以json格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json格式返回。
132+
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以json格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json格式返回,返回的格式需要与传入的数据格式保持一致
133133
```json
134134
135135
无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)"

frontend/public/assistant.js

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
const getChatContainerHtml = (data) => {
4141
return `
4242
<div id="sqlbot-assistant-chat-container">
43-
<iframe id="sqlbot-assistant-chat" allow="microphone" src="${data.domain_url}/#/assistant?id=${data.id}"></iframe>
43+
<iframe id="sqlbot-assistant-chat-iframe-${data.id}" allow="microphone" src="${data.domain_url}/#/assistant?id=${data.id}"></iframe>
4444
<div class="sqlbot-assistant-operate">
4545
<div class="sqlbot-assistant-closeviewport sqlbot-assistant-viewportnone">
4646
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 20 20" fill="none">
@@ -311,7 +311,7 @@
311311
#sqlbot-assistant #sqlbot-assistant-chat-container .sqlbot-assistant-viewportnone{
312312
display:none;
313313
}
314-
#sqlbot-assistant #sqlbot-assistant-chat-container #sqlbot-assistant-chat{
314+
#sqlbot-assistant #sqlbot-assistant-chat-container #sqlbot-assistant-chat-iframe-${data.id}{
315315
height:100%;
316316
width:100%;
317317
border: none;
@@ -334,7 +334,79 @@
334334
const url = new URL(src)
335335
return url.searchParams.get(key)
336336
}
337+
function parsrCertificate(config) {
338+
const certificateList = config.certificate
339+
if (!certificateList?.length) {
340+
return null
341+
}
342+
const list = certificateList.map((item) => formatCertificate(item)).filter((item) => !!item)
343+
return JSON.stringify(list)
344+
}
345+
function isEmpty(obj) {
346+
return obj == null || typeof obj == 'undefined'
347+
}
348+
function formatCertificate(item) {
349+
const { type, source, target, target_key, target_val } = item
350+
let source_val = null
351+
if (type.toLocaleLowerCase() == 'localstorage') {
352+
source_val = localStorage.getItem(source)
353+
}
354+
if (type.toLocaleLowerCase() == 'sessionstorage') {
355+
source_val = sessionStorage.getItem(source)
356+
}
357+
if (type.toLocaleLowerCase() == 'cookie') {
358+
source_val = getCookie(source)
359+
}
360+
if (type.toLocaleLowerCase() == 'custom') {
361+
source_val = source
362+
}
363+
if (isEmpty(source_val)) {
364+
return null
365+
}
366+
return {
367+
target,
368+
key: target_key || source,
369+
value: (target_val && eval(target_val)) || source_val,
370+
}
371+
}
372+
function getCookie(key) {
373+
if (!key || !document.cookie) {
374+
return null
375+
}
376+
const cookies = document.cookie.split(';')
377+
for (let i = 0; i < cookies.length; i++) {
378+
const cookie = cookies[i].trim()
337379

380+
if (cookie.startsWith(key + '=')) {
381+
return decodeURIComponent(cookie.substring(key.length + 1))
382+
}
383+
}
384+
return null
385+
}
386+
function registerMessageEvent(id, data) {
387+
const iframe = document.getElementById(`sqlbot-assistant-chat-iframe-${id}`)
388+
const url = iframe.src
389+
const eventName = 'sqlbot_assistant_event'
390+
window.addEventListener('message', (event) => {
391+
if (event.data?.eventName === eventName) {
392+
if (event.data?.messageId !== id) {
393+
return
394+
}
395+
if (event.data?.busi == 'ready' && event.data?.ready) {
396+
const certificate = parsrCertificate(data)
397+
console.log(certificate)
398+
params = {
399+
busi: 'certificate',
400+
certificate,
401+
eventName,
402+
messageId: id,
403+
}
404+
const contentWindow = iframe.contentWindow
405+
contentWindow.postMessage(params, url)
406+
}
407+
}
408+
})
409+
}
338410
function loadScript(src, id) {
339411
const domain_url = getDomain(src)
340412
let url = `${domain_url}/api/v1/system/assistant/info/${id}`
@@ -353,6 +425,10 @@
353425
tempData = Object.assign(tempData, config)
354426
}
355427
initsqlbot_assistant(tempData)
428+
if (data.type == 1) {
429+
registerMessageEvent(id, tempData)
430+
// postMessage the certificate to iframe
431+
}
356432
})
357433
}
358434
function getDomain(src) {

0 commit comments

Comments
 (0)