Skip to content

Commit 19d43d4

Browse files
committed
Merge branch 'main' of https://github.com/dataease/SQLBot
2 parents ff2cfce + ddcbd54 commit 19d43d4

File tree

8 files changed

+181
-40
lines changed

8 files changed

+181
-40
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""036_modify_assistant
2+
3+
Revision ID: 646e7ca28e0e
4+
Revises: 29559ee607af
5+
Create Date: 2025-08-18 16:12:46.041413
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '646e7ca28e0e'
15+
down_revision = '29559ee607af'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
op.add_column('sys_assistant', sa.Column('app_id', sa.String(255), nullable=True, comment='app_id'))
22+
op.add_column('sys_assistant', sa.Column('app_secret', sa.String(255), nullable=True, comment='app_secret'))
23+
24+
25+
def downgrade():
26+
op.drop_column('sys_assistant', 'app_id')
27+
op.drop_column('sys_assistant', 'app_secret')

backend/apps/system/api/assistant.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,23 @@
22
import json
33
import os
44
from typing import List, Optional
5-
from fastapi import APIRouter, FastAPI, Form, HTTPException, Query, Request, Response, UploadFile
5+
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response, UploadFile
66
from fastapi.responses import StreamingResponse
7-
from sqlmodel import Session, select
7+
from sqlmodel import select
88
from apps.system.crud.assistant import get_assistant_info
9+
from apps.system.crud.assistant_manage import dynamic_upgrade_cors, save
910
from apps.system.models.system_model import AssistantModel
1011
from apps.system.schemas.auth import CacheName, CacheNamespace
1112
from apps.system.schemas.system_schema import AssistantBase, AssistantDTO, AssistantUiSchema, AssistantValidator
1213
from common.core.deps import SessionDep, Trans
1314
from common.core.security import create_access_token
1415
from common.core.sqlbot_cache import clear_cache
1516
from common.utils.time import get_timestamp
16-
from starlette.middleware.cors import CORSMiddleware
17+
1718
from common.core.config import settings
1819
from common.utils.utils import get_origin_from_referer
1920
from sqlbot_xpack.file_utils import SQLBotFileUtils
21+
2022
router = APIRouter(tags=["system/assistant"], prefix="/system/assistant")
2123

2224
@router.get("/info/{id}")
@@ -104,16 +106,12 @@ async def ui(session: SessionDep, data: str = Form(), files: List[UploadFile] =
104106

105107
@router.get("", response_model=list[AssistantModel])
106108
async def query(session: SessionDep):
107-
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.name, AssistantModel.create_time)).all()
109+
list_result = session.exec(select(AssistantModel).where(AssistantModel.type.in_([0, 1])).order_by(AssistantModel.name, AssistantModel.create_time)).all()
108110
return list_result
109111

110112
@router.post("")
111113
async def add(request: Request, session: SessionDep, creator: AssistantBase):
112-
db_model = AssistantModel.model_validate(creator)
113-
db_model.create_time = get_timestamp()
114-
session.add(db_model)
115-
session.commit()
116-
dynamic_upgrade_cors(request=request, session=session)
114+
save(request, session, creator)
117115

118116

119117
@router.put("")
@@ -147,26 +145,7 @@ async def delete(request: Request, session: SessionDep, id: int):
147145
session.commit()
148146
dynamic_upgrade_cors(request=request, session=session)
149147

150-
def dynamic_upgrade_cors(request: Request, session: Session):
151-
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
152-
seen = set()
153-
unique_domains = []
154-
for item in list_result:
155-
if item.domain:
156-
for domain in item.domain.split(','):
157-
domain = domain.strip()
158-
if domain and domain not in seen:
159-
seen.add(domain)
160-
unique_domains.append(domain)
161-
app: FastAPI = request.app
162-
cors_middleware = None
163-
for middleware in app.user_middleware:
164-
if middleware.cls == CORSMiddleware:
165-
cors_middleware = middleware
166-
break
167-
if cors_middleware:
168-
updated_origins = list(set(settings.all_cors_origins + unique_domains))
169-
cors_middleware.kwargs['allow_origins'] = updated_origins
148+
170149

171150

172151

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
3+
from fastapi import FastAPI, Request
4+
from sqlmodel import Session, select
5+
from starlette.middleware.cors import CORSMiddleware
6+
from apps.system.schemas.system_schema import AssistantBase
7+
from common.core.config import settings
8+
from apps.system.models.system_model import AssistantModel
9+
from common.utils.time import get_timestamp
10+
11+
12+
def dynamic_upgrade_cors(request: Request, session: Session):
13+
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
14+
seen = set()
15+
unique_domains = []
16+
for item in list_result:
17+
if item.domain:
18+
for domain in item.domain.split(','):
19+
domain = domain.strip()
20+
if domain and domain not in seen:
21+
seen.add(domain)
22+
unique_domains.append(domain)
23+
app: FastAPI = request.app
24+
cors_middleware = None
25+
for middleware in app.user_middleware:
26+
if middleware.cls == CORSMiddleware:
27+
cors_middleware = middleware
28+
break
29+
if cors_middleware:
30+
updated_origins = list(set(settings.all_cors_origins + unique_domains))
31+
cors_middleware.kwargs['allow_origins'] = updated_origins
32+
33+
34+
async def save(request: Request, session: Session, creator: AssistantBase):
35+
db_model = AssistantModel.model_validate(creator)
36+
db_model.create_time = get_timestamp()
37+
session.add(db_model)
38+
session.commit()
39+
dynamic_upgrade_cors(request=request, session=session)

backend/apps/system/models/system_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class AssistantBaseModel(SQLModel):
5050
description: Optional[str] = Field(sa_type = Text(), nullable=True)
5151
configuration: Optional[str] = Field(sa_type = Text(), nullable=True)
5252
create_time: int = Field(default=0, sa_type=BigInteger())
53+
app_id: Optional[str] = Field(default=None, max_length=255, nullable=True)
54+
app_secret: Optional[str] = Field(default=None, max_length=255, nullable=True)
5355

5456
class AssistantModel(SnowflakeBase, AssistantBaseModel, table=True):
5557
__tablename__ = "sys_assistant"

backend/common/core/sqlbot_cache.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def custom_key_builder(
1616
kwargs: Dict[str, Any],
1717
cacheName: str,
1818
keyExpression: Optional[str] = None,
19-
) -> str:
19+
) -> str | list[str]:
2020
try:
2121
base_key = f"{namespace}:{cacheName}:"
2222

@@ -31,13 +31,17 @@ def custom_key_builder(
3131
if match := re.match(r"args\[(\d+)\]", keyExpression):
3232
index = int(match.group(1))
3333
value = bound_args.args[index]
34+
if isinstance(value, list):
35+
return [f"{base_key}{v}" for v in value]
3436
return f"{base_key}{value}"
3537

3638
# 支持属性路径格式
3739
parts = keyExpression.split('.')
3840
value = bound_args.arguments[parts[0]]
3941
for part in parts[1:]:
4042
value = getattr(value, part)
43+
if isinstance(value, list):
44+
return [f"{base_key}{v}" for v in value]
4145
return f"{base_key}{value}"
4246

4347
# 默认使用第一个参数作为key
@@ -102,14 +106,16 @@ async def wrapper(*args, **kwargs):
102106
cacheName=cacheName,
103107
keyExpression=keyExpression,
104108
)
109+
ket_list = cache_key if isinstance(cache_key, list) else [cache_key]
105110
backend = FastAPICache.get_backend()
106-
if await backend.get(cache_key):
107-
if settings.CACHE_TYPE.lower() == "redis":
108-
redis = backend.redis
109-
await redis.delete(cache_key)
110-
else:
111-
await backend.clear(key=cache_key)
112-
SQLBotLogUtil.debug(f"Cache cleared: {cache_key}")
111+
for temp_cache_key in ket_list:
112+
if await backend.get(temp_cache_key):
113+
if settings.CACHE_TYPE.lower() == "redis":
114+
redis = backend.redis
115+
await redis.delete(temp_cache_key)
116+
else:
117+
await backend.clear(key=temp_cache_key)
118+
SQLBotLogUtil.debug(f"Cache cleared: {temp_cache_key}")
113119
return await func(*args, **kwargs)
114120

115121
return wrapper

backend/common/utils/whitelist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"/system/assistant/validator*",
3232
"/system/assistant/info/*",
3333
"/system/assistant/picture/*",
34+
"/system/embedded*",
3435
"/datasource/uploadExcel"
3536
]
3637

frontend/src/router/watch.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const appearanceStore = useAppearanceStoreWithOut()
88
const userStore = useUserStore()
99
const { wsCache } = useCache()
1010
const whiteList = ['/login']
11-
const assistantWhiteList = ['/assistant']
11+
const assistantWhiteList = ['/assistant', '/embeddedPage']
1212
export const watchRouter = (router: any) => {
1313
router.beforeEach(async (to: any, from: any, next: any) => {
1414
await loadXpackStatic()

frontend/src/views/embedded/page.vue

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,100 @@
11
<template>
22
<div class="sqlbot--embedded-page">
3-
<chat-component ref="chatRef" />
3+
<chat-component v-if="!loading" ref="chatRef" />
44
</div>
55
</template>
66
<script setup lang="ts">
77
import ChatComponent from '@/views/chat/index.vue'
8-
import { ref } from 'vue'
8+
import { onBeforeMount, onBeforeUnmount, ref } from 'vue'
9+
import { useRoute } from 'vue-router'
10+
import { assistantApi } from '@/api/assistant'
11+
import { useAssistantStore } from '@/stores/assistant'
912
1013
const chatRef = ref()
14+
const assistantStore = useAssistantStore()
15+
const route = useRoute()
16+
const assistantName = ref('')
17+
18+
const validator = ref({
19+
id: '',
20+
valid: false,
21+
id_match: false,
22+
token: '',
23+
})
24+
const loading = ref(true)
25+
const eventName = 'sqlbot_assistant_event'
26+
const communicationCb = async (event: any) => {
27+
if (event.data?.eventName === eventName) {
28+
if (event.data?.messageId !== route.query.id) {
29+
return
30+
}
31+
if (event.data?.busi == 'certificate') {
32+
const certificate = event.data['certificate']
33+
assistantStore.setType(1)
34+
assistantStore.setCertificate(certificate)
35+
assistantStore.resolveCertificate(certificate)
36+
}
37+
if (event.data?.busi == 'setOnline') {
38+
setFormatOnline(event.data.online)
39+
}
40+
}
41+
}
42+
const setFormatOnline = (text?: any) => {
43+
if (text === null || typeof text === 'undefined') {
44+
assistantStore.setOnline(false)
45+
return
46+
}
47+
if (typeof text === 'boolean') {
48+
assistantStore.setOnline(text)
49+
return
50+
}
51+
if (typeof text === 'string') {
52+
assistantStore.setOnline(text.toLowerCase() === 'true')
53+
return
54+
}
55+
assistantStore.setOnline(false)
56+
}
57+
58+
onBeforeMount(async () => {
59+
debugger
60+
const assistantId = route.query.id
61+
if (!assistantId) {
62+
ElMessage.error('Miss embedded id, please check embedded url')
63+
return
64+
}
65+
const online = route.query.online
66+
setFormatOnline(online)
67+
68+
let name = route.query.name
69+
if (name) {
70+
assistantName.value = decodeURIComponent(name.toString())
71+
}
72+
const now = Date.now()
73+
assistantStore.setFlag(now)
74+
assistantStore.setId(assistantId?.toString() || '')
75+
const param = {
76+
id: assistantId,
77+
virtual: assistantStore.getFlag,
78+
online,
79+
}
80+
validator.value = await assistantApi.validate(param)
81+
assistantStore.setToken(validator.value.token)
82+
assistantStore.setAssistant(true)
83+
loading.value = false
84+
85+
window.addEventListener('message', communicationCb)
86+
const readyData = {
87+
eventName: 'sqlbot_embedded_event',
88+
busi: 'ready',
89+
ready: true,
90+
messageId: assistantId,
91+
}
92+
window.parent.postMessage(readyData, '*')
93+
})
94+
95+
onBeforeUnmount(() => {
96+
window.removeEventListener('message', communicationCb)
97+
})
1198
</script>
1299

13100
<style lang="less" scoped>

0 commit comments

Comments
 (0)