1
- import operator
2
- import csv
3
- import six
4
1
import codecs
2
+ import csv
3
+ import operator
5
4
import os .path
6
5
import re
6
+
7
+ import prettytable
8
+ import six
7
9
import sqlalchemy
8
10
import sqlparse
9
- import prettytable
10
- from pgspecial .main import PGSpecial
11
+
11
12
from .column_guesser import ColumnGuesserMixin
12
13
14
+ try :
15
+ from pgspecial .main import PGSpecial
16
+ except ImportError :
17
+ PGSpecial = None
18
+
13
19
14
20
def unduplicate_field_names (field_names ):
15
21
"""Append a number to duplicate field names to make them unique. """
@@ -23,6 +29,7 @@ def unduplicate_field_names(field_names):
23
29
res .append (k )
24
30
return res
25
31
32
+
26
33
class UnicodeWriter (object ):
27
34
"""
28
35
A CSV writer which will write rows to CSV file "f",
@@ -38,19 +45,17 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds):
38
45
39
46
def writerow (self , row ):
40
47
if six .PY2 :
41
- _row = [s .encode ("utf-8" )
42
- if hasattr (s , "encode" )
43
- else s
48
+ _row = [s .encode ("utf-8" ) if hasattr (s , "encode" ) else s
44
49
for s in row ]
45
50
else :
46
51
_row = row
47
52
self .writer .writerow (_row )
48
53
# Fetch UTF-8 output from the queue ...
49
54
data = self .queue .getvalue ()
50
55
if six .PY2 :
51
- data = data .decode ("utf-8" )
52
- # ... and reencode it into the target encoding
53
- data = self .encoder .encode (data )
56
+ data = data .decode ("utf-8" )
57
+ # ... and reencode it into the target encoding
58
+ data = self .encoder .encode (data )
54
59
# write to the target stream
55
60
self .stream .write (data )
56
61
# empty queue
@@ -61,14 +66,20 @@ def writerows(self, rows):
61
66
for row in rows :
62
67
self .writerow (row )
63
68
69
+
64
70
class CsvResultDescriptor (object ):
65
71
"""Provides IPython Notebook-friendly output for the feedback after a ``.csv`` called."""
72
+
66
73
def __init__ (self , file_path ):
67
74
self .file_path = file_path
75
+
68
76
def __repr__ (self ):
69
- return 'CSV results at %s' % os .path .join (os .path .abspath ('.' ), self .file_path )
77
+ return 'CSV results at %s' % os .path .join (
78
+ os .path .abspath ('.' ), self .file_path )
79
+
70
80
def _repr_html_ (self ):
71
- return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' , self .file_path )
81
+ return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' ,
82
+ self .file_path )
72
83
73
84
74
85
def _nonbreaking_spaces (match_obj ):
@@ -81,6 +92,7 @@ def _nonbreaking_spaces(match_obj):
81
92
spaces = ' ' * len (match_obj .group (2 ))
82
93
return '%s%s' % (match_obj .group (1 ), spaces )
83
94
95
+
84
96
_cell_with_spaces_pattern = re .compile (r'(<td>)( {2,})' )
85
97
86
98
@@ -90,6 +102,7 @@ class ResultSet(list, ColumnGuesserMixin):
90
102
91
103
Can access rows listwise, or by string value of leftmost column.
92
104
"""
105
+
93
106
def __init__ (self , sqlaproxy , sql , config ):
94
107
self .keys = sqlaproxy .keys ()
95
108
self .sql = sql
@@ -115,7 +128,8 @@ def _repr_html_(self):
115
128
self .pretty .add_rows (self )
116
129
result = self .pretty .get_html_string ()
117
130
result = _cell_with_spaces_pattern .sub (_nonbreaking_spaces , result )
118
- if self .config .displaylimit and len (self ) > self .config .displaylimit :
131
+ if self .config .displaylimit and len (
132
+ self ) > self .config .displaylimit :
119
133
result = '%s\n <span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
120
134
result , len (self ), self .config .displaylimit )
121
135
return result
@@ -140,6 +154,7 @@ def __getitem__(self, key):
140
154
if len (result ) > 1 :
141
155
raise KeyError ('%d results for "%s"' % (len (result ), key ))
142
156
return result [0 ]
157
+
143
158
def dict (self ):
144
159
"""Returns a single dict built from the result set
145
160
@@ -214,7 +229,7 @@ def plot(self, title=None, **kwargs):
214
229
plt .ylabel (ylabel )
215
230
return plot
216
231
217
- def bar (self , key_word_sep = " " , title = None , ** kwargs ):
232
+ def bar (self , key_word_sep = " " , title = None , ** kwargs ):
218
233
"""Generates a pylab bar plot from the result set.
219
234
220
235
``matplotlib`` must be installed, and in an
@@ -238,8 +253,7 @@ def bar(self, key_word_sep = " ", title=None, **kwargs):
238
253
self .guess_pie_columns (xlabel_sep = key_word_sep )
239
254
plot = plt .bar (range (len (self .ys [0 ])), self .ys [0 ], ** kwargs )
240
255
if self .xlabels :
241
- plt .xticks (range (len (self .xlabels )), self .xlabels ,
242
- rotation = 45 )
256
+ plt .xticks (range (len (self .xlabels )), self .xlabels , rotation = 45 )
243
257
plt .xlabel (self .xlabel )
244
258
plt .ylabel (self .ys [0 ].name )
245
259
return plot
@@ -248,7 +262,7 @@ def csv(self, filename=None, **format_params):
248
262
"""Generate results in comma-separated form. Write to ``filename`` if given.
249
263
Any other parameters will be passed on to csv.writer."""
250
264
if not self .pretty :
251
- return None # no results
265
+ return None # no results
252
266
self .pretty .add_rows (self )
253
267
if filename :
254
268
encoding = format_params .get ('encoding' , 'utf-8' )
@@ -276,17 +290,37 @@ def interpret_rowcount(rowcount):
276
290
result = '%d rows affected.' % rowcount
277
291
return result
278
292
293
+
279
294
class FakeResultProxy (object ):
280
295
"""A fake class that pretends to behave like the ResultProxy from
281
296
SqlAlchemy.
282
297
"""
298
+
283
299
def __init__ (self , cursor , headers ):
284
300
self .fetchall = cursor .fetchall
285
301
self .fetchmany = cursor .fetchmany
286
302
self .rowcount = cursor .rowcount
287
303
self .keys = lambda : headers
288
304
self .returns_rows = True
289
305
306
+ # some dialects have autocommit
307
+ # specific dialects break when commit is used:
308
+ _COMMIT_BLACKLIST_DIALECTS = ('mssql' , 'clickhouse' )
309
+
310
+
311
+ def _commit (conn , config ):
312
+ """Issues a commit, if appropriate for current config and dialect"""
313
+
314
+ _should_commit = config .autocommit and all (
315
+ dialect not in str (conn .dialect )
316
+ for dialect in _COMMIT_BLACKLIST_DIALECTS )
317
+
318
+ if _should_commit :
319
+ try :
320
+ conn .session .execute ('commit' )
321
+ except sqlalchemy .exc .OperationalError :
322
+ pass # not all engines can commit
323
+
290
324
291
325
def run (conn , sql , config , user_namespace ):
292
326
if sql .strip ():
@@ -295,20 +329,16 @@ def run(conn, sql, config, user_namespace):
295
329
if first_word == 'begin' :
296
330
raise Exception ("ipython_sql does not support transactions" )
297
331
if first_word .startswith ('\\ ' ) and 'postgres' in str (conn .dialect ):
332
+ if not PGSpecial :
333
+ raise ImportError ('pgspecial not installed' )
298
334
pgspecial = PGSpecial ()
299
335
_ , cur , headers , _ = pgspecial .execute (
300
- conn .session .connection .cursor (),
301
- statement )[0 ]
336
+ conn .session .connection .cursor (), statement )[0 ]
302
337
result = FakeResultProxy (cur , headers )
303
338
else :
304
339
txt = sqlalchemy .sql .text (statement )
305
340
result = conn .session .execute (txt , user_namespace )
306
- try :
307
- # mssql has autocommit
308
- if config .autocommit and ('mssql' not in str (conn .dialect )):
309
- conn .session .execute ('commit' )
310
- except sqlalchemy .exc .OperationalError :
311
- pass # not all engines can commit
341
+ _commit (conn = conn , config = config )
312
342
if result and config .feedback :
313
343
print (interpret_rowcount (result .rowcount ))
314
344
resultset = ResultSet (result , statement , config )
@@ -322,11 +352,10 @@ def run(conn, sql, config, user_namespace):
322
352
323
353
324
354
class PrettyTable (prettytable .PrettyTable ):
325
-
326
355
def __init__ (self , * args , ** kwargs ):
327
356
self .row_count = 0
328
357
self .displaylimit = None
329
- return super (PrettyTable , self ).__init__ (* args , ** kwargs )
358
+ return super (PrettyTable , self ).__init__ (* args , ** kwargs )
330
359
331
360
def add_rows (self , data ):
332
361
if self .row_count and (data .config .displaylimit == self .displaylimit ):
0 commit comments