|
15 | 15 | from langchain_community.utilities import SQLDatabase |
16 | 16 | from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk |
17 | 17 | from sqlalchemy import select |
18 | | -from sqlalchemy.exc import DBAPIError |
19 | 18 | from sqlalchemy.orm import sessionmaker |
20 | 19 | from sqlmodel import create_engine, Session |
21 | 20 |
|
|
30 | 29 | from apps.datasource.crud.datasource import get_table_schema |
31 | 30 | from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user |
32 | 31 | from apps.datasource.models.datasource import CoreDatasource |
33 | | -from apps.db.db import exec_sql, get_version |
| 32 | +from apps.db.db import exec_sql, get_version, check_connection |
34 | 33 | from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds |
35 | 34 | from apps.system.schemas.system_schema import AssistantOutDsSchema |
36 | 35 | from apps.terminology.curd.terminology import get_terminology_template |
37 | 36 | from common.core.config import settings |
38 | 37 | from common.core.deps import CurrentAssistant, CurrentUser |
39 | | -from common.error import SingleMessageError |
| 38 | +from common.error import SingleMessageError, SQLBotDBError, ParseSQLResultError, SQLBotDBConnectionError |
40 | 39 | from common.utils.utils import SQLBotLogUtil, extract_nested_json, prepare_for_orjson |
41 | 40 |
|
42 | 41 | warnings.filterwarnings("ignore") |
@@ -868,7 +867,14 @@ def execute_sql(self, sql: str): |
868 | 867 | Query results |
869 | 868 | """ |
870 | 869 | SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}") |
871 | | - return exec_sql(self.ds, sql) |
| 870 | + try: |
| 871 | + return exec_sql(self.ds, sql) |
| 872 | + except Exception as e: |
| 873 | + if isinstance(e, ParseSQLResultError): |
| 874 | + raise e |
| 875 | + else: |
| 876 | + err = traceback.format_exc(limit=1, chain=True) |
| 877 | + raise SQLBotDBError(err) |
872 | 878 |
|
873 | 879 | def pop_chunk(self): |
874 | 880 | try: |
@@ -940,6 +946,12 @@ def run_task(self, in_chat: bool = True): |
940 | 946 | ds=self.ds) |
941 | 947 | else: |
942 | 948 | self.validate_history_ds() |
| 949 | + |
| 950 | + # check connection |
| 951 | + connected = check_connection(ds=self.ds, trans=None) |
| 952 | + if not connected: |
| 953 | + raise SQLBotDBConnectionError('Connect DB failed') |
| 954 | + |
943 | 955 | # generate sql |
944 | 956 | sql_res = self.generate_sql() |
945 | 957 | full_sql_text = '' |
@@ -1059,13 +1071,12 @@ def run_task(self, in_chat: bool = True): |
1059 | 1071 | error_msg: str |
1060 | 1072 | if isinstance(e, SingleMessageError): |
1061 | 1073 | error_msg = str(e) |
1062 | | - elif isinstance(e, ConnectionError): |
| 1074 | + elif isinstance(e, SQLBotDBConnectionError): |
1063 | 1075 | error_msg = orjson.dumps( |
1064 | | - {'message': str(e), 'traceback': traceback.format_exc(limit=1), |
1065 | | - 'type': 'db-connection-err'}).decode() |
1066 | | - elif isinstance(e, DBAPIError): |
| 1076 | + {'message': str(e), 'type': 'db-connection-err'}).decode() |
| 1077 | + elif isinstance(e, SQLBotDBError): |
1067 | 1078 | error_msg = orjson.dumps( |
1068 | | - {'message': str(e), 'traceback': traceback.format_exc(limit=1), 'type': 'exec-sql-err'}).decode() |
| 1079 | + {'message': 'Execute SQL Failed', 'traceback': str(e), 'type': 'exec-sql-err'}).decode() |
1069 | 1080 | else: |
1070 | 1081 | error_msg = orjson.dumps({'message': str(e), 'traceback': traceback.format_exc(limit=1)}).decode() |
1071 | 1082 | self.save_error(message=error_msg) |
|
0 commit comments