|
2 | 2 | import json |
3 | 3 | import urllib.parse |
4 | 4 | from decimal import Decimal |
5 | | -from typing import Any |
6 | 5 |
|
7 | | -from sqlalchemy import create_engine, text, Result, Engine |
| 6 | +from sqlalchemy import create_engine, text, Engine |
8 | 7 | from sqlalchemy.orm import sessionmaker |
9 | 8 |
|
10 | 9 | from apps.datasource.models.datasource import DatasourceConf, CoreDatasource, TableSchema, ColumnSchema |
@@ -82,10 +81,8 @@ def get_session(ds: CoreDatasource | AssistantOutDsSchema): |
82 | 81 |
|
83 | 82 | def get_tables(ds: CoreDatasource): |
84 | 83 | conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() |
85 | | - session = get_session(ds) |
86 | | - result: Result[Any] |
87 | | - sql: str = '' |
88 | | - try: |
| 84 | + with get_session(ds) as session: |
| 85 | + sql: str = '' |
89 | 86 | if ds.type == "mysql": |
90 | 87 | sql = f""" |
91 | 88 | SELECT |
@@ -147,24 +144,16 @@ def get_tables(ds: CoreDatasource): |
147 | 144 | AND c.OWNER = '{conf.dbSchema}' |
148 | 145 | ORDER BY t.TABLE_NAME |
149 | 146 | """ |
150 | | - |
151 | | - result = session.execute(text(sql)) |
152 | | - res = result.fetchall() |
153 | | - res_list = [TableSchema(*item) for item in res] |
154 | | - return res_list |
155 | | - finally: |
156 | | - if result is not None: |
157 | | - result.close() |
158 | | - if session is not None: |
159 | | - session.close() |
| 147 | + with session.execute(text(sql)) as result: |
| 148 | + res = result.fetchall() |
| 149 | + res_list = [TableSchema(*item) for item in res] |
| 150 | + return res_list |
160 | 151 |
|
161 | 152 |
|
162 | 153 | def get_fields(ds: CoreDatasource, table_name: str = None): |
163 | 154 | conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if ds.type != "excel" else get_engine_config() |
164 | | - session = get_session(ds) |
165 | | - result: Result[Any] |
166 | | - sql: str = '' |
167 | | - try: |
| 155 | + with get_session(ds) as session: |
| 156 | + sql: str = '' |
168 | 157 | if ds.type == "mysql": |
169 | 158 | sql1 = f""" |
170 | 159 | SELECT |
@@ -238,15 +227,10 @@ def get_fields(ds: CoreDatasource, table_name: str = None): |
238 | 227 | sql2 = f" AND col.TABLE_NAME = '{table_name}'" if table_name is not None and table_name != "" else "" |
239 | 228 | sql = sql1 + sql2 |
240 | 229 |
|
241 | | - result = session.execute(text(sql)) |
242 | | - res = result.fetchall() |
243 | | - res_list = [ColumnSchema(*item) for item in res] |
244 | | - return res_list |
245 | | - finally: |
246 | | - if result is not None: |
247 | | - result.close() |
248 | | - if session is not None: |
249 | | - session.close() |
| 230 | + with session.execute(text(sql)) as result: |
| 231 | + res = result.fetchall() |
| 232 | + res_list = [ColumnSchema(*item) for item in res] |
| 233 | + return res_list |
250 | 234 |
|
251 | 235 |
|
252 | 236 | def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str): |
|
0 commit comments