1
- import operator
2
- import csv
3
- import six
4
1
import codecs
2
+ import csv
3
+ import operator
5
4
import os .path
6
5
import re
6
+
7
+ import prettytable
8
+ import six
7
9
import sqlalchemy
8
10
import sqlparse
9
- import prettytable
11
+
12
+ from .column_guesser import ColumnGuesserMixin
13
+
10
14
try :
11
15
from pgspecial .main import PGSpecial
12
16
except ImportError :
13
17
PGSpecial = None
14
- from .column_guesser import ColumnGuesserMixin
15
18
16
19
17
20
def unduplicate_field_names (field_names ):
@@ -26,6 +29,7 @@ def unduplicate_field_names(field_names):
26
29
res .append (k )
27
30
return res
28
31
32
+
29
33
class UnicodeWriter (object ):
30
34
"""
31
35
A CSV writer which will write rows to CSV file "f",
@@ -41,19 +45,17 @@ def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds):
41
45
42
46
def writerow (self , row ):
43
47
if six .PY2 :
44
- _row = [s .encode ("utf-8" )
45
- if hasattr (s , "encode" )
46
- else s
48
+ _row = [s .encode ("utf-8" ) if hasattr (s , "encode" ) else s
47
49
for s in row ]
48
50
else :
49
51
_row = row
50
52
self .writer .writerow (_row )
51
53
# Fetch UTF-8 output from the queue ...
52
54
data = self .queue .getvalue ()
53
55
if six .PY2 :
54
- data = data .decode ("utf-8" )
55
- # ... and reencode it into the target encoding
56
- data = self .encoder .encode (data )
56
+ data = data .decode ("utf-8" )
57
+ # ... and reencode it into the target encoding
58
+ data = self .encoder .encode (data )
57
59
# write to the target stream
58
60
self .stream .write (data )
59
61
# empty queue
@@ -64,14 +66,20 @@ def writerows(self, rows):
64
66
for row in rows :
65
67
self .writerow (row )
66
68
69
+
67
70
class CsvResultDescriptor (object ):
68
71
"""Provides IPython Notebook-friendly output for the feedback after a ``.csv`` called."""
72
+
69
73
def __init__ (self , file_path ):
70
74
self .file_path = file_path
75
+
71
76
def __repr__ (self ):
72
- return 'CSV results at %s' % os .path .join (os .path .abspath ('.' ), self .file_path )
77
+ return 'CSV results at %s' % os .path .join (
78
+ os .path .abspath ('.' ), self .file_path )
79
+
73
80
def _repr_html_ (self ):
74
- return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' , self .file_path )
81
+ return '<a href="%s">CSV results</a>' % os .path .join ('.' , 'files' ,
82
+ self .file_path )
75
83
76
84
77
85
def _nonbreaking_spaces (match_obj ):
@@ -84,6 +92,7 @@ def _nonbreaking_spaces(match_obj):
84
92
spaces = ' ' * len (match_obj .group (2 ))
85
93
return '%s%s' % (match_obj .group (1 ), spaces )
86
94
95
+
87
96
_cell_with_spaces_pattern = re .compile (r'(<td>)( {2,})' )
88
97
89
98
@@ -93,6 +102,7 @@ class ResultSet(list, ColumnGuesserMixin):
93
102
94
103
Can access rows listwise, or by string value of leftmost column.
95
104
"""
105
+
96
106
def __init__ (self , sqlaproxy , sql , config ):
97
107
self .keys = sqlaproxy .keys ()
98
108
self .sql = sql
@@ -118,7 +128,8 @@ def _repr_html_(self):
118
128
self .pretty .add_rows (self )
119
129
result = self .pretty .get_html_string ()
120
130
result = _cell_with_spaces_pattern .sub (_nonbreaking_spaces , result )
121
- if self .config .displaylimit and len (self ) > self .config .displaylimit :
131
+ if self .config .displaylimit and len (
132
+ self ) > self .config .displaylimit :
122
133
result = '%s\n <span style="font-style:italic;text-align:center;">%d rows, truncated to displaylimit of %d</span>' % (
123
134
result , len (self ), self .config .displaylimit )
124
135
return result
@@ -143,6 +154,7 @@ def __getitem__(self, key):
143
154
if len (result ) > 1 :
144
155
raise KeyError ('%d results for "%s"' % (len (result ), key ))
145
156
return result [0 ]
157
+
146
158
def dict (self ):
147
159
"""Returns a single dict built from the result set
148
160
@@ -217,7 +229,7 @@ def plot(self, title=None, **kwargs):
217
229
plt .ylabel (ylabel )
218
230
return plot
219
231
220
- def bar (self , key_word_sep = " " , title = None , ** kwargs ):
232
+ def bar (self , key_word_sep = " " , title = None , ** kwargs ):
221
233
"""Generates a pylab bar plot from the result set.
222
234
223
235
``matplotlib`` must be installed, and in an
@@ -241,8 +253,7 @@ def bar(self, key_word_sep = " ", title=None, **kwargs):
241
253
self .guess_pie_columns (xlabel_sep = key_word_sep )
242
254
plot = plt .bar (range (len (self .ys [0 ])), self .ys [0 ], ** kwargs )
243
255
if self .xlabels :
244
- plt .xticks (range (len (self .xlabels )), self .xlabels ,
245
- rotation = 45 )
256
+ plt .xticks (range (len (self .xlabels )), self .xlabels , rotation = 45 )
246
257
plt .xlabel (self .xlabel )
247
258
plt .ylabel (self .ys [0 ].name )
248
259
return plot
@@ -251,7 +262,7 @@ def csv(self, filename=None, **format_params):
251
262
"""Generate results in comma-separated form. Write to ``filename`` if given.
252
263
Any other parameters will be passed on to csv.writer."""
253
264
if not self .pretty :
254
- return None # no results
265
+ return None # no results
255
266
self .pretty .add_rows (self )
256
267
if filename :
257
268
encoding = format_params .get ('encoding' , 'utf-8' )
@@ -279,17 +290,37 @@ def interpret_rowcount(rowcount):
279
290
result = '%d rows affected.' % rowcount
280
291
return result
281
292
293
+
282
294
class FakeResultProxy (object ):
283
295
"""A fake class that pretends to behave like the ResultProxy from
284
296
SqlAlchemy.
285
297
"""
298
+
286
299
def __init__ (self , cursor , headers ):
287
300
self .fetchall = cursor .fetchall
288
301
self .fetchmany = cursor .fetchmany
289
302
self .rowcount = cursor .rowcount
290
303
self .keys = lambda : headers
291
304
self .returns_rows = True
292
305
306
+ # some dialects have autocommit
307
+ # specific dialects break when commit is used:
308
+ _COMMIT_BLACKLIST_DIALECTS = ('mssql' , 'clickhouse' )
309
+
310
+
311
+ def _commit (conn , config ):
312
+ """Issues a commit, if appropriate for current config and dialect"""
313
+
314
+ _should_commit = config .autocommit and all (
315
+ dialect not in str (conn .dialect )
316
+ for dialect in _COMMIT_BLACKLIST_DIALECTS )
317
+
318
+ if _should_commit :
319
+ try :
320
+ conn .session .execute ('commit' )
321
+ except sqlalchemy .exc .OperationalError :
322
+ pass # not all engines can commit
323
+
293
324
294
325
def run (conn , sql , config , user_namespace ):
295
326
if sql .strip ():
@@ -302,18 +333,12 @@ def run(conn, sql, config, user_namespace):
302
333
raise ImportError ('pgspecial not installed' )
303
334
pgspecial = PGSpecial ()
304
335
_ , cur , headers , _ = pgspecial .execute (
305
- conn .session .connection .cursor (),
306
- statement )[0 ]
336
+ conn .session .connection .cursor (), statement )[0 ]
307
337
result = FakeResultProxy (cur , headers )
308
338
else :
309
339
txt = sqlalchemy .sql .text (statement )
310
340
result = conn .session .execute (txt , user_namespace )
311
- try :
312
- # mssql has autocommit
313
- if config .autocommit and ('mssql' not in str (conn .dialect )):
314
- conn .session .execute ('commit' )
315
- except sqlalchemy .exc .OperationalError :
316
- pass # not all engines can commit
341
+ _commit (conn = conn , config = config )
317
342
if result and config .feedback :
318
343
print (interpret_rowcount (result .rowcount ))
319
344
resultset = ResultSet (result , statement , config )
@@ -327,11 +352,10 @@ def run(conn, sql, config, user_namespace):
327
352
328
353
329
354
class PrettyTable (prettytable .PrettyTable ):
330
-
331
355
def __init__ (self , * args , ** kwargs ):
332
356
self .row_count = 0
333
357
self .displaylimit = None
334
- return super (PrettyTable , self ).__init__ (* args , ** kwargs )
358
+ return super (PrettyTable , self ).__init__ (* args , ** kwargs )
335
359
336
360
def add_rows (self , data ):
337
361
if self .row_count and (data .config .displaylimit == self .displaylimit ):
0 commit comments