diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 7b46c83b..3cf88749 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -18,6 +18,9 @@ from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk from sqlalchemy import and_, select from sqlalchemy.orm import sessionmaker +from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts +from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum +from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil from sqlmodel import Session from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config @@ -30,9 +33,6 @@ get_last_execute_sql_error from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ ChatFinishStep -from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil -from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts -from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum from apps.data_training.curd.data_training import get_training_template from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user @@ -111,7 +111,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, if not ds: raise SingleMessageError("No available datasource configuration found") chat_question.engine = ds.type + get_version(ds) - chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id) + chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question) else: ds = self.session.get(CoreDatasource, chat.datasource) if not ds: @@ -249,7 +249,7 @@ def generate_analysis(self): self.current_user.oid, ds_id) if SQLBotLicenseUtil.valid(): self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS, - self.current_user.oid, ds_id) + self.current_user.oid, ds_id) analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) @@ -298,7 +298,7 @@ def generate_predict(self): if SQLBotLicenseUtil.valid(): ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA, - self.current_user.oid, ds_id) + self.current_user.oid, ds_id) predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) @@ -343,10 +343,11 @@ def generate_recommend_questions_task(self): # get schema if self.ds and not self.chat_question.db_schema: self.chat_question.db_schema = self.out_ds_instance.get_db_schema( - self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, - current_user=self.current_user, ds=self.ds, - question=self.chat_question.question, - embedding=False) + self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( + session=self.session, + current_user=self.current_user, ds=self.ds, + question=self.chat_question.question, + embedding=False) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) @@ -478,7 +479,8 @@ def select_datasource(self): _ds = self.out_ds_instance.get_ds(data['id']) self.ds = _ds self.chat_question.engine = _ds.type + get_version(self.ds) - self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id) + self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id, + self.chat_question.question) _engine_type = self.chat_question.engine _chat.engine_type = _ds.type else: @@ -529,7 +531,7 @@ def select_datasource(self): oid) if SQLBotLicenseUtil.valid(): self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, - oid, ds_id) + oid, ds_id) self.init_messages() @@ -923,8 +925,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True, self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id, oid) if SQLBotLicenseUtil.valid(): - self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, - oid, ds_id) + self.chat_question.custom_prompt = find_custom_prompts(self.session, + CustomPromptTypeEnum.GENERATE_SQL, + oid, ds_id) self.init_messages() @@ -961,10 +964,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True, 'type': 'datasource'}).decode() + '\n\n' self.chat_question.db_schema = self.out_ds_instance.get_db_schema( - self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, - current_user=self.current_user, - ds=self.ds, - question=self.chat_question.question) + self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( + session=self.session, + current_user=self.current_user, + ds=self.ds, + question=self.chat_question.question) else: self.validate_history_ds() diff --git a/backend/apps/datasource/crud/datasource.py b/backend/apps/datasource/crud/datasource.py index c541a152..92efe3c6 100644 --- a/backend/apps/datasource/crud/datasource.py +++ b/backend/apps/datasource/crud/datasource.py @@ -404,7 +404,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat # do table embedding if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: - tables = get_table_embedding(session, current_user, tables, question) + tables = get_table_embedding(tables, question) # splice schema if tables: for s in tables: diff --git a/backend/apps/datasource/embedding/ds_embedding.py b/backend/apps/datasource/embedding/ds_embedding.py index 7aad7d29..a3570178 100644 --- a/backend/apps/datasource/embedding/ds_embedding.py +++ b/backend/apps/datasource/embedding/ds_embedding.py @@ -22,7 +22,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o if out_ds.ds_list: for _ds in out_ds.ds_list: ds = out_ds.get_ds(_ds.id) - table_schema = out_ds.get_db_schema(_ds.id) + table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False) ds_info = f"{ds.name}, {ds.description}\n" ds_schema = ds_info + table_schema _list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds}) diff --git a/backend/apps/datasource/embedding/table_embedding.py b/backend/apps/datasource/embedding/table_embedding.py index 1a3fe896..e8efc405 100644 --- a/backend/apps/datasource/embedding/table_embedding.py +++ b/backend/apps/datasource/embedding/table_embedding.py @@ -7,11 +7,10 @@ from apps.ai_model.embedding import EmbeddingModelCache from apps.datasource.embedding.utils import cosine_similarity from common.core.config import settings -from common.core.deps import SessionDep, CurrentUser from common.utils.utils import SQLBotLogUtil -def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str): +def get_table_embedding(tables: list[dict], question: str): _list = [] for table in tables: _list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0}) diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 912218b9..d690eec2 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -8,6 +8,7 @@ from sqlmodel import Session, select from starlette.middleware.cors import CORSMiddleware +from apps.datasource.embedding.table_embedding import get_table_embedding from apps.datasource.models.datasource import CoreDatasource, DatasourceConf from apps.system.models.system_model import AssistantModel from apps.system.schemas.auth import CacheName, CacheNamespace @@ -18,6 +19,7 @@ from common.utils.aes_crypto import simple_aes_decrypt from common.utils.utils import string_to_numeric_hash + @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id") async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None: db_model = session.get(AssistantModel, assistant_id) @@ -141,18 +143,22 @@ def get_simple_ds_list(self): else: raise Exception("Datasource list is not found.") - def get_db_schema(self, ds_id: int) -> str: + def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str: ds = self.get_ds(ds_id) schema_str = "" db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase schema_str += f"【DB_ID】 {db_name}\n【Schema】\n" + tables = [] + i = 0 for table in ds.tables: - schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}" + i += 1 + schema_table = '' + schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}" table_comment = table.comment if table_comment == '': - schema_str += '\n[\n' + schema_table += '\n[\n' else: - schema_str += f", {table_comment}\n[\n" + schema_table += f", {table_comment}\n[\n" field_list = [] for field in table.fields: @@ -161,8 +167,19 @@ def get_db_schema(self, ds_id: int) -> str: field_list.append(f"({field.name}:{field.type})") else: field_list.append(f"({field.name}:{field.type}, {field_comment})") - schema_str += ",\n".join(field_list) - schema_str += '\n]\n' + schema_table += ",\n".join(field_list) + schema_table += '\n]\n' + t_obj = {"id": i, "schema_table": schema_table} + tables.append(t_obj) + + # do table embedding + if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: + tables = get_table_embedding(tables, question) + + if tables: + for s in tables: + schema_str += s.get('schema_table') + return schema_str def get_ds(self, ds_id: int): @@ -186,7 +203,8 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche try: ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv) except Exception as e: - raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}") + raise Exception( + f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}") for attr in attr_list: if attr in ds_dict: id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--' @@ -195,6 +213,7 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche ds_dict.pop("schema", None) return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema}) + class AssistantOutDsFactory: @staticmethod def get_instance(assistant: AssistantHeader) -> AssistantOutDs: