Skip to content

Commit 609ba2e

Browse files
Merge branch 'feriat-patch-1'
2 parents 207a5b1 + c8271e0 commit 609ba2e

File tree

4 files changed

+225
-175
lines changed

4 files changed

+225
-175
lines changed

run_tests.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/bash
2-
ipython -c "import nose; nose.run()"
3-
# Insert breakpoints with `from nose.tools import set_trace; set_trace()`
2+
python -c "import pytest; pytest.main(['.', '-x', '--pdb'])"
3+
# Insert breakpoints with `import pytest; pytest.set_trace()`

src/sql/run.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
import operator
2-
import csv
3-
import six
41
import codecs
2+
import csv
3+
import operator
54
import os.path
65
import re
6+
7+
import prettytable
8+
import six
79
import sqlalchemy
810
import sqlparse
9-
import prettytable
11+
12+
from .column_guesser import ColumnGuesserMixin
13+
1014
try:
1115
from pgspecial.main import PGSpecial
1216
except ImportError:
1317
PGSpecial = None
14-
from .column_guesser import ColumnGuesserMixin
1518

1619

1720
def unduplicate_field_names(field_names):
@@ -26,6 +29,7 @@ def unduplicate_field_names(field_names):
2629
res.append(k)
2730
return res
2831

32+
2933
class UnicodeWriter(object):
3034
"""
3135
A CSV writer which will write rows to CSV file "f",
@@ -41,19 +45,17 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds):
4145

4246
def writerow(self, row):
4347
if six.PY2:
44-
_row = [s.encode("utf-8")
45-
if hasattr(s, "encode")
46-
else s
48+
_row = [s.encode("utf-8") if hasattr(s, "encode") else s
4749
for s in row]
4850
else:
4951
_row = row
5052
self.writer.writerow(_row)
5153
# Fetch UTF-8 output from the queue ...
5254
data = self.queue.getvalue()
5355
if six.PY2:
54-
data = data.decode("utf-8")
55-
# ... and reencode it into the target encoding
56-
data = self.encoder.encode(data)
56+
data = data.decode("utf-8")
57+
# ... and reencode it into the target encoding
58+
data = self.encoder.encode(data)
5759
# write to the target stream
5860
self.stream.write(data)
5961
# empty queue
@@ -64,14 +66,20 @@ def writerows(self, rows):
6466
for row in rows:
6567
self.writerow(row)
6668

69+
6770
class CsvResultDescriptor(object):
6871
"""Provides IPython Notebook-friendly output for the feedback after a ``.csv`` called."""
72+
6973
def __init__(self, file_path):
7074
self.file_path = file_path
75+
7176
def __repr__(self):
72-
return 'CSV results at %s' % os.path.join(os.path.abspath('.'), self.file_path)
77+
return 'CSV results at %s' % os.path.join(
78+
os.path.abspath('.'), self.file_path)
79+
7380
def _repr_html_(self):
74-
return '<a href="%s">CSV results</a>' % os.path.join('.', 'files', self.file_path)
81+
return '<a href="%s">CSV results</a>' % os.path.join('.', 'files',
82+
self.file_path)
7583

7684

7785
def _nonbreaking_spaces(match_obj):
@@ -84,6 +92,7 @@ def _nonbreaking_spaces(match_obj):
8492
spaces = '&nbsp;' * len(match_obj.group(2))
8593
return '%s%s' % (match_obj.group(1), spaces)
8694

95+
8796
_cell_with_spaces_pattern = re.compile(r'(<td>)( {2,})')
8897

8998

@@ -93,6 +102,7 @@ class ResultSet(list, ColumnGuesserMixin):
93102
94103
Can access rows listwise, or by string value of leftmost column.
95104
"""
105+
96106
def __init__(self, sqlaproxy, sql, config):
97107
self.keys = sqlaproxy.keys()
98108
self.sql = sql
@@ -118,7 +128,8 @@ def _repr_html_(self):
118128
self.pretty.add_rows(self)
119129
result = self.pretty.get_html_string()
120130
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
121-
if self.config.displaylimit and len(self) > self.config.displaylimit:
131+
if self.config.displaylimit and len(
132+
self) > self.config.displaylimit:
122133
result = '%s\n<span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
123134
result, len(self), self.config.displaylimit)
124135
return result
@@ -143,6 +154,7 @@ def __getitem__(self, key):
143154
if len(result) > 1:
144155
raise KeyError('%d results for "%s"' % (len(result), key))
145156
return result[0]
157+
146158
def dict(self):
147159
"""Returns a single dict built from the result set
148160
@@ -217,7 +229,7 @@ def plot(self, title=None, **kwargs):
217229
plt.ylabel(ylabel)
218230
return plot
219231

220-
def bar(self, key_word_sep = " ", title=None, **kwargs):
232+
def bar(self, key_word_sep=" ", title=None, **kwargs):
221233
"""Generates a pylab bar plot from the result set.
222234
223235
``matplotlib`` must be installed, and in an
@@ -241,8 +253,7 @@ def bar(self, key_word_sep = " ", title=None, **kwargs):
241253
self.guess_pie_columns(xlabel_sep=key_word_sep)
242254
plot = plt.bar(range(len(self.ys[0])), self.ys[0], **kwargs)
243255
if self.xlabels:
244-
plt.xticks(range(len(self.xlabels)), self.xlabels,
245-
rotation=45)
256+
plt.xticks(range(len(self.xlabels)), self.xlabels, rotation=45)
246257
plt.xlabel(self.xlabel)
247258
plt.ylabel(self.ys[0].name)
248259
return plot
@@ -251,7 +262,7 @@ def csv(self, filename=None, **format_params):
251262
"""Generate results in comma-separated form. Write to ``filename`` if given.
252263
Any other parameters will be passed on to csv.writer."""
253264
if not self.pretty:
254-
return None # no results
265+
return None # no results
255266
self.pretty.add_rows(self)
256267
if filename:
257268
encoding = format_params.get('encoding', 'utf-8')
@@ -279,17 +290,37 @@ def interpret_rowcount(rowcount):
279290
result = '%d rows affected.' % rowcount
280291
return result
281292

293+
282294
class FakeResultProxy(object):
283295
"""A fake class that pretends to behave like the ResultProxy from
284296
SqlAlchemy.
285297
"""
298+
286299
def __init__(self, cursor, headers):
287300
self.fetchall = cursor.fetchall
288301
self.fetchmany = cursor.fetchmany
289302
self.rowcount = cursor.rowcount
290303
self.keys = lambda: headers
291304
self.returns_rows = True
292305

306+
# some dialects have autocommit
307+
# specific dialects break when commit is used:
308+
_COMMIT_BLACKLIST_DIALECTS = ('mssql', 'clickhouse')
309+
310+
311+
def _commit(conn, config):
312+
"""Issues a commit, if appropriate for current config and dialect"""
313+
314+
_should_commit = config.autocommit and all(
315+
dialect not in str(conn.dialect)
316+
for dialect in _COMMIT_BLACKLIST_DIALECTS)
317+
318+
if _should_commit:
319+
try:
320+
conn.session.execute('commit')
321+
except sqlalchemy.exc.OperationalError:
322+
pass # not all engines can commit
323+
293324

294325
def run(conn, sql, config, user_namespace):
295326
if sql.strip():
@@ -302,18 +333,12 @@ def run(conn, sql, config, user_namespace):
302333
raise ImportError('pgspecial not installed')
303334
pgspecial = PGSpecial()
304335
_, cur, headers, _ = pgspecial.execute(
305-
conn.session.connection.cursor(),
306-
statement)[0]
336+
conn.session.connection.cursor(), statement)[0]
307337
result = FakeResultProxy(cur, headers)
308338
else:
309339
txt = sqlalchemy.sql.text(statement)
310340
result = conn.session.execute(txt, user_namespace)
311-
try:
312-
# mssql has autocommit
313-
if config.autocommit and ('mssql' not in str(conn.dialect)):
314-
conn.session.execute('commit')
315-
except sqlalchemy.exc.OperationalError:
316-
pass # not all engines can commit
341+
_commit(conn=conn, config=config)
317342
if result and config.feedback:
318343
print(interpret_rowcount(result.rowcount))
319344
resultset = ResultSet(result, statement, config)
@@ -327,11 +352,10 @@ def run(conn, sql, config, user_namespace):
327352

328353

329354
class PrettyTable(prettytable.PrettyTable):
330-
331355
def __init__(self, *args, **kwargs):
332356
self.row_count = 0
333357
self.displaylimit = None
334-
return super(PrettyTable, self).__init__(*args, **kwargs)
358+
return super(PrettyTable, self).__init__(*args, **kwargs)
335359

336360
def add_rows(self, data):
337361
if self.row_count and (data.config.displaylimit == self.displaylimit):

src/tests/test_column_guesser.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,53 @@
11
import re
22
import sys
3-
from nose.tools import with_setup, raises
3+
4+
import pytest
5+
46
from sql.magic import SqlMagic
57

68
ip = get_ipython()
79

10+
811
class SqlEnv(object):
912
def __init__(self, connectstr):
1013
self.connectstr = connectstr
14+
1115
def query(self, txt):
1216
return ip.run_line_magic('sql', "%s %s" % (self.connectstr, txt))
1317

18+
1419
sql_env = SqlEnv('sqlite://')
1520

16-
def setup():
21+
22+
@pytest.fixture
23+
def tbl():
1724
sqlmagic = SqlMagic(shell=ip)
1825
ip.register_magics(sqlmagic)
1926
creator = """
2027
DROP TABLE IF EXISTS manycoltbl;
21-
CREATE TABLE manycoltbl
28+
CREATE TABLE manycoltbl
2229
(name TEXT, y1 REAL, y2 REAL, name2 TEXT, y3 INT);
23-
INSERT INTO manycoltbl VALUES
30+
INSERT INTO manycoltbl VALUES
2431
('r1-txt1', 1.01, 1.02, 'r1-txt2', 1.04);
25-
INSERT INTO manycoltbl VALUES
32+
INSERT INTO manycoltbl VALUES
2633
('r2-txt1', 2.01, 2.02, 'r2-txt2', 2.04);
2734
INSERT INTO manycoltbl VALUES ('r3-txt1', 3.01, 3.02, 'r3-txt2', 3.04);
2835
"""
2936
for qry in creator.split(";"):
3037
sql_env.query(qry)
31-
32-
def teardown():
38+
yield
3339
sql_env.query("DROP TABLE manycoltbl")
3440

41+
3542
class Harness(object):
3643
def run_query(self):
3744
return sql_env.query(self.query)
3845

46+
3947
class TestOneNum(Harness):
4048
query = "SELECT y1 FROM manycoltbl"
41-
42-
@with_setup(setup, teardown)
43-
def test_pie(self):
49+
50+
def test_pie(self, tbl):
4451
results = self.run_query()
4552
results.guess_pie_columns(xlabel_sep="//")
4653
assert results.ys[0].is_quantity
@@ -49,28 +56,26 @@ def test_pie(self):
4956
assert results.xlabels == []
5057
assert results.xlabel == ''
5158

52-
@with_setup(setup, teardown)
53-
def test_plot(self):
59+
def test_plot(self, tbl):
5460
results = self.run_query()
5561
results.guess_plot_columns()
5662
assert results.ys == [[1.01, 2.01, 3.01]]
5763
assert results.x == []
5864
assert results.x.name == ''
5965

66+
6067
class TestOneStrOneNum(Harness):
6168
query = "SELECT name, y1 FROM manycoltbl"
62-
63-
@with_setup(setup, teardown)
64-
def test_pie(self):
69+
70+
def test_pie(self, tbl):
6571
results = self.run_query()
6672
results.guess_pie_columns(xlabel_sep="//")
6773
assert results.ys[0].is_quantity
6874
assert results.ys == [[1.01, 2.01, 3.01]]
6975
assert results.xlabels == ['r1-txt1', 'r2-txt1', 'r3-txt1']
7076
assert results.xlabel == 'name'
7177

72-
@with_setup(setup, teardown)
73-
def test_plot(self):
78+
def test_plot(self, tbl):
7479
results = self.run_query()
7580
results.guess_plot_columns()
7681
assert results.ys == [[1.01, 2.01, 3.01]]
@@ -79,20 +84,19 @@ def test_plot(self):
7984

8085
class TestTwoStrTwoNum(Harness):
8186
query = "SELECT name2, y3, name, y1 FROM manycoltbl"
82-
83-
@with_setup(setup, teardown)
84-
def test_pie(self):
87+
88+
def test_pie(self, tbl):
8589
results = self.run_query()
8690
results.guess_pie_columns(xlabel_sep="//")
8791
assert results.ys[0].is_quantity
8892
assert results.ys == [[1.01, 2.01, 3.01]]
89-
assert results.xlabels == ['r1-txt2//1.04//r1-txt1',
90-
'r2-txt2//2.04//r2-txt1',
91-
'r3-txt2//3.04//r3-txt1']
93+
assert results.xlabels == [
94+
'r1-txt2//1.04//r1-txt1', 'r2-txt2//2.04//r2-txt1',
95+
'r3-txt2//3.04//r3-txt1'
96+
]
9297
assert results.xlabel == 'name2, y3, name'
9398

94-
@with_setup(setup, teardown)
95-
def test_plot(self):
99+
def test_plot(self, tbl):
96100
results = self.run_query()
97101
results.guess_plot_columns()
98102
assert results.ys == [[1.01, 2.01, 3.01]]
@@ -101,21 +105,19 @@ def test_plot(self):
101105

102106
class TestTwoStrThreeNum(Harness):
103107
query = "SELECT name, y1, name2, y2, y3 FROM manycoltbl"
104-
105-
@with_setup(setup, teardown)
106-
def test_pie(self):
108+
109+
def test_pie(self, tbl):
107110
results = self.run_query()
108111
results.guess_pie_columns(xlabel_sep="//")
109112
assert results.ys[0].is_quantity
110113
assert results.ys == [[1.04, 2.04, 3.04]]
111-
assert results.xlabels == ['r1-txt1//1.01//r1-txt2//1.02',
112-
'r2-txt1//2.01//r2-txt2//2.02',
113-
'r3-txt1//3.01//r3-txt2//3.02']
114+
assert results.xlabels == [
115+
'r1-txt1//1.01//r1-txt2//1.02', 'r2-txt1//2.01//r2-txt2//2.02',
116+
'r3-txt1//3.01//r3-txt2//3.02'
117+
]
114118

115-
@with_setup(setup, teardown)
116-
def test_plot(self):
119+
def test_plot(self, tbl):
117120
results = self.run_query()
118121
results.guess_plot_columns()
119122
assert results.ys == [[1.02, 2.02, 3.02], [1.04, 2.04, 3.04]]
120123
assert results.x == [1.01, 2.01, 3.01]
121-

0 commit comments

Comments
 (0)