Skip to content

Commit 4e78ac4

Browse files
2 parents 8eff6ad + c5288db commit 4e78ac4

File tree

6 files changed

+93
-21
lines changed

6 files changed

+93
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ nosetests.xml
3333
.mr.developer.cfg
3434
.project
3535
.pydevproject
36+
37+
# Pycharm
38+
/.idea

src/sql/connection.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
class Connection(object):
44
current = None
55
connections = {}
6+
67
@classmethod
78
def tell_format(cls):
8-
return "Format: (postgresql|mysql)://username:password@hostname/dbname, or one of %s" \
9-
% str(cls.connections.keys())
9+
return """Connection info needed in SQLAlchemy format, example:
10+
postgresql://username:password@hostname/dbname
11+
or an existing connection: %s""" % str(cls.connections.keys())
12+
1013
def __init__(self, connect_str=None):
1114
try:
1215
engine = sqlalchemy.create_engine(connect_str)
1316
except: # TODO: bare except; but what's an ArgumentError?
1417
print(self.tell_format())
15-
raise
18+
raise
1619
self.dialect = engine.url.get_dialect()
1720
self.metadata = sqlalchemy.MetaData(bind=engine)
1821
self.name = self.assign_name(engine)
19-
self.session = engine.connect()
22+
self.session = engine.connect()
2023
self.connections[self.name] = self
2124
self.connections[str(self.metadata.bind.url)] = self
2225
Connection.current = self
@@ -26,7 +29,7 @@ def get(cls, descriptor):
2629
cls.current = descriptor
2730
elif descriptor:
2831
conn = cls.connections.get(descriptor) or \
29-
cls.connections.get(descriptor.lower())
32+
cls.connections.get(descriptor.lower())
3033
if conn:
3134
cls.current = conn
3235
else:

src/sql/magic.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def execute(self, line, cell='', local_ns={}):
7878
user_ns.update(local_ns)
7979

8080
parsed = sql.parse.parse('%s\n%s' % (line, cell), self)
81+
flags = parsed['flags']
8182
conn = sql.connection.Connection.get(parsed['connection'])
82-
first_word = parsed['sql'].split(None, 1)[:1]
83-
if first_word and first_word[0].lower() == 'persist':
83+
84+
if flags['persist']:
8485
return self._persist_dataframe(parsed['sql'], conn, user_ns)
8586

8687
try:
@@ -104,6 +105,13 @@ def execute(self, line, cell='', local_ns={}):
104105

105106
return None
106107
else:
108+
109+
if flags['result_var']:
110+
result_var = flags['result_var']
111+
print("Returning data to local variable {}".format(result_var))
112+
self.shell.user_ns.update({result_var: result})
113+
return None
114+
107115
#Return results into the default ipython _ variable
108116
return result
109117

@@ -120,13 +128,11 @@ def _persist_dataframe(self, raw, conn, user_ns):
120128
if not DataFrame:
121129
raise ImportError("Must `pip install pandas` to use DataFrames")
122130

123-
# Parse input to get name of DataFrame
124-
pieces = raw.split()
125-
if len(pieces) != 2:
126-
raise SyntaxError("Format: %sql [connection] persist <DataFrameName>")
127-
frame_name = pieces[1].strip(';')
131+
frame_name = raw.strip(';')
128132

129-
# Get the DataFrame from the user namespace
133+
# Get the DataFrame from the user namespace
134+
if not frame_name:
135+
raise SyntaxError('Syntax: %sql PERSIST <name_of_data_frame>')
130136
frame = eval(frame_name, user_ns)
131137
if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
132138
raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name)

src/sql/parse.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,27 @@ def parse(cell, config):
2727
else:
2828
connection = ''
2929
sql = cell
30+
flags, sql = parse_sql_flags(sql.strip())
3031
return {'connection': connection.strip(),
31-
'sql': sql.strip()}
32+
'sql': sql,
33+
'flags': flags}
34+
35+
36+
def parse_sql_flags(sql):
37+
words = sql.split()
38+
flags = {
39+
'persist': False,
40+
'result_var': None
41+
}
42+
if not words:
43+
return (flags, "")
44+
num_words = len(words)
45+
trimmed_sql = sql
46+
if words[0].lower() == 'persist':
47+
print("Persist parsed to True")
48+
flags['persist'] = True
49+
trimmed_sql = " ".join(words[1:])
50+
elif num_words >= 2 and words[1] == '<<':
51+
flags['result_var'] = words[0]
52+
trimmed_sql = " ".join(words[2:])
53+
return (flags, trimmed_sql.strip())

src/tests/test_magic.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from nose import with_setup
2+
from nose.tools import raises
23
from sql.magic import SqlMagic
34
from textwrap import dedent
45
import os.path
@@ -57,6 +58,16 @@ def test_multi_sql():
5758
""")
5859
assert 'Shakespeare' in str(result) and 'Brecht' in str(result)
5960

61+
@with_setup(_setup_writer, _teardown_writer)
62+
def test_result_var():
63+
ip.run_cell_magic('sql', '', """
64+
sqlite://
65+
x <<
66+
SELECT last_name FROM writer;
67+
""")
68+
result = ip.user_global_ns['x']
69+
assert 'Shakespeare' in str(result) and 'Brecht' in str(result)
70+
6071
@with_setup(_setup_writer, _teardown_writer)
6172
def test_access_results_by_keys():
6273
assert ip.run_line_magic('sql', "sqlite:// SELECT * FROM writer;")['William'] == (u'William', u'Shakespeare', 1616)
@@ -78,7 +89,6 @@ def test_autolimit():
7889
result = ip.run_line_magic('sql', "sqlite:// SELECT * FROM test;")
7990
assert len(result) == 1
8091

81-
8292
@with_setup(_setup, _teardown)
8393
def test_persist():
8494
ip.run_cell("results = %sql SELECT * FROM test;")
@@ -87,13 +97,37 @@ def test_persist():
8797
persisted = ip.run_line_magic('sql', 'SELECT * FROM results_dframe')
8898
assert 'foo' in str(persisted)
8999

100+
@raises(NameError)
101+
def test_persist_nonexistent_raises():
102+
ip.run_line_magic('sql', "sqlite://")
103+
ip.run_line_magic('sql', 'PERSIST no_such_dataframe')
104+
105+
@raises(TypeError)
106+
def test_persist_non_frame_raises():
107+
ip.run_cell("not_a_dataframe = 22")
108+
ip.run_line_magic('sql', "sqlite://")
109+
ip.run_line_magic('sql', 'PERSIST not_a_dataframe')
110+
111+
@raises(SyntaxError)
112+
def test_persist_bare():
113+
ip.run_line_magic('sql', "sqlite://")
114+
ip.run_line_magic('sql', 'PERSIST')
115+
90116
@with_setup(_setup_writer, _teardown_writer)
91-
def test_unnamed_persist():
117+
def test_persist_frame_at_its_creation():
92118
ip.run_cell("results = %sql SELECT * FROM writer;")
93119
ip.run_line_magic('sql', 'PERSIST results.DataFrame()')
94120
persisted = ip.run_line_magic('sql', 'SELECT * FROM results')
95121
assert 'Shakespeare' in str(persisted)
96122

123+
# TODO: support
124+
# @with_setup(_setup_writer, _teardown_writer)
125+
# def test_persist_with_connection_info():
126+
# ip.run_cell("results = %sql SELECT * FROM writer;")
127+
# ip.run_line_magic('sql', 'sqlite:// PERSIST results.DataFrame()')
128+
# persisted = ip.run_line_magic('sql', 'SELECT * FROM results')
129+
# assert 'Shakespeare' in str(persisted)
130+
97131
@with_setup(_setup_writer, _teardown_writer)
98132
def test_displaylimit():
99133
ip.run_line_magic('config', "SqlMagic.autolimit = 0")

src/tests/test_parse.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,28 @@
66
from IPython.config.configurable import Configurable
77

88
empty_config = Configurable()
9-
9+
default_flags = {'persist': False, 'result_var': None}
1010
def test_parse_no_sql():
1111
assert parse("will:longliveliz@localhost/shakes", empty_config) == \
1212
{'connection': "will:longliveliz@localhost/shakes",
13-
'sql': ''}
13+
'sql': '',
14+
'flags': default_flags}
1415

1516
def test_parse_with_sql():
1617
assert parse("postgresql://will:longliveliz@localhost/shakes SELECT * FROM work",
1718
empty_config) == \
1819
{'connection': "postgresql://will:longliveliz@localhost/shakes",
19-
'sql': 'SELECT * FROM work'}
20+
'sql': 'SELECT * FROM work',
21+
'flags': default_flags}
2022

2123
def test_parse_sql_only():
2224
assert parse("SELECT * FROM work", empty_config) == \
2325
{'connection': "",
24-
'sql': 'SELECT * FROM work'}
26+
'sql': 'SELECT * FROM work',
27+
'flags': default_flags}
2528

2629
def test_parse_postgresql_socket_connection():
2730
assert parse("postgresql:///shakes SELECT * FROM work", empty_config) == \
2831
{'connection': "postgresql:///shakes",
29-
'sql': 'SELECT * FROM work'}
32+
'sql': 'SELECT * FROM work',
33+
'flags': default_flags}

0 commit comments

Comments
 (0)