Skip to content

Commit 10eb41f

Browse files
fix: Assistant ds
1 parent 6d771a4 commit 10eb41f

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

backend/apps/chat/task/llm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,22 +403,24 @@ def select_datasource(self):
403403

404404
if data['id'] and data['id'] != 0:
405405
_datasource = data['id']
406-
if self.current_assistant.type == 1:
406+
_chat = self.session.get(Chat, self.record.chat_id)
407+
_chat.datasource = _datasource
408+
if self.current_assistant and self.current_assistant.type == 1:
407409
_ds = self.out_ds_instance.get_ds(data['id'])
408410
self.ds = _ds
411+
self.chat_question.engine = _ds.type
412+
_engine_type = self.chat_question.engine
413+
_chat.engine_type = _ds.type
409414
else:
410415
_ds = self.session.get(CoreDatasource, _datasource)
411416
if not _ds:
412417
_datasource = None
413418
raise Exception(f"Datasource configuration with id {_datasource} not found")
414419
self.ds = CoreDatasource(**_ds.model_dump())
415-
self.chat_question.engine = _ds.type_name if _ds.type != 'excel' else 'PostgreSQL'
416-
_engine_type = self.chat_question.engine
420+
self.chat_question.engine = _ds.type_name if _ds.type != 'excel' else 'PostgreSQL'
421+
_engine_type = self.chat_question.engine
422+
_chat.engine_type = _ds.type_name
417423
# save chat
418-
_chat = self.session.get(Chat, self.record.chat_id)
419-
_chat.datasource = _datasource
420-
_chat.engine_type = _ds.type_name
421-
422424
self.session.add(_chat)
423425
self.session.flush()
424426
self.session.refresh(_chat)
@@ -734,7 +736,7 @@ def run_task(llm_service: LLMService, in_chat: bool = True):
734736
'type': 'datasource-result'}).decode() + '\n\n'
735737
if in_chat:
736738
yield orjson.dumps({'id': llm_service.ds.id, 'datasource_name': llm_service.ds.name,
737-
'engine_type': llm_service.ds.type_name, 'type': 'datasource'}).decode() + '\n\n'
739+
'engine_type': llm_service.ds.type_name or llm_service.ds.type, 'type': 'datasource'}).decode() + '\n\n'
738740

739741
llm_service.chat_question.db_schema = llm_service.out_ds_instance.get_db_schema() if llm_service.out_ds_instance else get_table_schema(session=llm_service.session, current_user=llm_service.current_user, ds=llm_service.ds)
740742

backend/apps/system/crud/assistant.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from apps.datasource.models.datasource import CoreDatasource
99
from apps.system.models.system_model import AssistantModel
1010
from apps.system.schemas.auth import CacheName, CacheNamespace
11-
from apps.system.schemas.system_schema import UserInfoDTO
11+
from apps.system.schemas.system_schema import AssistantOutDsSchema, UserInfoDTO
1212
from common.core.sqlbot_cache import cache
1313
from common.core.db import engine
1414
from starlette.middleware.cors import CORSMiddleware
@@ -73,7 +73,7 @@ def init_dynamic_cors(app: FastAPI):
7373

7474
class AssistantOutDs:
7575
assistant: AssistantModel
76-
ds_list: Optional[list[dict]] = None
76+
ds_list: Optional[list[AssistantOutDsSchema]] = None
7777
certificate: Optional[str] = None
7878
def __init__(self, assistant: AssistantModel, certificate: Optional[str] = None):
7979
self.assistant = assistant
@@ -99,9 +99,11 @@ async def get_ds_from_api(self, certificate: Optional[str] = None):
9999
result_json: dict[any] = json.loads(res.json())
100100
if result_json.get('code') == 0:
101101
temp_list = result_json.get('data', [])
102-
for idx, item in enumerate(temp_list, start=1):
103-
item["id"] = idx
104-
self.ds_list = temp_list
102+
self.ds_list = [
103+
AssistantOutDsSchema(**{**item, "id": idx})
104+
for idx, item in enumerate(temp_list, start=1)
105+
]
106+
105107
return self.ds_list
106108
else:
107109
raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}")
@@ -110,12 +112,25 @@ async def get_ds_from_api(self, certificate: Optional[str] = None):
110112

111113
def get_simple_ds_list(self):
112114
if self.ds_list:
113-
return [{'id': ds['id'], 'name': ds['name'], 'description': ds['comment']} for ds in self.ds_list]
115+
return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list]
114116
else:
115117
raise Exception("Datasource list is not found.")
116118

117-
def get_db_schema(self, ds_id: int):
118-
return None
119+
def get_db_schema(self, ds_id: int) -> str:
120+
ds = self.get_ds(ds_id)
121+
schema_str = ""
122+
db_name = ds.schema
123+
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
124+
for table in ds.tables:
125+
schema_str += f"# Table: {db_name}.{table.name}"
126+
schema_str += f", {table.comment}\n[\n"
127+
field_list = []
128+
for field in table.fields:
129+
field_list.append(f"({field.name}:{field.type}, {field.comment})")
130+
schema_str += ",\n".join(field_list)
131+
schema_str += '\n]\n'
132+
return schema_str
133+
119134
def get_ds(self, ds_id: int):
120135
if self.ds_list:
121136
for ds in self.ds_list:

backend/apps/system/schemas/system_schema.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,53 @@ class UserWs(BaseCreatorDTO):
101101
name: str
102102

103103
class UserWsOption(UserWs):
104-
account: str
104+
account: str
105+
106+
107+
class AssistantFieldSchema(BaseModel):
108+
id: Optional[int] = None
109+
name: Optional[str] = None
110+
type: Optional[str] = None
111+
comment: Optional[str] = None
112+
class AssistantTableSchema(BaseModel):
113+
id: Optional[int] = None
114+
name: Optional[str] = None
115+
comment: Optional[str] = None
116+
rule: Optional[str] = None
117+
sql: Optional[str] = None
118+
fields: Optional[list[AssistantFieldSchema]] = None
119+
120+
class AssistantOutDsBase(BaseModel):
121+
id: Optional[int] = None
122+
name: str
123+
type: Optional[str] = None
124+
type_name: Optional[str] = None
125+
comment: Optional[str] = None
126+
description: Optional[str] = None
127+
128+
def __init__(self, id: Optional[int] = None, name: str = '', type: Optional[str] = None,
129+
type_name: Optional[str] = None, comment: Optional[str] = None):
130+
super().__init__(id=id, name=name, type=type, type_name=type_name, comment=comment)
131+
132+
class AssistantOutDsSchema(AssistantOutDsBase):
133+
host: Optional[str] = None
134+
port: Optional[int] = None
135+
user: Optional[str] = None
136+
password: Optional[str] = None
137+
schema: Optional[str] = None
138+
tables: Optional[list[AssistantTableSchema]] = None
139+
140+
def __init__(self, id: int, name: str, comment: Optional[str] = None, type: Optional[str] = None,
141+
type_name: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None,
142+
username: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None):
143+
self.id = id
144+
self.name = name
145+
self.comment = comment
146+
self.type = type
147+
self.type_name = type_name
148+
self.host = host
149+
self.port = port
150+
self.username = username
151+
self.password = password
152+
self.schema = schema
153+
self.description = comment

0 commit comments

Comments
 (0)