|
9 | 9 | import sqlalchemy
|
10 | 10 | import sqlparse
|
11 | 11 | import prettytable
|
| 12 | +from pgspecial.main import PGSpecial |
12 | 13 | from .column_guesser import ColumnGuesserMixin
|
13 | 14 |
|
14 | 15 |
|
@@ -141,7 +142,7 @@ def __getitem__(self, key):
|
141 | 142 | return result[0]
|
142 | 143 | def dict(self):
|
143 | 144 | """Returns a single dict built from the result set
|
144 |
| - |
| 145 | +
|
145 | 146 | Keys are column names; values are a tuple"""
|
146 | 147 | return dict(zip(self.keys, zip(*self)))
|
147 | 148 |
|
@@ -274,19 +275,38 @@ def interpret_rowcount(rowcount):
|
274 | 275 | result = '%d rows affected.' % rowcount
|
275 | 276 | return result
|
276 | 277 |
|
| 278 | +class FakeResultProxy(object): |
| 279 | + """A fake class that pretends to behave like the ResultProxy from |
| 280 | + SqlAlchemy. |
| 281 | + """ |
| 282 | + def __init__(self, cursor, headers): |
| 283 | + self.fetchall = cursor.fetchall |
| 284 | + self.fetchmany = cursor.fetchmany |
| 285 | + self.rowcount = cursor.rowcount |
| 286 | + self.keys = lambda: headers |
| 287 | + self.returns_rows = True |
| 288 | + |
277 | 289 |
|
278 | 290 | def run(conn, sql, config, user_namespace):
|
279 | 291 | if sql.strip():
|
280 | 292 | for statement in sqlparse.split(sql):
|
281 |
| - if sql.strip().split()[0].lower() == 'begin': |
| 293 | + first_word = sql.strip().split()[0].lower() |
| 294 | + if first_word == 'begin': |
282 | 295 | raise Exception("ipython_sql does not support transactions")
|
283 |
| - txt = sqlalchemy.sql.text(statement) |
284 |
| - result = conn.session.execute(txt, user_namespace) |
| 296 | + if first_word.startswith('\\') and 'postgres' in str(conn.dialect): |
| 297 | + pgspecial = PGSpecial() |
| 298 | + _, cur, headers, _ = pgspecial.execute( |
| 299 | + conn.session.connection.cursor(), |
| 300 | + statement)[0] |
| 301 | + result = FakeResultProxy(cur, headers) |
| 302 | + else: |
| 303 | + txt = sqlalchemy.sql.text(statement) |
| 304 | + result = conn.session.execute(txt, user_namespace) |
285 | 305 | try:
|
286 | 306 | # mssql has autocommit
|
287 | 307 | if 'mssql' not in str(conn.dialect):
|
288 | 308 | conn.session.execute('commit')
|
289 |
| - except sqlalchemy.exc.OperationalError: |
| 309 | + except sqlalchemy.exc.OperationalError: |
290 | 310 | pass # not all engines can commit
|
291 | 311 | if result and config.feedback:
|
292 | 312 | print(interpret_rowcount(result.rowcount))
|
|
0 commit comments