|
18 | 18 | from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk |
19 | 19 | from sqlalchemy import and_, select |
20 | 20 | from sqlalchemy.orm import sessionmaker |
| 21 | +from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts |
| 22 | +from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum |
| 23 | +from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil |
21 | 24 | from sqlmodel import Session |
22 | 25 |
|
23 | 26 | from apps.ai_model.model_factory import LLMConfig, LLMFactory, get_default_config |
|
30 | 33 | get_last_execute_sql_error |
31 | 34 | from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ |
32 | 35 | ChatFinishStep |
33 | | -from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil |
34 | | -from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts |
35 | | -from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum |
36 | 36 | from apps.data_training.curd.data_training import get_training_template |
37 | 37 | from apps.datasource.crud.datasource import get_table_schema |
38 | 38 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
@@ -111,7 +111,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, |
111 | 111 | if not ds: |
112 | 112 | raise SingleMessageError("No available datasource configuration found") |
113 | 113 | chat_question.engine = ds.type + get_version(ds) |
114 | | - chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id) |
| 114 | + chat_question.db_schema = self.out_ds_instance.get_db_schema(ds.id, chat_question.question) |
115 | 115 | else: |
116 | 116 | ds = self.session.get(CoreDatasource, chat.datasource) |
117 | 117 | if not ds: |
@@ -249,7 +249,7 @@ def generate_analysis(self): |
249 | 249 | self.current_user.oid, ds_id) |
250 | 250 | if SQLBotLicenseUtil.valid(): |
251 | 251 | self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS, |
252 | | - self.current_user.oid, ds_id) |
| 252 | + self.current_user.oid, ds_id) |
253 | 253 |
|
254 | 254 | analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) |
255 | 255 | analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) |
@@ -298,7 +298,7 @@ def generate_predict(self): |
298 | 298 | if SQLBotLicenseUtil.valid(): |
299 | 299 | ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None |
300 | 300 | self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA, |
301 | | - self.current_user.oid, ds_id) |
| 301 | + self.current_user.oid, ds_id) |
302 | 302 |
|
303 | 303 | predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
304 | 304 | predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) |
@@ -343,10 +343,11 @@ def generate_recommend_questions_task(self): |
343 | 343 | # get schema |
344 | 344 | if self.ds and not self.chat_question.db_schema: |
345 | 345 | self.chat_question.db_schema = self.out_ds_instance.get_db_schema( |
346 | | - self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, |
347 | | - current_user=self.current_user, ds=self.ds, |
348 | | - question=self.chat_question.question, |
349 | | - embedding=False) |
| 346 | + self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( |
| 347 | + session=self.session, |
| 348 | + current_user=self.current_user, ds=self.ds, |
| 349 | + question=self.chat_question.question, |
| 350 | + embedding=False) |
350 | 351 |
|
351 | 352 | guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
352 | 353 | guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question())) |
@@ -478,7 +479,8 @@ def select_datasource(self): |
478 | 479 | _ds = self.out_ds_instance.get_ds(data['id']) |
479 | 480 | self.ds = _ds |
480 | 481 | self.chat_question.engine = _ds.type + get_version(self.ds) |
481 | | - self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id) |
| 482 | + self.chat_question.db_schema = self.out_ds_instance.get_db_schema(self.ds.id, |
| 483 | + self.chat_question.question) |
482 | 484 | _engine_type = self.chat_question.engine |
483 | 485 | _chat.engine_type = _ds.type |
484 | 486 | else: |
@@ -529,7 +531,7 @@ def select_datasource(self): |
529 | 531 | oid) |
530 | 532 | if SQLBotLicenseUtil.valid(): |
531 | 533 | self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, |
532 | | - oid, ds_id) |
| 534 | + oid, ds_id) |
533 | 535 |
|
534 | 536 | self.init_messages() |
535 | 537 |
|
@@ -923,8 +925,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
923 | 925 | self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, |
924 | 926 | ds_id, oid) |
925 | 927 | if SQLBotLicenseUtil.valid(): |
926 | | - self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, |
927 | | - oid, ds_id) |
| 928 | + self.chat_question.custom_prompt = find_custom_prompts(self.session, |
| 929 | + CustomPromptTypeEnum.GENERATE_SQL, |
| 930 | + oid, ds_id) |
928 | 931 |
|
929 | 932 | self.init_messages() |
930 | 933 |
|
@@ -961,10 +964,11 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
961 | 964 | 'type': 'datasource'}).decode() + '\n\n' |
962 | 965 |
|
963 | 966 | self.chat_question.db_schema = self.out_ds_instance.get_db_schema( |
964 | | - self.ds.id) if self.out_ds_instance else get_table_schema(session=self.session, |
965 | | - current_user=self.current_user, |
966 | | - ds=self.ds, |
967 | | - question=self.chat_question.question) |
| 967 | + self.ds.id, self.chat_question.question) if self.out_ds_instance else get_table_schema( |
| 968 | + session=self.session, |
| 969 | + current_user=self.current_user, |
| 970 | + ds=self.ds, |
| 971 | + question=self.chat_question.question) |
968 | 972 | else: |
969 | 973 | self.validate_history_ds() |
970 | 974 |
|
|
0 commit comments