6
6
import time
7
7
from concurrent .futures import ThreadPoolExecutor , as_completed
8
8
from datetime import datetime
9
- from typing import Any , Dict , Iterable , List , Optional , Tuple
9
+ from functools import wraps
10
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple
10
11
from uuid import UUID
11
12
12
13
import requests
@@ -780,14 +781,19 @@ def introspect_primary_key(
780
781
"""
781
782
primary_index_dict = inspector .get_pk_constraint (relation_name , schema_name )
782
783
783
- # MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
784
- # Sigh.
785
- pkey_name = primary_index_dict .get ('name' ) or '(unnamed primary key)'
784
+ # Athena dialect returns ... an empty _list_ instead of a dict, contrary to what
785
+ # https://docs.sqlalchemy.org/en/14/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_pk_constraint
786
+ # specifies for the return result from inspector.get_pk_constraint().
787
+ if isinstance (primary_index_dict , dict ):
788
+ # MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
789
+ # Sigh.
790
+ pkey_name = primary_index_dict .get ('name' ) or '(unnamed primary key)'
786
791
787
- if primary_index_dict ['constrained_columns' ]:
788
- return pkey_name , primary_index_dict ['constrained_columns' ]
789
- else :
790
- return None , []
792
+ if primary_index_dict ['constrained_columns' ]:
793
+ return pkey_name , primary_index_dict ['constrained_columns' ]
794
+
795
+ # No primary key to be returned.
796
+ return None , []
791
797
792
798
def introspect_columns (
793
799
self , inspector : SchemaStrippingInspector , schema_name : str , relation_name : str
@@ -1197,6 +1203,30 @@ def run_meta_command(
1197
1203
instance .do_run (invoker , args )
1198
1204
1199
1205
1206
+ def handle_not_implemented (default : Any = None , default_factory : Callable [[], Any ] = None ):
1207
+ """Decorator to catch NotImplementedError, return either default constant or
1208
+ whatever default_factory() returns."""
1209
+ assert default or default_factory , 'must provide one of default or default_factory'
1210
+ assert not (
1211
+ default and default_factory
1212
+ ), 'only provide one of either default or default_factory'
1213
+
1214
+ def wrapper (func ):
1215
+ @wraps (func )
1216
+ def wrapped (* args , ** kwargs ):
1217
+ try :
1218
+ return func (* args , ** kwargs )
1219
+ except NotImplementedError :
1220
+ if default_factory :
1221
+ return default_factory ()
1222
+ else :
1223
+ return default
1224
+
1225
+ return wrapped
1226
+
1227
+ return wrapper
1228
+
1229
+
1200
1230
class SchemaStrippingInspector :
1201
1231
"""Proxy implementation that removes 'schema.' prefixing from results of underlying
1202
1232
get_table_names() and get_view_names(). BigQuery dialect inspector seems to include
@@ -1218,6 +1248,7 @@ def get_schema_names(self) -> List[str]:
1218
1248
def get_columns (self , relation_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1219
1249
return self .underlying_inspector .get_columns (relation_name , schema = schema )
1220
1250
1251
+ @handle_not_implemented ('(unobtainable)' )
1221
1252
def get_view_definition (self , view_name : str , schema : Optional [str ] = None ) -> str :
1222
1253
return self .underlying_inspector .get_view_definition (view_name , schema = schema )
1223
1254
@@ -1227,20 +1258,16 @@ def get_pk_constraint(self, table_name: str, schema: Optional[str] = None) -> di
1227
1258
def get_foreign_keys (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1228
1259
return self .underlying_inspector .get_foreign_keys (table_name , schema = schema )
1229
1260
1261
+ @handle_not_implemented (default_factory = list )
1230
1262
def get_check_constraints (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1231
- try :
1232
- return self .underlying_inspector .get_check_constraints (table_name , schema = schema )
1233
- except NotImplementedError :
1234
- return []
1263
+ return self .underlying_inspector .get_check_constraints (table_name , schema = schema )
1235
1264
1236
1265
def get_indexes (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1237
1266
return self .underlying_inspector .get_indexes (table_name , schema = schema )
1238
1267
1268
+ @handle_not_implemented (default_factory = list )
1239
1269
def get_unique_constraints (self , table_name : str , schema : Optional [str ] = None ) -> List [dict ]:
1240
- try :
1241
- return self .underlying_inspector .get_unique_constraints (table_name , schema = schema )
1242
- except NotImplementedError :
1243
- return []
1270
+ return self .underlying_inspector .get_unique_constraints (table_name , schema = schema )
1244
1271
1245
1272
# Now the value-adding filtering methods.
1246
1273
def get_table_names (self , schema : Optional [str ] = None ) -> List [str ]:
0 commit comments