Skip to content

Commit 8141700

Browse files
committed
feat: assistant connect
1 parent 17fcfd4 commit 8141700

File tree

2 files changed

+57
-31
lines changed

2 files changed

+57
-31
lines changed

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_ds(session: SessionDep, id: int):
4141

4242

4343
def check_status(session: SessionDep, ds: CoreDatasource, is_raise: bool = False):
44-
conn = get_engine(ds, 5)
44+
conn = get_engine(ds, 10)
4545
try:
4646
with conn.connect() as connection:
4747
SQLBotLogUtil.info("success")

backend/apps/system/crud/assistant.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
1-
2-
31
import json
2+
import urllib
43
from typing import Optional
5-
from fastapi import FastAPI
4+
65
import requests
6+
from fastapi import FastAPI
77
from sqlalchemy import Engine, create_engine
88
from sqlmodel import Session, select
9-
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
9+
from starlette.middleware.cors import CORSMiddleware
1010

11+
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
1112
from apps.system.models.system_model import AssistantModel
1213
from apps.system.schemas.auth import CacheName, CacheNamespace
1314
from apps.system.schemas.system_schema import AssistantHeader, AssistantOutDsSchema, UserInfoDTO
14-
from common.core.sqlbot_cache import cache
15-
from common.core.db import engine
16-
from starlette.middleware.cors import CORSMiddleware
1715
from common.core.config import settings
16+
from common.core.db import engine
17+
from common.core.sqlbot_cache import cache
1818
from common.utils.utils import string_to_numeric_hash
1919

20+
2021
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
2122
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
2223
db_model = session.get(AssistantModel, assistant_id)
2324
return db_model
2425

26+
2527
def get_assistant_user(*, id: int):
26-
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant", email="[email protected]")
28+
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant",
29+
30+
2731

2832
def get_assistant_ds(llm_service) -> list[dict]:
2933
assistant: AssistantHeader = llm_service.current_assistant
@@ -34,13 +38,14 @@ def get_assistant_ds(llm_service) -> list[dict]:
3438
if configuration:
3539
config: dict[any] = json.loads(configuration)
3640
oid: int = int(config['oid'])
37-
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(CoreDatasource.oid == oid)
41+
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(
42+
CoreDatasource.oid == oid)
3843
if not assistant.online:
39-
private_list:list[int] = config.get('private_list') or None
44+
private_list: list[int] = config.get('private_list') or None
4045
if private_list:
4146
stmt = stmt.where(~CoreDatasource.id.in_(private_list))
4247
db_ds_list = session.exec(stmt)
43-
48+
4449
result_list = [
4550
{
4651
"id": ds.id,
@@ -49,7 +54,7 @@ def get_assistant_ds(llm_service) -> list[dict]:
4954
}
5055
for ds in db_ds_list
5156
]
52-
57+
5358
# filter private ds if offline
5459
return result_list
5560
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
@@ -58,8 +63,9 @@ def get_assistant_ds(llm_service) -> list[dict]:
5863
# format?
5964
return dslist
6065

66+
6167
def init_dynamic_cors(app: FastAPI):
62-
try:
68+
try:
6369
with Session(engine) as session:
6470
list_result = session.exec(select(AssistantModel).order_by(AssistantModel.create_time)).all()
6571
seen = set()
@@ -81,20 +87,20 @@ def init_dynamic_cors(app: FastAPI):
8187
cors_middleware.kwargs['allow_origins'] = updated_origins
8288
except Exception as e:
8389
return False, e
84-
85-
90+
8691

8792
class AssistantOutDs:
8893
assistant: AssistantHeader
8994
ds_list: Optional[list[AssistantOutDsSchema]] = None
9095
certificate: Optional[str] = None
96+
9197
def __init__(self, assistant: AssistantHeader):
9298
self.assistant = assistant
9399
self.ds_list = None
94100
self.certificate = assistant.certificate
95101
self.get_ds_from_api()
96-
97-
#@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
102+
103+
# @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
98104
def get_ds_from_api(self):
99105
config: dict[any] = json.loads(self.assistant.configuration)
100106
endpoint: str = config['endpoint']
@@ -118,23 +124,23 @@ def get_ds_from_api(self):
118124
self.convert2schema(item)
119125
for item in temp_list
120126
]
121-
127+
122128
return self.ds_list
123129
else:
124130
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
125131
else:
126132
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
127-
133+
128134
def get_simple_ds_list(self):
129135
if self.ds_list:
130136
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
131137
else:
132138
raise Exception("Datasource list is not found.")
133-
139+
134140
def get_db_schema(self, ds_id: int) -> str:
135141
ds = self.get_ds(ds_id)
136142
schema_str = ""
137-
#db_name = ds.db_schema
143+
# db_name = ds.db_schema
138144
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
139145
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
140146
for table in ds.tables:
@@ -144,7 +150,7 @@ def get_db_schema(self, ds_id: int) -> str:
144150
schema_str += '\n[\n'
145151
else:
146152
schema_str += f", {table_comment}\n[\n"
147-
153+
148154
field_list = []
149155
for field in table.fields:
150156
field_comment = field.comment
@@ -155,7 +161,7 @@ def get_db_schema(self, ds_id: int) -> str:
155161
schema_str += ",\n".join(field_list)
156162
schema_str += '\n]\n'
157163
return schema_str
158-
164+
159165
def get_ds(self, ds_id: int):
160166
if self.ds_list:
161167
for ds in self.ds_list:
@@ -175,20 +181,22 @@ def convert2schema(self, ds_dict: dict) -> AssistantOutDsSchema:
175181
db_schema = ds_dict.get('schema', ds_dict.get('db_schema', ''))
176182
ds_dict.pop("schema", None)
177183
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
178-
184+
185+
179186
class AssistantOutDsFactory:
180187
@staticmethod
181188
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
182189
return AssistantOutDs(assistant)
183190

191+
184192
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
185193
timeout: int = 30
186194
connect_args = {"connect_timeout": timeout}
187195
conf = DatasourceConf(
188-
host=ds.host,
189-
port=ds.port,
196+
host=ds.host,
197+
port=ds.port,
190198
username=ds.user,
191-
password=ds.password,
199+
password=ds.password,
192200
database=ds.dataBase,
193201
driver='',
194202
extraJdbc=ds.extraParams,
@@ -197,8 +205,26 @@ def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
197205
conf.extraJdbc = ''
198206
from apps.db.db import get_uri_from_config
199207
uri = get_uri_from_config(ds.type, conf)
208+
# if ds.type == "pg" and ds.db_schema:
209+
# connect_args.update({"options": f"-c search_path={ds.db_schema}"})
210+
# engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
211+
200212
if ds.type == "pg" and ds.db_schema:
201-
connect_args.update({"options": f"-c search_path={ds.db_schema}"})
202-
engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
213+
engine = create_engine(uri,
214+
connect_args={"options": f"-c search_path={urllib.parse.quote(ds.db_schema)}",
215+
"connect_timeout": timeout},
216+
pool_timeout=timeout, pool_size=20, max_overflow=10)
217+
elif ds.type == 'sqlServer':
218+
engine = create_engine(uri, pool_timeout=timeout,
219+
pool_size=20,
220+
max_overflow=10)
221+
elif ds.type == 'oracle':
222+
engine = create_engine(uri,
223+
pool_timeout=timeout,
224+
pool_size=20,
225+
max_overflow=10)
226+
else:
227+
engine = create_engine(uri, connect_args={"connect_timeout": timeout}, pool_timeout=timeout,
228+
pool_size=20,
229+
max_overflow=10)
203230
return engine
204-

0 commit comments

Comments
 (0)