Skip to content
This repository was archived by the owner on May 5, 2022. It is now read-only.

Commit 111e02c

Browse files
committed
refactor: re-implement get_columns function by get info from information_schema.columns table
1 parent 565ef0f commit 111e02c

File tree

2 files changed

+46
-24
lines changed

2 files changed

+46
-24
lines changed

sqlalchemy_trino/dialect.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,30 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
9090

9191
def get_columns(self, connection: Connection,
9292
table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
93-
full_table = self._get_full_table(table_name, schema)
94-
try:
95-
rows = self._get_table_columns(connection, full_table)
96-
columns = []
97-
for row in rows:
98-
columns.append(dict(
99-
name=row.Column,
100-
type=datatype.parse_sqltype(row.Type),
101-
nullable=getattr(row, 'Null', True),
102-
default=None,
103-
))
104-
return columns
105-
except error.TrinoQueryError as e:
106-
if e.error_name in (error.TABLE_NOT_FOUND, error.SCHEMA_NOT_FOUND, error.CATALOG_NOT_FOUND):
107-
raise exc.NoSuchTableError(full_table) from e
108-
raise
93+
if not self.has_table(connection, table_name, schema):
94+
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
95+
schema = schema or self._get_default_schema_name(connection)
96+
query = dedent("""
97+
SELECT
98+
"column_name",
99+
"column_default",
100+
"is_nullable",
101+
"data_type"
102+
FROM "information_schema"."columns"
103+
WHERE "table_schema" = :schema AND "table_name" = :table
104+
ORDER BY "ordinal_position" ASC
105+
""").strip()
106+
res = connection.execute(sql.text(query), schema=schema, table=table_name)
107+
columns = []
108+
for record in res:
109+
column = dict(
110+
name=record.column_name,
111+
type=datatype.parse_sqltype(record.data_type),
112+
nullable=(record.is_nullable or '').upper() == 'YES',
113+
default=record.column_default,
114+
)
115+
columns.append(column)
116+
return columns
109117

110118
def get_pk_constraint(self, connection: Connection,
111119
table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
@@ -158,7 +166,11 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
158166
res = connection.execute(sql.text(query))
159167
return res.first()[0]
160168
except error.TrinoQueryError as e:
161-
if e.error_name in (error.TABLE_NOT_FOUND, error.SCHEMA_NOT_FOUND, error.CATALOG_NOT_FOUND):
169+
if e.error_name in (
170+
error.TABLE_NOT_FOUND,
171+
error.SCHEMA_NOT_FOUND,
172+
error.CATALOG_NOT_FOUND,
173+
):
162174
raise exc.NoSuchTableError(full_view) from e
163175
raise
164176

@@ -186,7 +198,11 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
186198
res = connection.execute(sql.text(query))
187199
return res.first() is not None
188200
except error.TrinoQueryError as e:
189-
if e.error_name in (error.TABLE_NOT_FOUND, error.SCHEMA_NOT_FOUND, error.CATALOG_NOT_FOUND):
201+
if e.error_name in (
202+
error.TABLE_NOT_FOUND,
203+
error.SCHEMA_NOT_FOUND,
204+
error.CATALOG_NOT_FOUND,
205+
):
190206
return False
191207
raise
192208

@@ -200,7 +216,12 @@ def has_table(self, connection: Connection,
200216
res = connection.execute(sql.text(query))
201217
return res.first() is not None
202218
except error.TrinoQueryError as e:
203-
if e.error_name in (error.TABLE_NOT_FOUND, error.SCHEMA_NOT_FOUND, error.CATALOG_NOT_FOUND):
219+
if e.error_name in (
220+
error.TABLE_NOT_FOUND,
221+
error.SCHEMA_NOT_FOUND,
222+
error.CATALOG_NOT_FOUND,
223+
error.MISSING_SCHEMA_NAME,
224+
):
204225
return False
205226
raise
206227

@@ -255,11 +276,6 @@ def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
255276
"SERIALIZABLE"]
256277
return level_names[dbapi_conn.isolation_level]
257278

258-
@staticmethod
259-
def _get_table_columns(connection: Connection, full_table: str):
260-
stmt = sql.text(f'SHOW COLUMNS FROM {full_table}')
261-
return connection.execute(stmt)
262-
263279
def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str:
264280
table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name
265281
if schema:

sqlalchemy_trino/error.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
TrinoQueryError
33
)
44

5+
# ref: https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java
56
TABLE_NOT_FOUND = 'TABLE_NOT_FOUND'
67
SCHEMA_NOT_FOUND = 'SCHEMA_NOT_FOUND'
78
CATALOG_NOT_FOUND = 'CATALOG_NOT_FOUND'
9+
10+
MISSING_TABLE = 'MISSING_TABLE'
11+
MISSING_COLUMN_NAME = 'MISSING_COLUMN_NAME'
12+
MISSING_SCHEMA_NAME = 'MISSING_SCHEMA_NAME'
13+
MISSING_CATALOG_NAME = 'MISSING_CATALOG_NAME'

0 commit comments

Comments
 (0)