Skip to content

Commit 8dbac2a

Browse files
perf: Assistant dynamic ds
1 parent 517de2f commit 8dbac2a

File tree

6 files changed

+55
-41
lines changed

6 files changed

+55
-41
lines changed

backend/apps/chat/task/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def run_task(llm_service: LLMService, in_chat: bool = True):
748748
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
749749
'engine_type': llm_service.ds.type_name or llm_service.ds.type, 'type': 'datasource'}).decode() + '\n\n'
750750

751-
llm_service.chat_question.db_schema = llm_service.out_ds_instance.get_db_schema() if llm_service.out_ds_instance else get_table_schema(session=llm_service.session, current_user=llm_service.current_user, ds=llm_service.ds)
751+
llm_service.chat_question.db_schema = llm_service.out_ds_instance.get_db_schema(llm_service.ds.id) if llm_service.out_ds_instance else get_table_schema(session=llm_service.session, current_user=llm_service.current_user, ds=llm_service.ds)
752752

753753
# generate sql
754754
sql_res = llm_service.generate_sql()

backend/apps/db/db.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,32 @@
1111
from apps.db.engine import get_engine_config
1212
from decimal import Decimal
1313

14+
from apps.system.crud.assistant import get_ds_engine
15+
from apps.system.schemas.system_schema import AssistantOutDsSchema
1416

15-
def get_uri(ds: CoreDatasource):
17+
18+
def get_uri(ds: CoreDatasource) -> str:
1619
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
20+
return get_uri_from_config(ds.type, conf)
21+
22+
def get_uri_from_config(type: str,conf: DatasourceConf) -> str:
1723
db_url: str
18-
if ds.type == "mysql":
24+
if type == "mysql":
1925
if conf.extraJdbc is not None and conf.extraJdbc != '':
2026
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}?{conf.extraJdbc}"
2127
else:
2228
db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}"
23-
elif ds.type == "sqlServer":
29+
elif type == "sqlServer":
2430
if conf.extraJdbc is not None and conf.extraJdbc != '':
2531
db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}?{conf.extraJdbc}"
2632
else:
2733
db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}"
28-
elif ds.type == "pg" or ds.type == "excel":
34+
elif type == "pg" or type == "excel":
2935
if conf.extraJdbc is not None and conf.extraJdbc != '':
3036
db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}?{conf.extraJdbc}"
3137
else:
3238
db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{urllib.parse.quote(conf.database)}"
33-
elif ds.type == "oracle":
39+
elif type == "oracle":
3440
if conf.mode == "service_name":
3541
if conf.extraJdbc is not None and conf.extraJdbc != '':
3642
db_url = f"oracle+oracledb://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}?service_name={urllib.parse.quote(conf.database)}&{conf.extraJdbc}"
@@ -62,8 +68,8 @@ def get_engine(ds: CoreDatasource) -> Engine:
6268
return engine
6369

6470

65-
def get_session(ds: CoreDatasource):
66-
engine = get_engine(ds)
71+
def get_session(ds: CoreDatasource | AssistantOutDsSchema):
72+
engine = get_engine(ds) if isinstance(ds, CoreDatasource) else get_ds_engine(ds)
6773
session_maker = sessionmaker(bind=engine)
6874
session = session_maker()
6975
return session
@@ -238,7 +244,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
238244
session.close()
239245

240246

241-
def exec_sql(ds: CoreDatasource, sql: str):
247+
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str):
242248
session = get_session(ds)
243249
result = session.execute(text(sql))
244250
try:

backend/apps/system/crud/assistant.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from typing import Optional
55
from fastapi import FastAPI
66
import requests
7+
from sqlalchemy import Engine, create_engine
78
from sqlmodel import Session, select
8-
from apps.datasource.models.datasource import CoreDatasource
9+
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
10+
911
from apps.system.models.system_model import AssistantModel
1012
from apps.system.schemas.auth import CacheName, CacheNamespace
1113
from apps.system.schemas.system_schema import AssistantOutDsSchema, UserInfoDTO
@@ -92,7 +94,7 @@ def __init__(self, assistant: AssistantModel, certificate: Optional[str] = None)
9294
self.get_ds_from_api(certificate)
9395

9496
#@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
95-
async def get_ds_from_api(self, certificate: Optional[str] = None):
97+
def get_ds_from_api(self, certificate: Optional[str] = None):
9698
config: dict[any] = json.loads(self.assistant.configuration)
9799
endpoint: str = config['endpoint']
98100
certificateList: list[any] = json.loads(certificate)
@@ -106,7 +108,7 @@ async def get_ds_from_api(self, certificate: Optional[str] = None):
106108

107109
res = requests.get(url=endpoint, headers=header, cookies=cookies, timeout=10)
108110
if res.status_code == 200:
109-
result_json: dict[any] = json.loads(res.json())
111+
result_json: dict[any] = json.loads(res.text)
110112
if result_json.get('code') == 0:
111113
temp_list = result_json.get('data', [])
112114
self.ds_list = [
@@ -144,20 +146,35 @@ def get_db_schema(self, ds_id: int) -> str:
144146
def get_ds(self, ds_id: int):
145147
if self.ds_list:
146148
for ds in self.ds_list:
147-
if ds['id'] == ds_id:
149+
if ds.id == ds_id:
148150
return ds
149151
else:
150152
raise Exception("Datasource list is not found.")
151153
raise Exception(f"Datasource with id {ds_id} not found.")
152-
def get_ds_engine(self, ds_id: int):
153-
ds = self.get_ds(ds_id)
154-
ds_type = ds.get('type') if ds else None
155-
if not ds_type:
156-
raise Exception(f"Datasource with id {ds_id} not found or type is not defined.")
157-
return ds_type
154+
158155

159156
class AssistantOutDsFactory:
160157
@staticmethod
161158
def get_instance(assistant: AssistantModel, certificate: Optional[str] = None) -> AssistantOutDs:
162159
return AssistantOutDs(assistant, certificate)
160+
161+
def get_ds_engine(ds: AssistantOutDsSchema) -> Engine:
162+
timeout: int = 30
163+
connect_args = {"connect_timeout": timeout}
164+
conf = DatasourceConf(
165+
host=ds.host,
166+
port=ds.port,
167+
username=ds.user,
168+
password=ds.password,
169+
database=ds.dataBase,
170+
driver='',
171+
extraJdbc=ds.extraParams,
172+
dbSchema=ds.schema or ''
173+
)
174+
from apps.db.db import get_uri_from_config
175+
uri = get_uri_from_config(ds.type, conf)
176+
if ds.type == "pg" and ds.schema:
177+
connect_args.update({"options": f"-c search_path={ds.schema}"})
178+
engine = create_engine(uri, connect_args=connect_args, pool_timeout=timeout, pool_size=20, max_overflow=10)
179+
return engine
163180

backend/apps/system/schemas/system_schema.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,14 @@ class AssistantOutDsBase(BaseModel):
125125
comment: Optional[str] = None
126126
description: Optional[str] = None
127127

128-
def __init__(self, id: Optional[int] = None, name: str = '', type: Optional[str] = None,
129-
type_name: Optional[str] = None, comment: Optional[str] = None):
130-
super().__init__(id=id, name=name, type=type, type_name=type_name, comment=comment)
131128

132129
class AssistantOutDsSchema(AssistantOutDsBase):
133130
host: Optional[str] = None
134131
port: Optional[int] = None
132+
dataBase: Optional[str] = None
135133
user: Optional[str] = None
136134
password: Optional[str] = None
137135
schema: Optional[str] = None
136+
extraParams: Optional[str] = None
138137
tables: Optional[list[AssistantTableSchema]] = None
139-
140-
def __init__(self, id: int, name: str, comment: Optional[str] = None, type: Optional[str] = None,
141-
type_name: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None,
142-
username: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None):
143-
self.id = id
144-
self.name = name
145-
self.comment = comment
146-
self.type = type
147-
self.type_name = type_name
148-
self.host = host
149-
self.port = port
150-
self.username = username
151-
self.password = password
152-
self.schema = schema
153-
self.description = comment
138+

backend/common/core/deps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ async def get_current_user(request: Request) -> UserInfoDTO:
1919

2020
CurrentUser = Annotated[UserInfoDTO, Depends(get_current_user)]
2121

22-
async def get_current_assistant(request: Request) -> AssistantModel:
23-
return request.state.assistant
22+
async def get_current_assistant(request: Request) -> AssistantModel | None:
23+
return request.state.assistant if hasattr(request.state, "assistant") else None
2424

2525
CurrentAssistant = Annotated[AssistantModel, Depends(get_current_assistant)]
2626

frontend/src/views/chat/answer/ChartAnswer.vue

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
<script setup lang="ts">
22
import BaseAnswer from './BaseAnswer.vue'
33
import { Chat, chatApi, ChatInfo, type ChatMessage, ChatRecord, questionApi } from '@/api/chat.ts'
4+
import { useAssistantStore } from '@/stores/assistant'
45
import { computed, nextTick, ref } from 'vue'
6+
const assistantStore = useAssistantStore()
57
const props = withDefaults(
68
defineProps<{
79
chatList?: Array<ChatInfo>
@@ -96,11 +98,15 @@ const sendMessage = async () => {
9698
9799
try {
98100
const controller: AbortController = new AbortController()
99-
const response = await questionApi.add({
101+
const param = {
100102
question: currentRecord.question,
101103
chat_id: _currentChatId.value,
104+
assistant_certificate: assistantStore.getCertificate,
102105
controller,
103-
})
106+
}
107+
console.log(assistantStore.getCertificate)
108+
console.log(param)
109+
const response = await questionApi.add(param)
104110
const reader = response.body.getReader()
105111
const decoder = new TextDecoder()
106112

0 commit comments

Comments
 (0)