Skip to content

Commit c546645

Browse files
committed
feat: Vector retrieval matches tables
1 parent a5a37a6 commit c546645

File tree

5 files changed

+51
-29
lines changed

5 files changed

+51
-29
lines changed

backend/apps/chat/task/llm.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
1919
from sqlalchemy import and_, select
2020
from sqlalchemy.orm import sessionmaker
21+
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
22+
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
23+
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
2124
from sqlmodel import Session
2225

2326
from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
@@ -30,9 +33,6 @@
3033
get_last_execute_sql_error
3134
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
3235
ChatFinishStep
33-
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
34-
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
35-
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
3636
from apps.data_training.curd.data_training import get_training_template
3737
from apps.datasource.crud.datasource import get_table_schema
3838
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
@@ -111,7 +111,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
111111
if not ds:
112112
raise SingleMessageError("No available datasource configuration found")
113113
chat_question.engine = ds.type + get_version(ds)
114-
chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id)
114+
chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question)
115115
else:
116116
ds = self.session.get(CoreDatasource, chat.datasource)
117117
if not ds:
@@ -249,7 +249,7 @@ def generate_analysis(self):
249249
self.current_user.oid, ds_id)
250250
if SQLBotLicenseUtil.valid():
251251
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS,
252-
self.current_user.oid, ds_id)
252+
self.current_user.oid, ds_id)
253253

254254
analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
255255
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
@@ -298,7 +298,7 @@ def generate_predict(self):
298298
if SQLBotLicenseUtil.valid():
299299
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
300300
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA,
301-
self.current_user.oid, ds_id)
301+
self.current_user.oid, ds_id)
302302

303303
predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
304304
predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question()))
@@ -343,10 +343,11 @@ def generate_recommend_questions_task(self):
343343
# get schema
344344
if self.ds and not self.chat_question.db_schema:
345345
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
346-
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
347-
current_user=self.current_user, ds=self.ds,
348-
question=self.chat_question.question,
349-
embedding=False)
346+
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
347+
session=self.session,
348+
current_user=self.current_user, ds=self.ds,
349+
question=self.chat_question.question,
350+
embedding=False)
350351

351352
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
352353
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
@@ -478,7 +479,8 @@ def select_datasource(self):
478479
_ds = self.out_ds_instance.get_ds(data['id'])
479480
self.ds = _ds
480481
self.chat_question.engine = _ds.type + get_version(self.ds)
481-
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id)
482+
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id,
483+
self.chat_question.question)
482484
_engine_type = self.chat_question.engine
483485
_chat.engine_type = _ds.type
484486
else:
@@ -529,7 +531,7 @@ def select_datasource(self):
529531
oid)
530532
if SQLBotLicenseUtil.valid():
531533
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
532-
oid, ds_id)
534+
oid, ds_id)
533535

534536
self.init_messages()
535537

@@ -923,8 +925,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
923925
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
924926
ds_id, oid)
925927
if SQLBotLicenseUtil.valid():
926-
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
927-
oid, ds_id)
928+
self.chat_question.custom_prompt = find_custom_prompts(self.session,
929+
CustomPromptTypeEnum.GENERATE_SQL,
930+
oid, ds_id)
928931

929932
self.init_messages()
930933

@@ -961,10 +964,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
961964
'type': 'datasource'}).decode() + '\n\n'
962965

963966
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
964-
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
965-
current_user=self.current_user,
966-
ds=self.ds,
967-
question=self.chat_question.question)
967+
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
968+
session=self.session,
969+
current_user=self.current_user,
970+
ds=self.ds,
971+
question=self.chat_question.question)
968972
else:
969973
self.validate_history_ds()
970974

backend/apps/datasource/crud/datasource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
404404

405405
# do table embedding
406406
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
407-
tables = get_table_embedding(session, current_user, tables, question)
407+
tables = get_table_embedding(tables, question)
408408
# splice schema
409409
if tables:
410410
for s in tables:

backend/apps/datasource/embedding/ds_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
2222
if out_ds.ds_list:
2323
for _ds in out_ds.ds_list:
2424
ds = out_ds.get_ds(_ds.id)
25-
table_schema = out_ds.get_db_schema(_ds.id)
25+
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
2626
ds_info = f"{ds.name}, {ds.description}\n"
2727
ds_schema = ds_info + table_schema
2828
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from apps.ai_model.embedding import EmbeddingModelCache
88
from apps.datasource.embedding.utils import cosine_similarity
99
from common.core.config import settings
10-
from common.core.deps import SessionDep, CurrentUser
1110
from common.utils.utils import SQLBotLogUtil
1211

1312

14-
def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str):
13+
def get_table_embedding(tables: list[dict], question: str):
1514
_list = []
1615
for table in tables:
1716
_list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})

backend/apps/system/crud/assistant.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmodel import Session, select
99
from starlette.middleware.cors import CORSMiddleware
1010

11+
from apps.datasource.embedding.table_embedding import get_table_embedding
1112
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
1213
from apps.system.models.system_model import AssistantModel
1314
from apps.system.schemas.auth import CacheName, CacheNamespace
@@ -18,6 +19,7 @@
1819
from common.utils.aes_crypto import simple_aes_decrypt
1920
from common.utils.utils import string_to_numeric_hash
2021

22+
2123
@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
2224
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
2325
db_model = session.get(AssistantModel, assistant_id)
@@ -141,18 +143,22 @@ def get_simple_ds_list(self):
141143
else:
142144
raise Exception("Datasource list is not found.")
143145

144-
def get_db_schema(self, ds_id: int) -> str:
146+
def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
145147
ds = self.get_ds(ds_id)
146148
schema_str = ""
147149
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
148150
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
151+
tables = []
152+
i = 0
149153
for table in ds.tables:
150-
schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
154+
i += 1
155+
schema_table = ''
156+
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
151157
table_comment = table.comment
152158
if table_comment == '':
153-
schema_str += '\n[\n'
159+
schema_table += '\n[\n'
154160
else:
155-
schema_str += f", {table_comment}\n[\n"
161+
schema_table += f", {table_comment}\n[\n"
156162

157163
field_list = []
158164
for field in table.fields:
@@ -161,8 +167,19 @@ def get_db_schema(self, ds_id: int) -> str:
161167
field_list.append(f"({field.name}:{field.type})")
162168
else:
163169
field_list.append(f"({field.name}:{field.type}, {field_comment})")
164-
schema_str += ",\n".join(field_list)
165-
schema_str += '\n]\n'
170+
schema_table += ",\n".join(field_list)
171+
schema_table += '\n]\n'
172+
t_obj = {"id": i, "schema_table": schema_table}
173+
tables.append(t_obj)
174+
175+
# do table embedding
176+
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
177+
tables = get_table_embedding(tables, question)
178+
179+
if tables:
180+
for s in tables:
181+
schema_str += s.get('schema_table')
182+
166183
return schema_str
167184

168185
def get_ds(self, ds_id: int):
@@ -186,7 +203,8 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche
186203
try:
187204
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
188205
except Exception as e:
189-
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
206+
raise Exception(
207+
f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
190208
for attr in attr_list:
191209
if attr in ds_dict:
192210
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
@@ -195,6 +213,7 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche
195213
ds_dict.pop("schema", None)
196214
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})
197215

216+
198217
class AssistantOutDsFactory:
199218
@staticmethod
200219
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:

0 commit comments

Comments
 (0)