Skip to content

Commit f4283c6

Browse files
Connect to $DATABASE_URL if no connection string given
1 parent 2c01a9d commit f4283c6

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

README.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ an existing connection by username@database
6464
======================
6565
Poet 733
6666
67+
If no connect string is supplied, ``5sql`` will provide a list of existing connections;
68+
however, if no connections have yet been made and the environment variable ``DATABASE_URL``
69+
is available, that will be used.
70+
6771
For secure access, you may dynamically access your credentials (e.g. from your system environment or `getpass.getpass`) to avoid storing your password in the notebook itself. Use the `$` before any variable to access it in your `%sql` command.
6872

6973
.. code-block:: python
@@ -250,10 +254,10 @@ specified) or in a file of the given name.
250254
PostgreSQL features
251255
-------------------
252256

253-
``psql``-style "backslash" `meta-commands`_ commands (``\d``, ``\dt``, etc.)
257+
``psql``-style "backslash" `meta-commands`_ commands (``\d``, ``\dt``, etc.)
254258
are provided by `PGSpecial`_.
255259

256-
.. _PGSpecial: https://pypi.python.org/pypi/pgspecial
260+
.. _PGSpecial: https://pypi.python.org/pypi/pgspecial
257261

258262
.. _meta-commands: https://www.postgresql.org/docs/9.6/static/app-psql.html#APP-PSQL-META-COMMANDS
259263

src/sql/connection.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import sqlalchemy
2+
import os
3+
4+
class ConnectionError(Exception):
5+
pass
6+
27

38
class Connection(object):
49
current = None
@@ -23,27 +28,45 @@ def __init__(self, connect_str=None):
2328
self.connections[self.name] = self
2429
self.connections[str(self.metadata.bind.url)] = self
2530
Connection.current = self
31+
2632
@classmethod
27-
def get(cls, descriptor):
28-
if isinstance(descriptor, Connection):
29-
cls.current = descriptor
30-
elif descriptor:
31-
conn = cls.connections.get(descriptor) or \
32-
cls.connections.get(descriptor.lower())
33-
if conn:
34-
cls.current = conn
33+
def set(cls, descriptor):
34+
"Sets the current database connection"
35+
36+
if descriptor:
37+
if isinstance(descriptor, Connection):
38+
cls.current = descriptor
3539
else:
36-
cls.current = Connection(descriptor)
37-
if cls.current:
38-
return cls.current
40+
existing = cls.connections.get(descriptor) or \
41+
cls.connections.get(descriptor.lower())
42+
cls.current = existing or Connection(descriptor)
3943
else:
40-
raise Exception(cls.tell_format())
44+
if cls.connections:
45+
print(cls.connection_list())
46+
else:
47+
if os.getenv('DATABASE_URL'):
48+
cls.current = Connection(os.getenv('DATABASE_URL'))
49+
else:
50+
raise ConnectionError('Environment variable $DATABASE_URL not set, and no connect string given.')
51+
return cls.current
52+
4153
@classmethod
4254
def assign_name(cls, engine):
43-
core_name = '%s@%s' % (engine.url.username, engine.url.database)
55+
core_name = '%s@%s' % (engine.url.username or '', engine.url.database)
4456
incrementer = 1
4557
name = core_name
4658
while name in cls.connections:
4759
name = '%s_%d' % (core_name, incrementer)
4860
incrementer += 1
4961
return name
62+
63+
@classmethod
64+
def connection_list(cls):
65+
result = []
66+
for key in sorted(cls.connections):
67+
if cls.connections[key] == cls.current:
68+
template = ' * {}'
69+
else:
70+
template = ' {}'
71+
result.append(template.format(key))
72+
return '\n'.join(result)

src/sql/magic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,14 @@ def execute(self, line, cell='', local_ns={}):
8080

8181
parsed = sql.parse.parse('%s\n%s' % (line, cell), self)
8282
flags = parsed['flags']
83-
conn = sql.connection.Connection.get(parsed['connection'])
83+
try:
84+
conn = sql.connection.Connection.set(parsed['connection'])
85+
except Exception as e:
86+
print(e)
87+
print(sql.connection.Connection.tell_format())
88+
return None
8489

85-
if flags['persist']:
90+
if flags.get('persist'):
8691
return self._persist_dataframe(parsed['sql'], conn, user_ns)
8792

8893
try:
@@ -107,7 +112,7 @@ def execute(self, line, cell='', local_ns={}):
107112
return None
108113
else:
109114

110-
if flags['result_var']:
115+
if flags.get('result_var'):
111116
result_var = flags['result_var']
112117
print("Returning data to local variable {}".format(result_var))
113118
self.shell.user_ns.update({result_var: result})

src/sql/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def parse(cell, config):
99

1010
parts = [part.strip() for part in cell.split(None, 1)]
1111
if not parts:
12-
return {'connection': '', 'sql': ''}
12+
return {'connection': '', 'sql': '', 'flags': {}}
1313
parts[0] = expandvars(parts[0]) # for environment variables
1414

1515
if parts[0].startswith('[') and parts[0].endswith(']'):

0 commit comments

Comments
 (0)