@@ -33,7 +33,6 @@ def get_rowcount(self):
33
33
class TrinoDialect (DefaultDialect ):
34
34
name = 'trino'
35
35
driver = 'rest'
36
- paramstyle = 'pyformat' # trino.dbapi.paramstyle
37
36
38
37
execution_ctx_cls = TrinoExecutionContext
39
38
statement_compiler = compiler .TrinoSQLCompiler
@@ -47,7 +46,7 @@ class TrinoDialect(DefaultDialect):
47
46
supports_native_decimal = True
48
47
49
48
# Column options
50
- supports_sequences = False # TODO: check
49
+ supports_sequences = False
51
50
supports_comments = True
52
51
inline_comments = True
53
52
supports_default_values = False
@@ -56,7 +55,7 @@ class TrinoDialect(DefaultDialect):
56
55
supports_alter = True
57
56
58
57
# DML
59
- supports_empty_insert = False # TODO: check
58
+ supports_empty_insert = False
60
59
supports_multivalues_insert = True
61
60
62
61
@classmethod
@@ -91,13 +90,13 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
91
90
def get_columns (self , connection : Connection ,
92
91
table_name : str , schema : str = None , ** kw ) -> List [Dict [str , Any ]]:
93
92
if not self .has_table (connection , table_name , schema ):
94
- raise exc .NoSuchTableError (f" schema={ schema } , table={ table_name } " )
93
+ raise exc .NoSuchTableError (f' schema={ schema } , table={ table_name } ' )
95
94
return self ._get_columns (connection , table_name , schema , ** kw )
96
95
97
96
def _get_columns (self , connection : Connection ,
98
97
table_name : str , schema : str = None , ** kw ) -> List [Dict [str , Any ]]:
99
98
schema = schema or self ._get_default_schema_name (connection )
100
- query = dedent ("""
99
+ query = dedent ('''
101
100
SELECT
102
101
"column_name",
103
102
"column_default",
@@ -106,7 +105,7 @@ def _get_columns(self, connection: Connection,
106
105
FROM "information_schema"."columns"
107
106
WHERE "table_schema" = :schema AND "table_name" = :table
108
107
ORDER BY "ordinal_position" ASC
109
- """ ).strip ()
108
+ ''' ).strip ()
110
109
res = connection .execute (sql .text (query ), schema = schema , table = table_name )
111
110
columns = []
112
111
for record in res :
@@ -135,14 +134,14 @@ def get_foreign_keys(self, connection: Connection,
135
134
return []
136
135
137
136
def get_schema_names (self , connection : Connection , ** kw ) -> List [str ]:
138
- query = " SHOW SCHEMAS"
137
+ query = ' SHOW SCHEMAS'
139
138
res = connection .execute (sql .text (query ))
140
139
return [row .Schema for row in res ]
141
140
142
141
def get_table_names (self , connection : Connection , schema : str = None , ** kw ) -> List [str ]:
143
- query = " SHOW TABLES"
142
+ query = ' SHOW TABLES'
144
143
if schema :
145
- query = f" { query } FROM { self .identifier_preparer .quote_identifier (schema )} "
144
+ query = f' { query } FROM { self .identifier_preparer .quote_identifier (schema )} '
146
145
res = connection .execute (sql .text (query ))
147
146
return [row .Table for row in res ]
148
147
@@ -153,12 +152,12 @@ def get_temp_table_names(self, connection: Connection, schema: str = None, **kw)
153
152
def get_view_names (self , connection : Connection , schema : str = None , ** kw ) -> List [str ]:
154
153
schema = schema or self ._get_default_schema_name (connection )
155
154
if schema is None :
156
- raise exc .NoSuchTableError (" schema is required" )
157
- query = dedent ("""
155
+ raise exc .NoSuchTableError (' schema is required' )
156
+ query = dedent ('''
158
157
SELECT "table_name"
159
158
FROM "information_schema"."views"
160
159
WHERE "table_schema" = :schema
161
- """ ).strip ()
160
+ ''' ).strip ()
162
161
res = connection .execute (sql .text (query ), schema = schema )
163
162
return [row .table_name for row in res ]
164
163
@@ -168,7 +167,7 @@ def get_temp_view_names(self, connection: Connection, schema: str = None, **kw)
168
167
169
168
def get_view_definition (self , connection : Connection , view_name : str , schema : str = None , ** kw ) -> str :
170
169
full_view = self ._get_full_table (view_name , schema )
171
- query = f" SHOW CREATE VIEW { full_view } "
170
+ query = f' SHOW CREATE VIEW { full_view } '
172
171
try :
173
172
res = connection .execute (sql .text (query ))
174
173
return res .first ()[0 ]
@@ -184,11 +183,11 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
184
183
def get_indexes (self , connection : Connection ,
185
184
table_name : str , schema : str = None , ** kw ) -> List [Dict [str , Any ]]:
186
185
if not self .has_table (connection , table_name , schema ):
187
- raise exc .NoSuchTableError (f" schema={ schema } , table={ table_name } " )
186
+ raise exc .NoSuchTableError (f' schema={ schema } , table={ table_name } ' )
188
187
189
- partitioned_columns = self ._get_columns (connection , f" { table_name } $partitions" , schema , ** kw )
188
+ partitioned_columns = self ._get_columns (connection , f' { table_name } $partitions' , schema , ** kw )
190
189
partition_index = dict (
191
- name = " partition" ,
190
+ name = ' partition' ,
192
191
column_names = [col ['name' ] for col in partitioned_columns ],
193
192
unique = False
194
193
)
@@ -224,9 +223,9 @@ def has_schema(self, connection: Connection, schema: str) -> bool:
224
223
225
224
def has_table (self , connection : Connection ,
226
225
table_name : str , schema : str = None ) -> bool :
227
- query = " SHOW TABLES"
226
+ query = ' SHOW TABLES'
228
227
if schema :
229
- query = f" { query } FROM { self .identifier_preparer .quote_identifier (schema )} "
228
+ query = f' { query } FROM { self .identifier_preparer .quote_identifier (schema )} '
230
229
query = f"{ query } LIKE '{ table_name } '"
231
230
try :
232
231
res = connection .execute (sql .text (query ))
@@ -247,11 +246,11 @@ def has_sequence(self, connection: Connection,
247
246
return False
248
247
249
248
def _get_server_version_info (self , connection : Connection ) -> Tuple [int , ...]:
250
- query = dedent ("""
249
+ query = dedent ('''
251
250
SELECT *
252
251
FROM system.runtime.nodes
253
252
WHERE coordinator = true AND state = 'active'
254
- """ ).strip ()
253
+ ''' ).strip ()
255
254
res = connection .execute (sql .text (query )).first ()
256
255
version = int (res .node_version )
257
256
return tuple ([version ])
@@ -285,11 +284,11 @@ def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level) -> None
285
284
dbapi_conn ._isolation_level = getattr (trino_dbapi .IsolationLevel , level )
286
285
287
286
def get_isolation_level (self , dbapi_conn : trino_dbapi .Connection ) -> str :
288
- level_names = [" AUTOCOMMIT" ,
289
- " READ_UNCOMMITTED" ,
290
- " READ_COMMITTED" ,
291
- " REPEATABLE_READ" ,
292
- " SERIALIZABLE" ]
287
+ level_names = [' AUTOCOMMIT' ,
288
+ ' READ_UNCOMMITTED' ,
289
+ ' READ_COMMITTED' ,
290
+ ' REPEATABLE_READ' ,
291
+ ' SERIALIZABLE' ]
293
292
return level_names [dbapi_conn .isolation_level ]
294
293
295
294
def _get_full_table (self , table_name : str , schema : str = None , quote : bool = True ) -> str :
0 commit comments