Skip to content

Commit 39bedc5

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 66cbfd2 + 27d7454 commit 39bedc5

File tree

10 files changed

+79
-50
lines changed

10 files changed

+79
-50
lines changed

backend/apps/chat/curd/chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def create_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj:
189189
if ds:
190190
chat_info.datasource_exists = True
191191
chat_info.datasource_name = ds.name
192+
chat_info.ds_type = ds.type
192193

193194
if require_datasource and ds:
194195
# generate first empty record

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_ds(session: SessionDep, id: int):
4141

4242

4343
def check_status(session: SessionDep, ds: CoreDatasource, is_raise: bool = False):
44-
conn = get_engine(ds, 5)
44+
conn = get_engine(ds, 10)
4545
try:
4646
with conn.connect() as connection:
4747
SQLBotLogUtil.info("success")

backend/apps/db/db.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,28 +52,24 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str:
5252
return db_url
5353

5454

55-
def get_engine(ds: CoreDatasource, timeout: int = 30) -> Engine:
55+
def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine:
5656
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
5757
if conf.timeout is None:
5858
conf.timeout = timeout
59+
if timeout > 0:
60+
conf.timeout = timeout
5961
if ds.type == "pg" and (conf.dbSchema is not None and conf.dbSchema != ""):
6062
engine = create_engine(get_uri(ds),
6163
connect_args={"options": f"-c search_path={urllib.parse.quote(conf.dbSchema)}",
6264
"connect_timeout": conf.timeout},
63-
pool_timeout=conf.timeout, pool_size=20, max_overflow=10)
65+
pool_timeout=conf.timeout)
6466
elif ds.type == 'sqlServer':
65-
engine = create_engine(get_uri(ds), pool_timeout=conf.timeout,
66-
pool_size=20,
67-
max_overflow=10)
67+
engine = create_engine(get_uri(ds), pool_timeout=conf.timeout)
6868
elif ds.type == 'oracle':
6969
engine = create_engine(get_uri(ds),
70-
pool_timeout=conf.timeout,
71-
pool_size=20,
72-
max_overflow=10)
70+
pool_timeout=conf.timeout)
7371
else:
74-
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout,
75-
pool_size=20,
76-
max_overflow=10)
72+
engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, pool_timeout=conf.timeout)
7773
return engine
7874

7975

backend/apps/db/engine.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ def get_engine_conn():
2525
db_url = get_engine_uri(conf)
2626
engine = create_engine(db_url,
2727
connect_args={"options": f"-c search_path={conf.dbSchema}", "connect_timeout": conf.timeout},
28-
pool_timeout=conf.timeout,
29-
pool_size=20,
30-
max_overflow=10)
28+
pool_timeout=conf.timeout)
3129
return engine
3230

3331

backend/apps/system/crud/assistant.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
1-
2-
31
import json
2+
import urllib
43
from typing import Optional
5-
from fastapi import FastAPI
4+
65
import requests
6+
from fastapi import FastAPI
77
from sqlalchemy import Engine, create_engine
88
from sqlmodel import Session, select
9-
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
9+
from starlette.middleware.cors import CORSMiddleware
1010

11+
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
1112
from apps.system.models.system_model import AssistantModel
1213
from apps.system.schemas.auth import CacheName, CacheNamespace
1314
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
14-
from common.core.sqlbot_cache import cache
15-
from common.core.db import engine
16-
from starlette.middleware.cors import CORSMiddleware
1715
from common.core.config import settings
16+
from common.core.db import engine
17+
from common.core.sqlbot_cache import cache
1818
from common.utils.utils import string_to_numeric_hash
1919

20+
2021
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
2122
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
2223
db_model = session.get(AssistantModel, assistant_id)
2324
return db_model
2425

26+
2527
def get_assistant_user(*, id: int):
26-
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant", email="[email protected]")
28+
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
29+
30+
2731

2832
def get_assistant_ds(llm_service) -> list[dict]:
2933
assistant: AssistantHeader = llm_service.current_assistant
@@ -34,13 +38,14 @@ def get_assistant_ds(llm_service) -> list[dict]:
3438
if configuration:
3539
config: dict[any] = json.loads(configuration)
3640
oid: int = int(config['oid'])
37-
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(CoreDatasource.oid == oid)
41+
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
42+
CoreDatasource.oid == oid)
3843
if not assistant.online:
39-
private_list:list[int] = config.get('private_list') or None
44+
private_list: list[int] = config.get('private_list') or None
4045
if private_list:
4146
stmt = stmt.where(~CoreDatasource.id.in_(private_list))
4247
db_ds_list = session.exec(stmt)
43-
48+
4449
result_list = [
4550
{
4651
"id": ds.id,
@@ -49,7 +54,7 @@ def get_assistant_ds(llm_service) -> list[dict]:
4954
}
5055
for ds in db_ds_list
5156
]
52-
57+
5358
# filter private ds if offline
5459
return result_list
5560
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
@@ -58,8 +63,9 @@ def get_assistant_ds(llm_service) -> list[dict]:
5863
# format?
5964
return dslist
6065

66+
6167
def init_dynamic_cors(app: FastAPI):
62-
try:
68+
try:
6369
with Session(engine) as session:
6470
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
6571
seen = set()
@@ -81,20 +87,20 @@ def init_dynamic_cors(app: FastAPI):
8187
cors_middleware.kwargs['allow_origins'] = updated_origins
8288
except Exception as e:
8389
return False, e
84-
85-
90+
8691

8792
class AssistantOutDs:
8893
assistant: AssistantHeader
8994
ds_list: Optional[list[AssistantOutDsSchema]] = None
9095
certificate: Optional[str] = None
96+
9197
def __init__(self, assistant: AssistantHeader):
9298
self.assistant = assistant
9399
self.ds_list = None
94100
self.certificate = assistant.certificate
95101
self.get_ds_from_api()
96-
97-
#@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
102+
103+
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
98104
def get_ds_from_api(self):
99105
config: dict[any] = json.loads(self.assistant.configuration)
100106
endpoint: str = config['endpoint']
@@ -118,23 +124,23 @@ def get_ds_from_api(self):
118124
self.convert2schema(item)
119125
for item in temp_list
120126
]
121-
127+
122128
return self.ds_list
123129
else:
124130
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
125131
else:
126132
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
127-
133+
128134
def get_simple_ds_list(self):
129135
if self.ds_list:
130136
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
131137
else:
132138
raise Exception("Datasource list is not found.")
133-
139+
134140
def get_db_schema(self, ds_id: int) -> str:
135141
ds = self.get_ds(ds_id)
136142
schema_str = ""
137-
#db_name = ds.db_schema
143+
# db_name = ds.db_schema
138144
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
139145
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
140146
for table in ds.tables:
@@ -144,7 +150,7 @@ def get_db_schema(self, ds_id: int) -> str:
144150
schema_str += '\n[\n'
145151
else:
146152
schema_str += f", {table_comment}\n[\n"
147-
153+
148154
field_list = []
149155
for field in table.fields:
150156
field_comment = field.comment
@@ -155,7 +161,7 @@ def get_db_schema(self, ds_id: int) -> str:
155161
schema_str += ",\n".join(field_list)
156162
schema_str += '\n]\n'
157163
return schema_str
158-
164+
159165
def get_ds(self, ds_id: int):
160166
if self.ds_list:
161167
for ds in self.ds_list:
@@ -175,20 +181,22 @@ def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema:
175181
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
176182
ds_dict.pop("schema", None)
177183
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
178-
184+
185+
179186
class AssistantOutDsFactory:
180187
@staticmethod
181188
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
182189
return AssistantOutDs(assistant)
183190

191+
184192
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
185193
timeout: int = 30
186194
connect_args = {"connect_timeout": timeout}
187195
conf = DatasourceConf(
188-
host=ds.host,
189-
port=ds.port,
196+
host=ds.host,
197+
port=ds.port,
190198
username=ds.user,
191-
password=ds.password,
199+
password=ds.password,
192200
database=ds.dataBase,
193201
driver='',
194202
extraJdbc=ds.extraParams,
@@ -197,8 +205,20 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
197205
conf.extraJdbc = ''
198206
from apps.db.db import get_uri_from_config
199207
uri = get_uri_from_config(ds.type, conf)
208+
# if ds.type == "pg" and ds.db_schema:
209+
# connect_args.update({"options": f"-c search_path={ds.db_schema}"})
210+
# engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
211+
200212
if ds.type == "pg" and ds.db_schema:
201-
connect_args.update({"options": f"-c search_path={ds.db_schema}"})
202-
engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
213+
engine = create_engine(uri,
214+
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
215+
"connect_timeout": timeout},
216+
pool_timeout=timeout)
217+
elif ds.type == 'sqlServer':
218+
engine = create_engine(uri, pool_timeout=timeout)
219+
elif ds.type == 'oracle':
220+
engine = create_engine(uri,
221+
pool_timeout=timeout)
222+
else:
223+
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout)
203224
return engine
204-

backend/template.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ template:
77
根据表结构和问题生成符合{engine}数据库引擎规范的sql语句,以及sql中所用到的表名(不要包含schema和database,用数组返回)。
88
你必须遵守以下规则:
99
- 生成的SQL必须符合{engine}的规范。
10+
- 提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL
1011
- 根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS)。
1112
- SQL查询中不能使用星号(*),必须明确指定字段名.
1213
- SQL查询的字段名不要自动翻译,别名必须为英文。

frontend/public/assistant.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
;(function () {
1+
; (function () {
22
window.sqlbot_assistant_handler = window.sqlbot_assistant_handler || {}
33
const defaultData = {
44
id: '1',
@@ -68,7 +68,7 @@
6868
const getChatContainerHtml = (data) => {
6969
return `
7070
<div id="sqlbot-assistant-chat-container">
71-
<iframe id="sqlbot-assistant-chat-iframe-${data.id}" allow="microphone" src="${data.domain_url}/#/assistant?id=${data.id}&online=${!!data.online}"></iframe>
71+
<iframe id="sqlbot-assistant-chat-iframe-${data.id}" allow="microphone;clipboard-read 'src'; clipboard-write 'src'" src="${data.domain_url}/#/assistant?id=${data.id}&online=${!!data.online}"></iframe>
7272
<div class="sqlbot-assistant-operate">
7373
<div class="sqlbot-assistant-closeviewport sqlbot-assistant-viewportnone">
7474
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" viewBox="0 0 20 20" fill="none">

frontend/src/views/chat/chat-block/ChartBlock.vue

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import icon_window_mini_outlined from '@/assets/svg/icon_window-mini_outlined.sv
1818
import icon_copy_outlined from '@/assets/svg/icon_copy_outlined.svg'
1919
import { useI18n } from 'vue-i18n'
2020
import SQLComponent from '@/views/chat/component/SQLComponent.vue'
21+
import { useAssistantStore } from '@/stores/assistant'
2122
import AddViewDashboard from '@/views/dashboard/common/AddViewDashboard.vue'
2223
2324
const props = withDefaults(
@@ -51,6 +52,8 @@ const dataObject = computed<{
5152
}
5253
return {}
5354
})
55+
const assistantStore = useAssistantStore()
56+
const isAssistant = computed(() => assistantStore.getAssistant)
5457
5558
const data = computed(() => {
5659
if (props.isPredict) {
@@ -325,7 +328,7 @@ function copy() {
325328

326329
<el-drawer
327330
v-model="sqlShow"
328-
size="600"
331+
:size="isAssistant ? '100%' : '600px'"
329332
:title="t('chat.show_sql')"
330333
direction="rtl"
331334
body-class="chart-sql-drawer-body"

frontend/src/views/chat/index.vue

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@
296296
type="textarea"
297297
:autosize="{ minRows: 1, maxRows: 8.583 }"
298298
:placeholder="t('qa.question_placeholder')"
299-
@keydown.enter.exact.prevent="sendMessage"
299+
@keydown.enter.exact.prevent="($event: any) => sendMessage($event)"
300300
@keydown.ctrl.enter.exact.prevent="handleCtrlEnter"
301301
/>
302302

@@ -551,7 +551,10 @@ const assistantPrepareSend = async () => {
551551
}
552552
}
553553
}
554-
const sendMessage = async () => {
554+
const sendMessage = async ($event: any = {}) => {
555+
if ($event?.isComposing) {
556+
return
557+
}
555558
if (!inputMessage.value.trim()) return
556559
557560
loading.value = true

frontend/src/views/embedded/index.vue

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,10 @@ onBeforeUnmount(() => {
191191
}
192192
}
193193
</style>
194+
195+
<style lang="less">
196+
.ed-overlay-dialog,
197+
.ed-drawer {
198+
margin-top: 50px;
199+
}
200+
</style>

0 commit comments

Comments
 (0)