Skip to content

Commit 0fbeffe

Browse files
author
James Robinson
authored
Harden versus Athena dialect foibles. (#95)
1 parent 56f4296 commit 0fbeffe

File tree

2 files changed

+81
-16
lines changed

2 files changed

+81
-16
lines changed

noteable_magics/sql/meta_commands.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import time
77
from concurrent.futures import ThreadPoolExecutor, as_completed
88
from datetime import datetime
9-
from typing import Any, Dict, Iterable, List, Optional, Tuple
9+
from functools import wraps
10+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
1011
from uuid import UUID
1112

1213
import requests
@@ -780,14 +781,19 @@ def introspect_primary_key(
780781
"""
781782
primary_index_dict = inspector.get_pk_constraint(relation_name, schema_name)
782783

783-
# MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
784-
# Sigh.
785-
pkey_name = primary_index_dict.get('name') or '(unnamed primary key)'
784+
# Athena dialect returns ... an empty _list_ instead of a dict, contrary to what
785+
# https://docs.sqlalchemy.org/en/14/core/reflection.html#sqlalchemy.engine.reflection.Inspector.get_pk_constraint
786+
# specifies for the return result from inspector.get_pk_constraint().
787+
if isinstance(primary_index_dict, dict):
788+
# MySQL at least can have unnamed primary keys. The returned dict will have 'name' -> None.
789+
# Sigh.
790+
pkey_name = primary_index_dict.get('name') or '(unnamed primary key)'
786791

787-
if primary_index_dict['constrained_columns']:
788-
return pkey_name, primary_index_dict['constrained_columns']
789-
else:
790-
return None, []
792+
if primary_index_dict['constrained_columns']:
793+
return pkey_name, primary_index_dict['constrained_columns']
794+
795+
# No primary key to be returned.
796+
return None, []
791797

792798
def introspect_columns(
793799
self, inspector: SchemaStrippingInspector, schema_name: str, relation_name: str
@@ -1197,6 +1203,30 @@ def run_meta_command(
11971203
instance.do_run(invoker, args)
11981204

11991205

1206+
def handle_not_implemented(default: Any = None, default_factory: Callable[[], Any] = None):
1207+
"""Decorator to catch NotImplementedError, return either default constant or
1208+
whatever default_factory() returns."""
1209+
assert default or default_factory, 'must provide one of default or default_factory'
1210+
assert not (
1211+
default and default_factory
1212+
), 'only provide one of either default or default_factory'
1213+
1214+
def wrapper(func):
1215+
@wraps(func)
1216+
def wrapped(*args, **kwargs):
1217+
try:
1218+
return func(*args, **kwargs)
1219+
except NotImplementedError:
1220+
if default_factory:
1221+
return default_factory()
1222+
else:
1223+
return default
1224+
1225+
return wrapped
1226+
1227+
return wrapper
1228+
1229+
12001230
class SchemaStrippingInspector:
12011231
"""Proxy implementation that removes 'schema.' prefixing from results of underlying
12021232
get_table_names() and get_view_names(). BigQuery dialect inspector seems to include
@@ -1218,6 +1248,7 @@ def get_schema_names(self) -> List[str]:
12181248
def get_columns(self, relation_name: str, schema: Optional[str] = None) -> List[dict]:
12191249
return self.underlying_inspector.get_columns(relation_name, schema=schema)
12201250

1251+
@handle_not_implemented('(unobtainable)')
12211252
def get_view_definition(self, view_name: str, schema: Optional[str] = None) -> str:
12221253
return self.underlying_inspector.get_view_definition(view_name, schema=schema)
12231254

@@ -1227,20 +1258,16 @@ def get_pk_constraint(self, table_name: str, schema: Optional[str] = None) -> di
12271258
def get_foreign_keys(self, table_name: str, schema: Optional[str] = None) -> List[dict]:
12281259
return self.underlying_inspector.get_foreign_keys(table_name, schema=schema)
12291260

1261+
@handle_not_implemented(default_factory=list)
12301262
def get_check_constraints(self, table_name: str, schema: Optional[str] = None) -> List[dict]:
1231-
try:
1232-
return self.underlying_inspector.get_check_constraints(table_name, schema=schema)
1233-
except NotImplementedError:
1234-
return []
1263+
return self.underlying_inspector.get_check_constraints(table_name, schema=schema)
12351264

12361265
def get_indexes(self, table_name: str, schema: Optional[str] = None) -> List[dict]:
12371266
return self.underlying_inspector.get_indexes(table_name, schema=schema)
12381267

1268+
@handle_not_implemented(default_factory=list)
12391269
def get_unique_constraints(self, table_name: str, schema: Optional[str] = None) -> List[dict]:
1240-
try:
1241-
return self.underlying_inspector.get_unique_constraints(table_name, schema=schema)
1242-
except NotImplementedError:
1243-
return []
1270+
return self.underlying_inspector.get_unique_constraints(table_name, schema=schema)
12441271

12451272
# Now the value-adding filtering methods.
12461273
def get_table_names(self, schema: Optional[str] = None) -> List[str]:

tests/test_sql_magic_meta_commands.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SchemaStrippingInspector,
1818
_all_command_classes,
1919
convert_relation_glob_to_regex,
20+
handle_not_implemented,
2021
parse_schema_and_relation_glob,
2122
)
2223
from tests.conftest import COCKROACH_HANDLE, COCKROACH_UUID, KNOWN_TABLES, KNOWN_TABLES_AND_KINDS
@@ -1032,3 +1033,40 @@ def test_convert_relation_glob_to_regex(
10321033
inp: str, imply_prefix, expected_result: Tuple[Optional[str], Optional[str]], mocker
10331034
):
10341035
assert convert_relation_glob_to_regex(inp, imply_prefix=imply_prefix) == expected_result
1036+
1037+
1038+
class TestHandleNotImplemented:
1039+
def test_returns_underlying_when_implemented(self):
1040+
@handle_not_implemented(default='no')
1041+
def func():
1042+
return 12
1043+
1044+
assert func() == 12
1045+
1046+
def test_returns_default_when_not_implemented(self):
1047+
@handle_not_implemented(default='no')
1048+
def func():
1049+
raise NotImplementedError
1050+
1051+
assert func() == 'no'
1052+
1053+
def test_returns_default_factory_when_not_implemented(self):
1054+
@handle_not_implemented(default_factory=list)
1055+
def func():
1056+
raise NotImplementedError
1057+
1058+
assert func() == []
1059+
1060+
def test_hates_both_default_and_default_factory(self):
1061+
with pytest.raises(AssertionError, match='only provide one'):
1062+
1063+
@handle_not_implemented(default='no', default_factory=list)
1064+
def func():
1065+
raise NotImplementedError
1066+
1067+
def test_requires_either_default_or_default_factory(self):
1068+
with pytest.raises(AssertionError, match='must provide one of'):
1069+
1070+
@handle_not_implemented()
1071+
def func():
1072+
raise NotImplementedError

0 commit comments

Comments
 (0)