Skip to content

Commit 4e3d34a

Browse files
committed
fix: Fix errors that may be caused by variable exceptions
1 parent 312bacb commit 4e3d34a

File tree

1 file changed

+13
-29
lines changed

1 file changed

+13
-29
lines changed

backend/apps/db/db.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import json
33
import urllib.parse
44
from decimal import Decimal
5-
from typing import Any
65

7-
from sqlalchemy import create_engine, text, Result, Engine
6+
from sqlalchemy import create_engine, text, Engine
87
from sqlalchemy.orm import sessionmaker
98

109
from apps.datasource.models.datasource import DatasourceConf, CoreDatasource, TableSchema, ColumnSchema
@@ -82,10 +81,8 @@ def get_session(ds: CoreDatasource | AssistantOutDsSchema):
8281

8382
def get_tables(ds: CoreDatasource):
8483
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
85-
session = get_session(ds)
86-
result: Result[Any]
87-
sql: str = ''
88-
try:
84+
with get_session(ds) as session:
85+
sql: str = ''
8986
if ds.type == "mysql":
9087
sql = f"""
9188
SELECT
@@ -147,24 +144,16 @@ def get_tables(ds: CoreDatasource):
147144
AND c.OWNER = '{conf.dbSchema}'
148145
ORDER BY t.TABLE_NAME
149146
"""
150-
151-
result = session.execute(text(sql))
152-
res = result.fetchall()
153-
res_list = [TableSchema(*item) for item in res]
154-
return res_list
155-
finally:
156-
if result is not None:
157-
result.close()
158-
if session is not None:
159-
session.close()
147+
with session.execute(text(sql)) as result:
148+
res = result.fetchall()
149+
res_list = [TableSchema(*item) for item in res]
150+
return res_list
160151

161152

162153
def get_fields(ds: CoreDatasource, table_name: str = None):
163154
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config()
164-
session = get_session(ds)
165-
result: Result[Any]
166-
sql: str = ''
167-
try:
155+
with get_session(ds) as session:
156+
sql: str = ''
168157
if ds.type == "mysql":
169158
sql1 = f"""
170159
SELECT
@@ -238,15 +227,10 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
238227
sql2 = f" AND col.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else ""
239228
sql = sql1 + sql2
240229

241-
result = session.execute(text(sql))
242-
res = result.fetchall()
243-
res_list = [ColumnSchema(*item) for item in res]
244-
return res_list
245-
finally:
246-
if result is not None:
247-
result.close()
248-
if session is not None:
249-
session.close()
230+
with session.execute(text(sql)) as result:
231+
res = result.fetchall()
232+
res_list = [ColumnSchema(*item) for item in res]
233+
return res_list
250234

251235

252236
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str):

0 commit comments

Comments
 (0)