Skip to content

Commit 49513ae

Browse files
committed
feat: chat use local db schema
1 parent 7020a5a commit 49513ae

File tree

4 files changed

+65
-27
lines changed

4 files changed

+65
-27
lines changed

backend/apps/chat/api/chat.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from apps.chat.curd.chat import list_chats, get_chat_with_records, create_chat, save_question, save_answer, rename_chat, \
66
delete_chat
7-
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat
7+
from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, Chat
88
from apps.chat.schemas.chat_base_schema import LLMConfig
99
from apps.chat.schemas.chat_schema import ChatQuestion
1010
from apps.chat.schemas.llm import AgentService
11+
from apps.datasource.crud.datasource import get_table_obj_by_ds
1112
from apps.datasource.models.datasource import CoreDatasource
1213
from apps.system.models.system_model import AiModelDetail
1314
from common.core.deps import SessionDep, CurrentUser
@@ -80,26 +81,18 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
8081
"""
8182
question = request_question.question
8283

83-
# Get available AI model
84-
aimodel = session.exec(select(AiModelDetail).where(
85-
AiModelDetail.status == True,
86-
AiModelDetail.api_key.is_not(None)
87-
)).first()
88-
89-
# Get available datasource
90-
ds = session.exec(select(CoreDatasource).where(
91-
CoreDatasource.status == 'Success'
92-
)).first()
93-
94-
if not aimodel:
84+
chat = session.query(Chat).filter(Chat.id == request_question.chat_id).first()
85+
if not chat:
9586
raise HTTPException(
9687
status_code=400,
97-
detail="No available AI model configuration found"
88+
detail=f"Chat with id {request_question.chart_id} not found"
9889
)
9990

91+
# Get available datasource
92+
ds = session.query(CoreDatasource).filter(CoreDatasource.id == chat.datasource).first()
10093
if not ds:
10194
raise HTTPException(
102-
status_code=400,
95+
status_code=500,
10396
detail="No available datasource configuration found"
10497
)
10598

@@ -112,6 +105,17 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
112105
detail=str(e1)
113106
)
114107

108+
# Get available AI model
109+
aimodel = session.exec(select(AiModelDetail).where(
110+
AiModelDetail.status == True,
111+
AiModelDetail.api_key.is_not(None)
112+
)).first()
113+
if not aimodel:
114+
raise HTTPException(
115+
status_code=400,
116+
detail="No available AI model configuration found"
117+
)
118+
115119
# Use Tongyi Qianwen
116120
tongyi_config = LLMConfig(
117121
model_type="openai",
@@ -136,10 +140,39 @@ async def stream_sql(session: SessionDep, current_user: CurrentUser, request_que
136140
""" result = llm_service.generate_sql(question)
137141
return result """
138142

143+
# get schema
144+
schema_str = ""
145+
table_objs = get_table_obj_by_ds(session=session, ds=ds)
146+
db_name = table_objs[0].schema
147+
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
148+
for obj in table_objs:
149+
schema_str += f"# Table: {db_name}.{obj.table.table_name}"
150+
table_comment = ''
151+
if obj.table.custom_comment:
152+
table_comment = obj.table.custom_comment.strip()
153+
if table_comment == '':
154+
schema_str += '\n[\n'
155+
else:
156+
schema_str += f", {table_comment}\n[\n"
157+
158+
field_list = []
159+
for field in obj.fields:
160+
field_comment = ''
161+
if field.custom_comment:
162+
field_comment = field.custom_comment.strip()
163+
if field_comment == '':
164+
field_list.append(f"({field.field_name}:{field.field_type})")
165+
else:
166+
field_list.append(f"({field.field_name}:{field.field_type}, {field_comment})")
167+
schema_str += ",\n".join(field_list)
168+
schema_str += '\n]\n'
169+
170+
print(schema_str)
171+
139172
async def event_generator():
140173
all_text = ''
141174
try:
142-
async for chunk in llm_service.async_generate(question):
175+
async for chunk in llm_service.async_generate(question, schema_str):
143176
data = json.loads(chunk.replace('data: ', ''))
144177

145178
if data['type'] in ['final', 'tool_result']:

backend/apps/chat/schemas/llm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,17 @@ def generate_sql(self, question: str) -> str:
101101
schema = self.db.get_table_info()
102102
return chain.invoke({"schema": schema, "question": question})
103103

104-
async def async_generate(self, question: str) -> AsyncGenerator[str, None]:
104+
async def async_generate(self, question: str, schema: str) -> AsyncGenerator[str, None]:
105105

106106
chain = self.prompt | self.agent_executor
107107
# schema = self.db.get_table_info()
108108

109-
schema_engine = SchemaEngine(engine=self.db._engine)
110-
mschema = schema_engine.mschema
111-
mschema_str = mschema.to_mschema()
109+
# schema_engine = SchemaEngine(engine=self.db._engine)
110+
# mschema = schema_engine.mschema
111+
# mschema_str = mschema.to_mschema()
112112

113-
async for chunk in chain.astream({"schema": mschema_str, "question": question}):
113+
# async for chunk in chain.astream({"schema": mschema_str, "question": question}):
114+
async for chunk in chain.astream({"schema": schema, "question": question}):
114115
if not isinstance(chunk, dict):
115116
continue
116117

backend/apps/datasource/crud/datasource.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,12 @@ def preview(session: SessionDep, id: int, data: TableObj):
211211
return exec_sql(ds, sql)
212212

213213

214-
def get_tableobj_by_ds(session: SessionDep, id: int) -> List[TableAndFields]:
215-
list: List = []
216-
tables = session.query(CoreTable).filter(CoreTable.ds_id == id).all()
214+
def get_table_obj_by_ds(session: SessionDep, ds: CoreDatasource) -> List[TableAndFields]:
215+
_list: List = []
216+
tables = session.query(CoreTable).filter(CoreTable.ds_id == ds.id).all()
217+
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
218+
schema = conf.dbSchema if conf.dbSchema is not None and conf.dbSchema != "" else conf.database
217219
for table in tables:
218220
fields = session.query(CoreField).filter(and_(CoreField.table_id == table.id, CoreField.checked == True)).all()
219-
list.append(TableAndFields(table=table, fields=fields))
220-
return list
221+
_list.append(TableAndFields(schema=schema, table=table, fields=fields))
222+
return _list

backend/apps/datasource/models/datasource.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ def __init__(self, attr1, attr2, attr3):
111111

112112

113113
class TableAndFields:
114-
def __init__(self, table, fields):
114+
def __init__(self, schema, table, fields):
115+
self.schema = schema
115116
self.table = table
116117
self.fields = fields
117118

119+
schema: str
118120
table: CoreTable
119121
fields: List[CoreField]

0 commit comments

Comments
 (0)