|
38 | 38 | from apps.system.schemas.system_schema import AssistantOutDsSchema |
39 | 39 | from common.core.config import settings |
40 | 40 | from common.core.deps import CurrentAssistant, CurrentUser |
41 | | -from common.utils.utils import SQLBotLogUtil, extract_nested_json |
| 41 | +from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson |
42 | 42 |
|
43 | 43 | warnings.filterwarnings("ignore") |
44 | 44 |
|
@@ -71,7 +71,7 @@ def __init__(self, current_user: CurrentUser, chat_question: ChatQuestion, |
71 | 71 | engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) |
72 | 72 | session_maker = sessionmaker(bind=engine) |
73 | 73 | self.session = session_maker() |
74 | | - |
| 74 | + self.session.exec = self.session.exec if hasattr(self.session, "exec") else self.session.execute |
75 | 75 | self.current_user = current_user |
76 | 76 | self.current_assistant = current_assistant |
77 | 77 | # chat = self.session.query(Chat).filter(Chat.id == chat_question.chat_id).first() |
@@ -365,7 +365,7 @@ def select_datasource(self): |
365 | 365 | datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
366 | 366 | datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question())) |
367 | 367 | if self.current_assistant: |
368 | | - _ds_list = get_assistant_ds(llm_service=self) |
| 368 | + _ds_list = get_assistant_ds(session=self.session, llm_service=self) |
369 | 369 | else: |
370 | 370 | oid: str = self.current_user.oid |
371 | 371 | stmt = select(CoreDatasource.id, CoreDatasource.name, CoreDatasource.description).where( |
@@ -516,6 +516,41 @@ def generate_sql(self): |
516 | 516 | [{'type': msg.type, 'content': msg.content} for msg in |
517 | 517 | self.sql_message]).decode()) |
518 | 518 |
|
| 519 | + def generate_with_sub_sql(self, sql, sub_mappings: list): |
| 520 | + sub_query = json.dumps(sub_mappings, ensure_ascii=False) |
| 521 | + self.chat_question.sql = sql |
| 522 | + self.chat_question.sub_query = sub_query |
| 523 | + msg: List[Union[BaseMessage, dict[str, Any]]] = [] |
| 524 | + msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question())) |
| 525 | + msg.append(HumanMessage(content=self.chat_question.dynamic_user_question())) |
| 526 | + full_thinking_text = '' |
| 527 | + full_dynamic_text = '' |
| 528 | + res = self.llm.stream(msg) |
| 529 | + token_usage = {} |
| 530 | + for chunk in res: |
| 531 | + SQLBotLogUtil.info(chunk) |
| 532 | + reasoning_content_chunk = '' |
| 533 | + if 'reasoning_content' in chunk.additional_kwargs: |
| 534 | + reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') |
| 535 | + if reasoning_content_chunk is None: |
| 536 | + reasoning_content_chunk = '' |
| 537 | + full_thinking_text += reasoning_content_chunk |
| 538 | + full_dynamic_text += chunk.content |
| 539 | + get_token_usage(chunk, token_usage) |
| 540 | + |
| 541 | + SQLBotLogUtil.info(full_dynamic_text) |
| 542 | + return full_dynamic_text |
| 543 | + |
| 544 | + def generate_assistant_dynamic_sql(self, sql, tables: List): |
| 545 | + ds: AssistantOutDsSchema = self.ds |
| 546 | + sub_query = [] |
| 547 | + for table in ds.tables: |
| 548 | + if table.name in tables and table.sql: |
| 549 | + sub_query.append({"table": table.name, "query": table.sql}) |
| 550 | + if not sub_query: |
| 551 | + return None |
| 552 | + return self.generate_with_sub_sql(sql=sql, sub_mappings=sub_query) |
| 553 | + |
519 | 554 | def build_table_filter(self, sql: str, filters: list): |
520 | 555 | filter = json.dumps(filters, ensure_ascii=False) |
521 | 556 | self.chat_question.sql = sql |
@@ -635,27 +670,23 @@ def generate_chart(self): |
635 | 670 | full_message=orjson.dumps( |
636 | 671 | [{'type': msg.type, 'content': msg.content} for msg in |
637 | 672 | self.chart_message]).decode()) |
638 | | - |
639 | | - def check_save_sql(self, res: str) -> str: |
640 | | - |
| 673 | + def check_sql(self, res: str) -> tuple[any]: |
641 | 674 | json_str = extract_nested_json(res) |
642 | | - data = orjson.loads(json_str) |
643 | | - |
| 675 | + data: dict = orjson.loads(json_str) |
644 | 676 | sql = '' |
645 | 677 | message = '' |
646 | | - error = False |
647 | | - |
648 | 678 | if data['success']: |
649 | 679 | sql = data['sql'] |
650 | 680 | else: |
651 | 681 | message = data['message'] |
652 | | - error = True |
653 | | - |
654 | | - if error: |
655 | 682 | raise Exception(message) |
| 683 | + |
656 | 684 | if sql.strip() == '': |
657 | 685 | raise Exception("SQL query is empty") |
658 | | - |
| 686 | + return sql, data.get('tables') |
| 687 | + |
| 688 | + def check_save_sql(self, res: str) -> str: |
| 689 | + sql, *_ = self.check_sql(res=res) |
659 | 690 | save_sql(session=self.session, sql=sql, record_id=self.record.id) |
660 | 691 |
|
661 | 692 | self.chat_question.sql = sql |
@@ -716,6 +747,10 @@ def save_error(self, message: str): |
716 | 747 | return save_error_message(session=self.session, record_id=self.record.id, message=message) |
717 | 748 |
|
718 | 749 | def save_sql_data(self, data_obj: Dict[str, Any]): |
| 750 | + data_result = data_obj.get('data') |
| 751 | + if data_result: |
| 752 | + data_result = prepare_for_orjson(data_result) |
| 753 | + data_obj['data'] = data_result |
719 | 754 | return save_sql_exec_data(session=self.session, record_id=self.record.id, |
720 | 755 | data=orjson.dumps(data_obj).decode()) |
721 | 756 |
|
@@ -812,34 +847,27 @@ def run_task(self, in_chat: bool = True): |
812 | 847 |
|
813 | 848 | # filter sql |
814 | 849 | SQLBotLogUtil.info(full_sql_text) |
| 850 | + use_dynamic_ds: bool = self.current_assistant and self.current_assistant.type == 1 |
815 | 851 |
|
816 | 852 | # todo row permission |
817 | | - if (not self.current_assistant and is_normal_user(self.current_user)) or ( |
818 | | - self.current_assistant and self.current_assistant.type == 1): |
819 | | - sql_json_str = extract_nested_json(full_sql_text) |
820 | | - data = orjson.loads(sql_json_str) |
821 | | - |
822 | | - sql = '' |
823 | | - message = '' |
824 | | - error = False |
825 | | - if data['success']: |
826 | | - sql = data['sql'] |
827 | | - else: |
828 | | - message = data['message'] |
829 | | - error = True |
830 | | - if error: |
831 | | - raise Exception(message) |
832 | | - if sql.strip() == '': |
833 | | - raise Exception("SQL query is empty") |
| 853 | + if (not self.current_assistant and is_normal_user(self.current_user)) or use_dynamic_ds: |
| 854 | + sql, tables = self.check_sql(res=full_sql_text) |
834 | 855 |
|
835 | 856 | if self.current_assistant: |
836 | | - sql_result = self.generate_assistant_filter(data.get('sql'), data.get('tables')) |
| 857 | + dynamic_sql_result = self.generate_assistant_dynamic_sql(sql, tables) |
| 858 | + if dynamic_sql_result: |
| 859 | + SQLBotLogUtil.info(dynamic_sql_result) |
| 860 | + sql, *_ = self.check_sql(res=dynamic_sql_result) |
| 861 | + |
| 862 | + sql_result = self.generate_assistant_filter(sql, tables) |
837 | 863 | else: |
838 | | - sql_result = self.generate_filter(data.get('sql'), data.get('tables')) # maybe no sql and tables |
| 864 | + sql_result = self.generate_filter(sql, tables) # maybe no sql and tables |
839 | 865 |
|
840 | 866 | if sql_result: |
841 | 867 | SQLBotLogUtil.info(sql_result) |
842 | 868 | sql = self.check_save_sql(res=sql_result) |
| 869 | + elif dynamic_sql_result: |
| 870 | + sql = self.check_save_sql(res=dynamic_sql_result) |
843 | 871 | else: |
844 | 872 | sql = self.check_save_sql(res=full_sql_text) |
845 | 873 | else: |
|
0 commit comments