Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk
from sqlalchemy import and_, select
from sqlalchemy.orm import sessionmaker
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
from sqlmodel import Session

from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config
Expand All @@ -30,9 +33,6 @@
get_last_execute_sql_error
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep
from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil
from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts
from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum
from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion,
if not ds:
raise SingleMessageError("No available datasource configuration found")
chat_question.engine = ds.type + get_version(ds)
chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id)
chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question)
else:
ds = self.session.get(CoreDatasource, chat.datasource)
if not ds:
Expand Down Expand Up @@ -249,7 +249,7 @@ def generate_analysis(self):
self.current_user.oid, ds_id)
if SQLBotLicenseUtil.valid():
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS,
self.current_user.oid, ds_id)
self.current_user.oid, ds_id)

analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question()))
analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question()))
Expand Down Expand Up @@ -298,7 +298,7 @@ def generate_predict(self):
if SQLBotLicenseUtil.valid():
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA,
self.current_user.oid, ds_id)
self.current_user.oid, ds_id)

predict_msg: List[Union[BaseMessage, dict[str, Any]]] = []
predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question()))
Expand Down Expand Up @@ -343,10 +343,11 @@ def generate_recommend_questions_task(self):
# get schema
if self.ds and not self.chat_question.db_schema:
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
embedding=False)
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
session=self.session,
current_user=self.current_user, ds=self.ds,
question=self.chat_question.question,
embedding=False)

guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question()))
Expand Down Expand Up @@ -478,7 +479,8 @@ def select_datasource(self):
_ds = self.out_ds_instance.get_ds(data['id'])
self.ds = _ds
self.chat_question.engine = _ds.type + get_version(self.ds)
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id)
self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id,
self.chat_question.question)
_engine_type = self.chat_question.engine
_chat.engine_type = _ds.type
else:
Expand Down Expand Up @@ -529,7 +531,7 @@ def select_datasource(self):
oid)
if SQLBotLicenseUtil.valid():
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
oid, ds_id)
oid, ds_id)

self.init_messages()

Expand Down Expand Up @@ -923,8 +925,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
self.chat_question.data_training = get_training_template(self.session, self.chat_question.question,
ds_id, oid)
if SQLBotLicenseUtil.valid():
self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL,
oid, ds_id)
self.chat_question.custom_prompt = find_custom_prompts(self.session,
CustomPromptTypeEnum.GENERATE_SQL,
oid, ds_id)

self.init_messages()

Expand Down Expand Up @@ -961,10 +964,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
'type': 'datasource'}).decode() + '\n\n'

self.chat_question.db_schema = self.out_ds_instance.get_db_schema(
self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question)
self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema(
session=self.session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question)
else:
self.validate_history_ds()

Expand Down
2 changes: 1 addition & 1 deletion backend/apps/datasource/crud/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = get_table_embedding(session, current_user, tables, question)
tables = get_table_embedding(tables, question)
# splice schema
if tables:
for s in tables:
Expand Down
2 changes: 1 addition & 1 deletion backend/apps/datasource/embedding/ds_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_ds_embedding(session: SessionDep, current_user: CurrentUser, _ds_list, o
if out_ds.ds_list:
for _ds in out_ds.ds_list:
ds = out_ds.get_ds(_ds.id)
table_schema = out_ds.get_db_schema(_ds.id)
table_schema = out_ds.get_db_schema(_ds.id, question, embedding=False)
ds_info = f"{ds.name}, {ds.description}\n"
ds_schema = ds_info + table_schema
_list.append({"id": ds.id, "ds_schema": ds_schema, "cosine_similarity": 0.0, "ds": ds})
Expand Down
3 changes: 1 addition & 2 deletions backend/apps/datasource/embedding/table_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from apps.ai_model.embedding import EmbeddingModelCache
from apps.datasource.embedding.utils import cosine_similarity
from common.core.config import settings
from common.core.deps import SessionDep, CurrentUser
from common.utils.utils import SQLBotLogUtil


def get_table_embedding(session: SessionDep, current_user: CurrentUser, tables: list[dict], question: str):
def get_table_embedding(tables: list[dict], question: str):
_list = []
for table in tables:
_list.append({"id": table.get('id'), "schema_table": table.get('schema_table'), "cosine_similarity": 0.0})
Expand Down
33 changes: 26 additions & 7 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlmodel import Session, select
from starlette.middleware.cors import CORSMiddleware

from apps.datasource.embedding.table_embedding import get_table_embedding
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
Expand All @@ -18,6 +19,7 @@
from common.utils.aes_crypto import simple_aes_decrypt
from common.utils.utils import string_to_numeric_hash


@cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id")
async def get_assistant_info(*, session: Session, assistant_id: int) -> AssistantModel | None:
db_model = session.get(AssistantModel, assistant_id)
Expand Down Expand Up @@ -141,18 +143,22 @@ def get_simple_ds_list(self):
else:
raise Exception("Datasource list is not found.")

def get_db_schema(self, ds_id: int) -> str:
def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
tables = []
i = 0
for table in ds.tables:
schema_str += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
i += 1
schema_table = ''
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" else f"# Table: {table.name}"
table_comment = table.comment
if table_comment == '':
schema_str += '\n[\n'
schema_table += '\n[\n'
else:
schema_str += f", {table_comment}\n[\n"
schema_table += f", {table_comment}\n[\n"

field_list = []
for field in table.fields:
Expand All @@ -161,8 +167,19 @@ def get_db_schema(self, ds_id: int) -> str:
field_list.append(f"({field.name}:{field.type})")
else:
field_list.append(f"({field.name}:{field.type}, {field_comment})")
schema_str += ",\n".join(field_list)
schema_str += '\n]\n'
schema_table += ",\n".join(field_list)
schema_table += '\n]\n'
t_obj = {"id": i, "schema_table": schema_table}
tables.append(t_obj)

# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = get_table_embedding(tables, question)

if tables:
for s in tables:
schema_str += s.get('schema_table')

return schema_str

def get_ds(self, ds_id: int):
Expand All @@ -186,7 +203,8 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche
try:
ds_dict[attr] = simple_aes_decrypt(ds_dict[attr], key, iv)
except Exception as e:
raise Exception(f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
raise Exception(
f"Failed to encrypt {attr} for datasource {ds_dict.get('name')}, error: {str(e)}")
for attr in attr_list:
if attr in ds_dict:
id_marker += str(ds_dict.get(attr, '')) + '--sqlbot--'
Expand All @@ -195,6 +213,7 @@ def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSche
ds_dict.pop("schema", None)
return AssistantOutDsSchema(**{**ds_dict, "id": id, "db_schema": db_schema})


class AssistantOutDsFactory:
@staticmethod
def get_instance(assistant: AssistantHeader) -> AssistantOutDs:
Expand Down
Loading