Skip to content

Commit b017498

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 61bcf01 + 1aaf5ed commit b017498

File tree

5 files changed

+92
-23
lines changed

5 files changed

+92
-23
lines changed

backend/apps/chat/curd/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def save_full_select_datasource_message_and_answer(session: SessionDep, record_i
341341
engine_type: str = None, token_usage: dict = None) -> ChatRecord:
342342
if not record_id:
343343
raise Exception("Record id cannot be None")
344-
record = session.query(ChatRecord).filter(ChatRecord.id == record_id).first()
344+
record = session.get(ChatRecord, record_id)
345345
record.full_select_datasource_message = full_message
346346
record.datasource_select_answer = answer
347347

backend/apps/chat/models/chat_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def filter_user_question(self):
151151
class ChatQuestion(AiModelQuestion):
152152
question: str = ''
153153
chat_id: int = 0
154+
assistant_certificate: Optional[str] = None
154155

155156

156157
class ChatMcp(ChatQuestion):

backend/apps/chat/task/llm.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from apps.datasource.crud.row_permission import transFilterTree
3131
from apps.datasource.models.datasource import CoreDatasource, CoreTable
3232
from apps.db.db import exec_sql
33-
from apps.system.crud.assistant import get_assistant_ds
33+
from apps.system.crud.assistant import AssistantOutDs, get_assistant_ds
3434
from common.core.config import settings
3535
from common.core.deps import CurrentAssistant, SessionDep, CurrentUser
3636
from common.utils.utils import extract_nested_json
@@ -52,6 +52,8 @@ class LLMService:
5252
session: SessionDep
5353
current_user: CurrentUser
5454
current_assistant: Optional[CurrentAssistant] = None
55+
assistant_certificate: str
56+
out_ds_instance: Optional[AssistantOutDs] = None
5557
change_title: bool = False
5658

5759
def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question: ChatQuestion,
@@ -60,6 +62,8 @@ def __init__(self, session: SessionDep, current_user: CurrentUser, chat_question
6062
self.session = session
6163
self.current_user = current_user
6264
self.current_assistant = current_assistant
65+
if chat_question.assistant_certificate:
66+
self.assistant_certificate = chat_question.assistant_certificate
6367
# chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first()
6468
chat_id = chat_question.chat_id
6569
chat: Chat = self.session.get(Chat, chat_id)
@@ -346,7 +350,7 @@ def select_datasource(self):
346350
datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = []
347351
datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question()))
348352
if self.current_assistant:
349-
_ds_list = get_assistant_ds(session=self.session, assistant=self.current_assistant)
353+
_ds_list = get_assistant_ds(llm_service=self)
350354
else:
351355
_ds_list = self.session.exec(select(CoreDatasource).options(
352356
load_only(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description))).all()
@@ -398,15 +402,19 @@ def select_datasource(self):
398402

399403
if data['id'] and data['id'] != 0:
400404
_datasource = data['id']
401-
_ds = self.session.query(CoreDatasource).filter(CoreDatasource.id == _datasource).first()
402-
if not _ds:
403-
_datasource = None
404-
raise Exception(f"Datasource configuration with id {_datasource} not found")
405-
self.ds = CoreDatasource(**_ds.model_dump())
405+
if self.current_assistant.type == 1:
406+
_ds = self.out_ds_instance.get_ds(data['id'])
407+
self.ds = _ds
408+
else:
409+
_ds = self.session.get(CoreDatasource, _datasource)
410+
if not _ds:
411+
_datasource = None
412+
raise Exception(f"Datasource configuration with id {_datasource} not found")
413+
self.ds = CoreDatasource(**_ds.model_dump())
406414
self.chat_question.engine = _ds.type_name if _ds.type != 'excel' else 'PostgreSQL'
407415
_engine_type = self.chat_question.engine
408416
# save chat
409-
_chat = self.session.query(Chat).filter(Chat.id == self.record.chat_id).first()
417+
_chat = self.session.get(Chat, self.record.chat_id)
410418
_chat.datasource = _datasource
411419
_chat.engine_type = _ds.type_name
412420

@@ -727,7 +735,7 @@ def run_task(llm_service: LLMService, in_chat: bool = True):
727735
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
728736
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
729737

730-
llm_service.chat_question.db_schema = get_table_schema(session=llm_service.session, current_user=llm_service.current_user, ds=llm_service.ds)
738+
llm_service.chat_question.db_schema = llm_service.out_ds_instance.get_db_schema() if llm_service.out_ds_instance else get_table_schema(session=llm_service.session, current_user=llm_service.current_user, ds=llm_service.ds)
731739

732740
# generate sql
733741
sql_res = llm_service.generate_sql()

backend/apps/system/crud/assistant.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11

22

33
import json
4+
from typing import Optional
45
from fastapi import FastAPI
6+
import requests
57
from sqlmodel import Session, select
8+
from apps.chat.task.llm import LLMService
69
from apps.datasource.models.datasource import CoreDatasource
710
from apps.system.models.system_model import AssistantModel
811
from apps.system.schemas.auth import CacheName, CacheNamespace
@@ -11,6 +14,7 @@
1114
from common.core.db import engine
1215
from starlette.middleware.cors import CORSMiddleware
1316
from common.core.config import settings
17+
from deps import CurrentUser
1418

1519
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
1620
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
@@ -20,21 +24,26 @@ async def get_assistant_info(*, session: Session, assistant_id: int) -> Assistan
2024
def get_assistant_user(*, id: int):
2125
return UserInfoDTO(id=id, account="sqlbot-inner-assistant", oid=1, name="sqlbot-inner-assistant", email="[email protected]")
2226

23-
def get_assistant_ds(*, session: Session, assistant: AssistantModel):
27+
# def get_assistant_ds(*, session: Session, assistant: AssistantModel):
28+
def get_assistant_ds(llm_service: LLMService) -> list[dict]:
29+
assistant: AssistantModel = llm_service.current_assistant
30+
session: Session = llm_service.session
2431
type = assistant.type
2532
if type == 0:
26-
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description)
2733
configuration = assistant.configuration
2834
if configuration:
29-
config = json.loads(configuration)
35+
config: dict[any] = json.loads(configuration)
36+
oid: str = config['oid']
37+
stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where(CoreDatasource.oid == oid)
3038
private_list:list[int] = config['private_list']
31-
if not private_list:
39+
if private_list:
3240
stmt.where(~CoreDatasource.id.in_(private_list))
3341
db_ds_list = session.exec(stmt).all()
3442
# filter private ds if offline
3543
return db_ds_list
36-
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant)
37-
dslist = out_ds_instance.get_ds_list()
44+
out_ds_instance: AssistantOutDs = AssistantOutDsFactory.get_instance(assistant, llm_service.assistant_certificate)
45+
llm_service.out_ds_instance = out_ds_instance
46+
dslist = out_ds_instance.get_simple_ds_list()
3847
# format?
3948
return dslist
4049

@@ -66,16 +75,66 @@ def init_dynamic_cors(app: FastAPI):
6675

6776
class AssistantOutDs:
6877
assistant: AssistantModel
69-
def get_ds_list(self):
78+
ds_list: Optional[list[dict]] = None
79+
certificate: Optional[str] = None
80+
def __init__(self, assistant: AssistantModel, certificate: Optional[str] = None):
81+
self.assistant = assistant
82+
self.ds_list = None
83+
self.certificate = certificate
84+
self.get_ds_from_api(certificate)
85+
86+
#@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_DS, keyExpression="current_user.id")
87+
async def get_ds_from_api(self, certificate: Optional[str] = None):
7088
config: dict[any] = json.loads(self.assistant.configuration)
71-
url: str = config['url']
89+
endpoint: str = config['endpoint']
90+
certificateList: list[any] = json.loads(certificate)
91+
header = {}
92+
cookies = {}
93+
for item in certificateList:
94+
if item['target'] == 'head':
95+
header[item['key']] = item['value']
96+
if item['target'] == 'cookie':
97+
cookies[item['key']] = item['value']
98+
99+
res = requests.get(url=endpoint, headers=header, cookies=cookies, timeout=10)
100+
if res.status_code == 200:
101+
result_json: dict[any] = json.loads(res.json())
102+
if result_json.get('code') == 0:
103+
temp_list = result_json.get('data', [])
104+
for idx, item in enumerate(temp_list, start=1):
105+
item["id"] = idx
106+
self.ds_list = temp_list
107+
return self.ds_list
108+
else:
109+
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
110+
else:
111+
raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}")
112+
113+
def get_simple_ds_list(self):
114+
if self.ds_list:
115+
return [{'id': ds['id'], 'name': ds['name'], 'description': ds['comment']} for ds in self.ds_list]
116+
else:
117+
raise Exception("Datasource list is not found.")
118+
119+
def get_db_schema(self, ds_id: int):
72120
return None
121+
def get_ds(self, ds_id: int):
122+
if self.ds_list:
123+
for ds in self.ds_list:
124+
if ds['id'] == ds_id:
125+
return ds
126+
else:
127+
raise Exception("Datasource list is not found.")
128+
raise Exception(f"Datasource with id {ds_id} not found.")
129+
def get_ds_engine(self, ds_id: int):
130+
ds = self.get_ds(ds_id)
131+
ds_type = ds.get('type') if ds else None
132+
if not ds_type:
133+
raise Exception(f"Datasource with id {ds_id} not found or type is not defined.")
134+
return ds_type
73135

74136
class AssistantOutDsFactory:
75-
_instance: AssistantOutDs = None
76137
@staticmethod
77-
def get_instance(cls, assistant: AssistantModel) -> AssistantOutDs:
78-
if not cls._instance:
79-
cls._instance = AssistantOutDs(assistant)
80-
return cls._instance
138+
def get_instance(assistant: AssistantModel, certificate: Optional[str] = None) -> AssistantOutDs:
139+
return AssistantOutDs(assistant, certificate)
81140

backend/apps/system/schemas/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ def __str__(self):
1414
class CacheName(Enum):
1515
USER_INFO = "user:info"
1616
ASSISTANT_INFO = "assistant:info"
17+
ASSISTANT_DS = "assistant:ds"
1718
def __str__(self):
1819
return self.value

0 commit comments

Comments
 (0)