Skip to content

Commit 8eff6ad

Browse files
Add tests for .csv()
1 parent 2ae029f commit 8eff6ad

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

src/sql/magic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,25 @@ def execute(self, line, cell='', local_ns={}):
116116

117117
legal_sql_identifier = re.compile(r'^[A-Za-z0-9#_$]+')
118118
def _persist_dataframe(self, raw, conn, user_ns):
119+
"""Implements PERSIST, which writes a DataFrame to the RDBMS"""
119120
if not DataFrame:
120121
raise ImportError("Must `pip install pandas` to use DataFrames")
122+
123+
# Parse input to get name of DataFrame
121124
pieces = raw.split()
122125
if len(pieces) != 2:
123126
raise SyntaxError("Format: %sql [connection] persist <DataFrameName>")
124127
frame_name = pieces[1].strip(';')
128+
129+
# Get the DataFrame from the user namespace
125130
frame = eval(frame_name, user_ns)
126131
if not isinstance(frame, DataFrame) and not isinstance(frame, Series):
127132
raise TypeError('%s is not a Pandas DataFrame or Series' % frame_name)
133+
134+
# Make a suitable name for the resulting database table
128135
table_name = frame_name.lower()
129136
table_name = self.legal_sql_identifier.search(table_name).group(0)
137+
130138
frame.to_sql(table_name, conn.session.engine)
131139
return 'Persisted %s' % table_name
132140

src/sql/parse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44

55

66
def parse(cell, config):
7+
"""Separate input into (connection info, SQL statement)"""
8+
79
parts = [part.strip() for part in cell.split(None, 1)]
810
if not parts:
911
return {'connection': '', 'sql': ''}
12+
1013
if parts[0].startswith('[') and parts[0].endswith(']'):
1114
section = parts[0].lstrip('[').rstrip(']')
1215
parser = CP.ConfigParser()

src/tests/test_magic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from nose import with_setup
22
from sql.magic import SqlMagic
33
from textwrap import dedent
4+
import os.path
45
import re
6+
import tempfile
57

68
ip = get_ipython()
79

@@ -132,3 +134,26 @@ def test_autopandas():
132134
dframe = ip.run_cell("%sql SELECT * FROM test;")
133135
assert dframe.success
134136
assert dframe.result.name[0] == 'foo'
137+
138+
@with_setup(_setup, _teardown)
139+
def test_csv():
140+
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh
141+
result = ip.run_line_magic('sql', "sqlite:// SELECT * FROM test;")
142+
result = result.csv()
143+
for row in result.splitlines():
144+
assert row.count(',') == 1
145+
assert len(result.splitlines()) == 3
146+
147+
@with_setup(_setup, _teardown)
148+
def test_csv_to_file():
149+
ip.run_line_magic('config', "SqlMagic.autopandas = False") # uh-oh
150+
result = ip.run_line_magic('sql', "sqlite:// SELECT * FROM test;")
151+
with tempfile.TemporaryDirectory() as tempdir:
152+
fname = os.path.join(tempdir, 'test.csv')
153+
output = result.csv(fname)
154+
assert os.path.exists(output.file_path)
155+
with open(output.file_path) as csvfile:
156+
content = csvfile.read()
157+
for row in content.splitlines():
158+
assert row.count(',') == 1
159+
assert len(content.splitlines()) == 3

0 commit comments

Comments
 (0)