Skip to content

Commit 0db7044

Browse files
author
James Robinson
authored
Refactor and make getting at SQL connections easy from python (#102)
Refactor, expose "noteable.sql.get_sqla_connection() and get_sqla_engine()
1 parent f56bf2c commit 0db7044

File tree

10 files changed

+127
-135
lines changed

10 files changed

+127
-135
lines changed

noteable/__init__.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,22 @@
22

33
__version__ = pkg_resources.get_distribution("noteable").version
44

5-
from .data_loader import LOCAL_DB_CONN_HANDLE, NoteableDataLoaderMagic, get_db_connection
5+
from .data_loader import NoteableDataLoaderMagic
66
from .datasources import bootstrap_datasources
77
from .logging import configure_logging
88
from .ntbl import NTBLMagic
9+
from .sql.connection import bootstrap_duckdb
910
from .sql.magic import SqlMagic
10-
from .sql.run import add_commit_blacklist_dialect
1111

1212

1313
def load_ipython_extension(ipython):
14-
# Initialize any remote datasource connections
15-
bootstrap_datasources()
16-
17-
# Always prevent sql-magic from trying to autocommit bigquery,
18-
# for the legacy datasource support for Expel and whomever.
19-
add_commit_blacklist_dialect('bigquery')
14+
configure_logging(False, "INFO", "DEBUG")
2015

21-
# Initialize the noteable local (duck_db) database connection
22-
get_db_connection(LOCAL_DB_CONN_HANDLE)
16+
# Initialize any remote datasource connections.
17+
bootstrap_datasources()
2318

24-
configure_logging(False, "INFO", "DEBUG")
19+
# Initialize the noteable local (duck_db) database connection.
20+
bootstrap_duckdb()
2521

22+
# Register all of our magics.
2623
ipython.register_magics(NoteableDataLoaderMagic, NTBLMagic, SqlMagic)

noteable/data_loader.py

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import mimetypes
2-
from typing import Optional
32

43
import pandas as pd
54
from IPython.core.magic import Magics, line_cell_magic, magics_class
@@ -8,46 +7,12 @@
87
from traitlets import Bool, Int
98
from traitlets.config import Configurable
109

11-
from noteable.sql.connection import Connection
10+
from noteable.sql.connection import LOCAL_DB_CONN_HANDLE, get_db_connection
1211

1312
EXCEL_MIMETYPES = {
1413
"application/vnd.ms-excel", # .xls
1514
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", # .xlsx
1615
}
17-
LOCAL_DB_CONN_HANDLE = "@noteable"
18-
duckdb_location = "duckdb:///:memory:"
19-
20-
21-
def get_db_connection(sql_cell_handle_or_human_name: str) -> Optional['Connection']:
22-
"""Return the sql.connection.Connection corresponding to the requested
23-
datasource sql_cell_handle or human name.
24-
25-
If the cell handle happens to correspond to the 'local database' DuckDB database,
26-
then we will bootstrap it upon demand. Otherwise, try to find and return
27-
the connection.
28-
29-
Will return None if the given handle isn't @noteable and isn't present in
30-
the connections dict already (created after this kernel was launched?)
31-
"""
32-
if (
33-
sql_cell_handle_or_human_name == LOCAL_DB_CONN_HANDLE
34-
and sql_cell_handle_or_human_name not in Connection.connections
35-
):
36-
# Bootstrap the DuckDB database if asked and needed.
37-
return Connection.set(
38-
duckdb_location,
39-
human_name="Local Database",
40-
name=LOCAL_DB_CONN_HANDLE,
41-
)
42-
else:
43-
# If, say, they created the datasource *after* this kernel was launched, then
44-
# this will come up empty and the caller should handle gracefully.
45-
for conn in Connection.connections.values():
46-
if (
47-
conn.name == sql_cell_handle_or_human_name
48-
or conn.human_name == sql_cell_handle_or_human_name
49-
):
50-
return conn
5116

5217

5318
@magics_class
@@ -56,7 +21,6 @@ class NoteableDataLoaderMagic(Magics, Configurable):
5621
True, config=True, help="Return the first N rows from the loaded pandas dataframe"
5722
)
5823
display_example = Bool(True, config=True, help="Show example SQL query")
59-
display_connection_str = Bool(False, config=True, help="Show connection string after execute")
6024
pandas_limit = Int(10, config=True, help="The limit of rows to returns in the pandas dataframe")
6125

6226
@line_cell_magic("create_or_replace_data_view")
@@ -116,19 +80,13 @@ def execute(self, line="", cell=""):
11680
f"Could not find datasource identified by {args.connection!r}. Perhaps restart the kernel?"
11781
)
11882

119-
tmp_df.to_sql(tablename, conn.session, if_exists="replace", index=args.include_index)
120-
121-
if self.display_connection_str:
122-
print(f"Connect with: %sql {conn.name}")
83+
tmp_df.to_sql(
84+
tablename, conn.sqla_connection, if_exists="replace", index=args.include_index
85+
)
12386

12487
if self.display_example:
125-
if conn.human_name:
126-
noun = f'{conn.human_name!r}'
127-
else:
128-
# Hmm. "Legacy" created datasource. Err on the engine's dialect name?
129-
noun = conn._engine.dialect.name
13088
print(
131-
f"""Create a {noun} SQL cell and then input query. """
89+
f"""Create a {conn.human_name!r} SQL cell and then input query. """
13290
f"Example: 'SELECT * FROM \"{tablename}\" LIMIT 10'"
13391
)
13492

noteable/sql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .magic import *
1+
from .connection import get_sqla_connection, get_sqla_engine # noqa

noteable/sql/connection.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,16 @@
88
logger = structlog.get_logger(__name__)
99

1010

11-
class UnknownConnectionError(Exception):
12-
pass
13-
14-
15-
def rough_dict_get(dct, sought, default=None):
16-
"""
17-
Like dct.get(sought), but any key containing sought will do.
11+
LOCAL_DB_CONN_HANDLE = "@noteable"
12+
LOCAL_DB_CONN_NAME = "Local Database"
13+
DUCKDB_LOCATION = "duckdb:///:memory:"
1814

19-
If there is a `@` in sought, seek each piece separately.
20-
This lets `me@server` match `me:***@myserver/db`
21-
"""
2215

23-
sought = sought.split("@")
24-
for key, val in dct.items():
25-
if not any(s.lower() not in key.lower() for s in sought):
26-
return val
27-
return default
16+
class UnknownConnectionError(Exception):
17+
pass
2818

2919

30-
class Connection(object):
20+
class Connection:
3121
current = None
3222
connections: Dict[str, 'Connection'] = {}
3323
bootstrapping_failures: Dict[str, str] = {}
@@ -107,23 +97,30 @@ def __init__(self, connect_str=None, name=None, human_name=None, **create_engine
10797
self.metadata = sqlalchemy.MetaData(bind=self._engine)
10898
self.name = name or self.assign_name(self._engine)
10999
self.human_name = human_name
110-
self._session = None
100+
self._sqla_connection = None
111101
self.connections[name or repr(self.metadata.bind.url)] = self
112102

113103
Connection.current = self
114104

115105
@property
116-
def session(self) -> sqlalchemy.engine.base.Connection:
117-
"""Lazily connect to the database.
106+
def engine(self) -> sqlalchemy.engine.base.Engine:
107+
return self._engine
118108

119-
Despite the name, this is a SQLA Connection, not a Session. And 'Connection'
120-
is highly overused term around here.
121-
"""
109+
@property
110+
def sqla_connection(self) -> sqlalchemy.engine.base.Connection:
111+
"""Lazily connect to the database. Return a SQLA Connection object, or die trying."""
112+
113+
if not self._sqla_connection:
114+
self._sqla_connection = self._engine.connect()
122115

123-
if not self._session:
124-
self._session = self._engine.connect()
116+
return self._sqla_connection
125117

126-
return self._session
118+
def reset_connection_pool(self):
119+
"""Reset the SQLA connection pool, such as after an exception suspected to indicate
120+
a broken connection has been raised.
121+
"""
122+
self._engine.dispose()
123+
self._sqla_connection = None
127124

128125
@classmethod
129126
def set(
@@ -164,14 +161,23 @@ def connection_list(cls):
164161
result.append(template.format(engine_url.__repr__()))
165162
return "\n".join(result)
166163

164+
@classmethod
165+
def find(cls, name: str) -> Optional['Connection']:
166+
"""Find a connection by SQL cell handle or by human assigned name"""
167+
# TODO: Capt. Obvious says to double-register the instance by both of these keys
168+
# to then be able to do lookups properly in this dict?
169+
for c in cls.connections.values():
170+
if c.name == name or c.human_name == name:
171+
return c
172+
167173
@classmethod
168174
def get_engine(cls, name: str) -> Optional[Engine]:
169175
"""Return the SQLAlchemy Engine given either the sql_cell_handle or
170176
end-user assigned name for the connection.
171177
"""
172-
for c in cls.connections.values():
173-
if c.name == name or c.human_name == name:
174-
return c._engine
178+
maybe_conn = cls.find(name)
179+
if maybe_conn:
180+
return maybe_conn.engine
175181

176182
@classmethod
177183
def add_bootstrapping_failure(cls, name: str, human_name: Optional[str], error_message: str):
@@ -204,7 +210,56 @@ def _close(cls, descriptor):
204210
)
205211
cls.connections.pop(conn.name, None)
206212
cls.connections.pop(str(conn.metadata.bind.url), None)
207-
conn.session.close()
213+
conn.sqla_connection.close()
208214

209215
def close(self):
210216
self.__class__._close(self)
217+
218+
219+
def rough_dict_get(dct, sought, default=None):
220+
"""
221+
Like dct.get(sought), but any key containing sought will do.
222+
223+
If there is a `@` in sought, seek each piece separately.
224+
This lets `me@server` match `me:***@myserver/db`
225+
"""
226+
227+
sought = sought.split("@")
228+
for key, val in dct.items():
229+
if not any(s.lower() not in key.lower() for s in sought):
230+
return val
231+
return default
232+
233+
234+
def get_db_connection(name_or_handle: str) -> Optional[Connection]:
235+
"""Return the noteable.sql.connection.Connection corresponding to the requested
236+
datasource a name or handle.
237+
238+
Will return None if the given handle isn't present in
239+
the connections dict already (created after this kernel was launched?)
240+
"""
241+
return Connection.find(name_or_handle)
242+
243+
244+
def get_sqla_connection(name_or_handle: str) -> Optional[sqlalchemy.engine.base.Connection]:
245+
"""Return a SQLAlchemy connection given a name or handle
246+
Returns None if cannot find by this string.
247+
"""
248+
nconn = get_db_connection(name_or_handle)
249+
if nconn:
250+
return nconn.sqla_connection
251+
252+
253+
def get_sqla_engine(name_or_handle: str) -> Optional[Engine]:
254+
"""Return a SQLAlchemy Engine given a name or handle.
255+
Returns None if cannot find by this string.
256+
"""
257+
return Connection.get_engine(name_or_handle)
258+
259+
260+
def bootstrap_duckdb():
261+
Connection.set(
262+
DUCKDB_LOCATION,
263+
human_name=LOCAL_DB_CONN_NAME,
264+
name=LOCAL_DB_CONN_HANDLE,
265+
)

noteable/sql/magic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def execute(self, line="", cell="", local_ns={}): # noqa: C901
149149
OperationalError,
150150
MetaCommandException,
151151
) as e:
152-
153152
# Normal syntax errors, missing table, etc. should come back as
154153
# ProgrammingError. And the rest indicate something fundamentally
155154
# broken at the DBAPI layer.
@@ -185,8 +184,8 @@ def execute(self, line="", cell="", local_ns={}): # noqa: C901
185184
#
186185
# "Restart Kernel" is too big of a hammer here.
187186
#
188-
conn._engine.dispose()
189-
conn._session = None
187+
conn.reset_connection_pool()
188+
190189
eprint(
191190
"Encoutered the following unexpected exception while trying to run the statement."
192191
" Closed the connection just to be safe. Re-run the cell to try again!\n\n"

noteable/sql/run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _commit(conn, config):
107107

108108
if _should_commit:
109109
try:
110-
conn.session.execute("commit")
110+
conn.sqla_connection.execute("commit")
111111
except sqlalchemy.exc.OperationalError:
112112
pass # not all engines can commit
113113

@@ -140,7 +140,6 @@ def _commit(conn, config):
140140

141141

142142
def run(conn, sql, config, user_namespace, skip_boxing_scalar_result: bool):
143-
144143
if sql.strip():
145144
for statement in sqlparse.split(sql):
146145
first_word = sql.strip().split()[0].lower()
@@ -154,7 +153,7 @@ def run(conn, sql, config, user_namespace, skip_boxing_scalar_result: bool):
154153
bind_dict = {str(idx + 1): elem for (idx, elem) in enumerate(bind_list)}
155154

156155
txt = sqlalchemy.sql.text(query)
157-
result = conn.session.execute(txt, bind_dict)
156+
result = conn.sqla_connection.execute(txt, bind_dict)
158157

159158
_commit(conn=conn, config=config)
160159

tests/conftest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
FileProgressUpdateContent,
1919
FileProgressUpdateMessage,
2020
)
21-
from noteable.sql.connection import Connection
21+
from noteable.sql.connection import Connection, bootstrap_duckdb
2222
from noteable.sql.magic import SqlMagic
2323
from noteable.sql.run import add_commit_blacklist_dialect
2424

@@ -117,6 +117,14 @@ def with_empty_connections() -> None:
117117
Connection.connections = preexisting_connections
118118

119119

120+
@pytest.fixture
121+
def with_duckdb_bootstrapped(with_empty_connections) -> None:
122+
# Normal magics bootstrapping will leave us with DuckDB connection populated.
123+
bootstrap_duckdb()
124+
125+
yield
126+
127+
120128
@pytest.fixture
121129
def ipython_shell() -> InteractiveShell:
122130
return InteractiveShell()

0 commit comments

Comments
 (0)