Skip to content

Commit 2e6e0f5

Browse files
brianmaissyPliner
andauthored
Don't run ROLLBACK when the connection is closed. (#778)
* Don't run ROLLBACK when the connection is closed. This can be caused when a query times out while running, for example, and the connection is closed as a result (as opposed to cancelling the query, since PR #570). In this case, we would rather not emit the ROLLBACK (the connection is already closed, so the transaction is over anyway), rather than raising an exception when trying to use a connection which is already closed. See issue #777. * add tests for cancelled queries in transaction * Update test_transaction.py * Update test_sa_transaction.py * Update connection.py Co-authored-by: Yury Pliner <[email protected]>
1 parent 59671d3 commit 2e6e0f5

File tree

4 files changed

+138
-3
lines changed

4 files changed

+138
-3
lines changed

aiopg/sa/connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ async def _commit_impl(self):
211211
self._transaction = None
212212

213213
async def _rollback_impl(self):
214+
if self._connection.closed:
215+
self._transaction = None
216+
return
217+
214218
cur = await self._get_cursor()
215219
try:
216220
await cur.execute('ROLLBACK')
@@ -253,6 +257,10 @@ async def _savepoint_impl(self, name=None):
253257
cur.close()
254258

255259
async def _rollback_to_savepoint_impl(self, name, parent):
260+
if self._connection.closed:
261+
self._transaction = parent
262+
return
263+
256264
cur = await self._get_cursor()
257265
try:
258266
await cur.execute(f'ROLLBACK TO SAVEPOINT {name}')

aiopg/transaction.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,15 @@ async def commit(self):
127127

128128
async def rollback(self):
129129
self._check_commit_rollback()
130-
await self._cur.execute(self._isolation.rollback())
130+
if not self._cur.closed:
131+
await self._cur.execute(self._isolation.rollback())
131132
self._is_begin = False
132133

133134
async def rollback_savepoint(self):
134135
self._check_release_rollback()
135-
await self._cur.execute(
136-
self._isolation.rollback_savepoint(self._unique_id))
136+
if not self._cur.closed:
137+
await self._cur.execute(
138+
self._isolation.rollback_savepoint(self._unique_id))
137139
self._unique_id = None
138140

139141
async def release_savepoint(self):

tests/test_sa_transaction.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from unittest import mock
23

34
import pytest
@@ -411,3 +412,67 @@ async def test_transaction_mode(connect):
411412
res1 = await conn.scalar(select([func.count()]).select_from(tbl))
412413
assert 5 == res1
413414
await tr8.commit()
415+
416+
417+
async def test_timeout_in_transaction_context_manager(make_engine):
418+
engine = await make_engine(timeout=1)
419+
with pytest.raises(asyncio.TimeoutError):
420+
async with engine.acquire() as connection:
421+
async with connection.begin():
422+
await connection.execute("SELECT pg_sleep(10)")
423+
424+
engine.terminate()
425+
await engine.wait_closed()
426+
427+
428+
async def test_timeout_in_nested_transaction_context_manager(make_engine):
429+
engine = await make_engine(timeout=1)
430+
with pytest.raises(asyncio.TimeoutError):
431+
async with engine.acquire() as connection:
432+
async with connection.begin():
433+
async with connection.begin_nested():
434+
await connection.execute("SELECT pg_sleep(10)")
435+
436+
engine.terminate()
437+
await engine.wait_closed()
438+
439+
440+
async def test_cancel_in_transaction_context_manager(make_engine, loop):
441+
engine = await make_engine()
442+
443+
with pytest.raises(asyncio.CancelledError):
444+
async with engine.acquire() as connection:
445+
async with connection.begin():
446+
task = loop.create_task(
447+
connection.execute("SELECT pg_sleep(10)"))
448+
449+
async def cancel_soon():
450+
await asyncio.sleep(1)
451+
task.cancel()
452+
453+
loop.create_task(cancel_soon())
454+
await task
455+
456+
engine.terminate()
457+
await engine.wait_closed()
458+
459+
460+
async def test_cancel_in_savepoint_context_manager(make_engine, loop):
461+
engine = await make_engine()
462+
463+
with pytest.raises(asyncio.CancelledError):
464+
async with engine.acquire() as connection:
465+
async with connection.begin():
466+
async with connection.begin_nested():
467+
task = loop.create_task(
468+
connection.execute("SELECT pg_sleep(10)"))
469+
470+
async def cancel_soon():
471+
await asyncio.sleep(1)
472+
task.cancel()
473+
474+
loop.create_task(cancel_soon())
475+
await task
476+
477+
engine.terminate()
478+
await engine.wait_closed()

tests/test_transaction.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import psycopg2
24
import pytest
35

@@ -181,3 +183,61 @@ async def test_transaction_point_oldstyle(engine):
181183
(3, 'data')]
182184

183185
await tr.commit()
186+
187+
188+
async def test_timeout_in_transaction_context_manager(make_engine):
189+
engine = await make_engine(timeout=1)
190+
with pytest.raises(asyncio.TimeoutError):
191+
async with engine.acquire() as connection:
192+
async with Transaction(connection, IsolationLevel.read_committed):
193+
await connection.execute("SELECT pg_sleep(10)")
194+
195+
engine.terminate()
196+
await engine.wait_closed()
197+
198+
199+
async def test_timeout_in_savepoint_context_manager(make_engine):
200+
engine = await make_engine(timeout=1)
201+
with pytest.raises(asyncio.TimeoutError):
202+
async with engine.acquire() as connection:
203+
async with Transaction(
204+
connection, IsolationLevel.read_committed
205+
) as transaction:
206+
async with transaction.point():
207+
await connection.execute("SELECT pg_sleep(10)")
208+
209+
engine.terminate()
210+
await engine.wait_closed()
211+
212+
213+
async def test_cancel_in_transaction_context_manager(engine, loop):
214+
with pytest.raises(asyncio.CancelledError):
215+
async with engine.acquire() as connection:
216+
async with Transaction(connection, IsolationLevel.read_committed):
217+
task = loop.create_task(
218+
connection.execute("SELECT pg_sleep(10)"))
219+
220+
async def cancel_soon():
221+
await asyncio.sleep(1)
222+
task.cancel()
223+
224+
loop.create_task(cancel_soon())
225+
await task
226+
227+
228+
async def test_cancel_in_savepoint_context_manager(engine, loop):
229+
with pytest.raises(asyncio.CancelledError):
230+
async with engine.acquire() as connection:
231+
async with Transaction(
232+
connection, IsolationLevel.read_committed
233+
) as transaction:
234+
async with transaction.point():
235+
task = loop.create_task(
236+
connection.execute("SELECT pg_sleep(10)"))
237+
238+
async def cancel_soon():
239+
await asyncio.sleep(1)
240+
task.cancel()
241+
242+
loop.create_task(cancel_soon())
243+
await task

0 commit comments

Comments
 (0)