|
30 | 30 | get_last_execute_sql_error |
31 | 31 | from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ |
32 | 32 | ChatFinishStep |
| 33 | +from sqlbot_xpack.license.license_manage import SQLBotLicenseUtil |
| 34 | +from sqlbot_xpack.custom_prompt.curd.custom_prompt import find_custom_prompts |
| 35 | +from sqlbot_xpack.custom_prompt.models.custom_prompt_model import CustomPromptTypeEnum |
33 | 36 | from apps.data_training.curd.data_training import get_training_template |
34 | 37 | from apps.datasource.crud.datasource import get_table_schema |
35 | 38 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
@@ -244,6 +247,9 @@ def generate_analysis(self): |
244 | 247 | ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None |
245 | 248 | self.chat_question.terminologies = get_terminology_template(self.session, self.chat_question.question, |
246 | 249 | self.current_user.oid, ds_id) |
| 250 | + if SQLBotLicenseUtil.valid(): |
| 251 | + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.ANALYSIS, |
| 252 | + self.current_user.oid, ds_id) |
247 | 253 |
|
248 | 254 | analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) |
249 | 255 | analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) |
@@ -288,6 +294,12 @@ def generate_predict(self): |
288 | 294 | self.chat_question.fields = orjson.dumps(fields).decode() |
289 | 295 | data = get_chat_chart_data(self.session, self.record.id) |
290 | 296 | self.chat_question.data = orjson.dumps(data.get('data')).decode() |
| 297 | + |
| 298 | + if SQLBotLicenseUtil.valid(): |
| 299 | + ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None |
| 300 | + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.PREDICT_DATA, |
| 301 | + self.current_user.oid, ds_id) |
| 302 | + |
291 | 303 | predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
292 | 304 | predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) |
293 | 305 | predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question())) |
@@ -509,6 +521,9 @@ def select_datasource(self): |
509 | 521 | ds_id) |
510 | 522 | self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, ds_id, |
511 | 523 | oid) |
| 524 | + if SQLBotLicenseUtil.valid(): |
| 525 | + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, |
| 526 | + oid, ds_id) |
512 | 527 |
|
513 | 528 | self.init_messages() |
514 | 529 |
|
@@ -902,6 +917,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True, |
902 | 917 | oid, ds_id) |
903 | 918 | self.chat_question.data_training = get_training_template(self.session, self.chat_question.question, |
904 | 919 | ds_id, oid) |
| 920 | + if SQLBotLicenseUtil.valid(): |
| 921 | + self.chat_question.custom_prompt = find_custom_prompts(self.session, CustomPromptTypeEnum.GENERATE_SQL, |
| 922 | + oid, ds_id) |
905 | 923 |
|
906 | 924 | self.init_messages() |
907 | 925 |
|
|
0 commit comments