|
11 | 11 | import os
|
12 | 12 | import platform
|
13 | 13 | import random
|
| 14 | +import sys |
| 15 | +import textwrap |
14 | 16 | import time
|
15 | 17 | import unittest
|
16 | 18 |
|
@@ -195,6 +197,7 @@ async def test_pool_11(self):
|
195 | 197 | self.assertIn(repr(con._con), repr(con)) # Test __repr__.
|
196 | 198 |
|
197 | 199 | ps = await con.prepare('SELECT 1')
|
| 200 | + txn = con.transaction() |
198 | 201 | async with con.transaction():
|
199 | 202 | cur = await con.cursor('SELECT 1')
|
200 | 203 | ps_cur = await ps.cursor()
|
@@ -233,6 +236,14 @@ async def test_pool_11(self):
|
233 | 236 |
|
234 | 237 | c.forward(1)
|
235 | 238 |
|
| 239 | + for meth in ('start', 'commit', 'rollback'): |
| 240 | + with self.assertRaisesRegex( |
| 241 | + asyncpg.InterfaceError, |
| 242 | + r'cannot call Transaction\.{meth}.*released ' |
| 243 | + r'back to the pool'.format(meth=meth)): |
| 244 | + |
| 245 | + getattr(txn, meth)() |
| 246 | + |
236 | 247 | await pool.close()
|
237 | 248 |
|
238 | 249 | async def test_pool_12(self):
|
@@ -661,6 +672,75 @@ async def test_pool_handles_inactive_connection_errors(self):
|
661 | 672 | await con.close()
|
662 | 673 | await pool.close()
|
663 | 674 |
|
| 675 | + @unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support') |
| 676 | + async def test_pool_handles_transaction_exit_in_asyncgen_1(self): |
| 677 | + pool = await self.create_pool(database='postgres', |
| 678 | + min_size=1, max_size=1) |
| 679 | + |
| 680 | + locals_ = {} |
| 681 | + exec(textwrap.dedent('''\ |
| 682 | + async def iterate(con): |
| 683 | + async with con.transaction(): |
| 684 | + for record in await con.fetch("SELECT 1"): |
| 685 | + yield record |
| 686 | + '''), globals(), locals_) |
| 687 | + iterate = locals_['iterate'] |
| 688 | + |
| 689 | + class MyException(Exception): |
| 690 | + pass |
| 691 | + |
| 692 | + with self.assertRaises(MyException): |
| 693 | + async with pool.acquire() as con: |
| 694 | + async for _ in iterate(con): # noqa |
| 695 | + raise MyException() |
| 696 | + |
| 697 | + @unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support') |
| 698 | + async def test_pool_handles_transaction_exit_in_asyncgen_2(self): |
| 699 | + pool = await self.create_pool(database='postgres', |
| 700 | + min_size=1, max_size=1) |
| 701 | + |
| 702 | + locals_ = {} |
| 703 | + exec(textwrap.dedent('''\ |
| 704 | + async def iterate(con): |
| 705 | + async with con.transaction(): |
| 706 | + for record in await con.fetch("SELECT 1"): |
| 707 | + yield record |
| 708 | + '''), globals(), locals_) |
| 709 | + iterate = locals_['iterate'] |
| 710 | + |
| 711 | + class MyException(Exception): |
| 712 | + pass |
| 713 | + |
| 714 | + with self.assertRaises(MyException): |
| 715 | + async with pool.acquire() as con: |
| 716 | + iterator = iterate(con) |
| 717 | + async for _ in iterator: # noqa |
| 718 | + raise MyException() |
| 719 | + |
| 720 | + del iterator |
| 721 | + |
| 722 | + @unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support') |
| 723 | + async def test_pool_handles_asyncgen_finalization(self): |
| 724 | + pool = await self.create_pool(database='postgres', |
| 725 | + min_size=1, max_size=1) |
| 726 | + |
| 727 | + locals_ = {} |
| 728 | + exec(textwrap.dedent('''\ |
| 729 | + async def iterate(con): |
| 730 | + for record in await con.fetch("SELECT 1"): |
| 731 | + yield record |
| 732 | + '''), globals(), locals_) |
| 733 | + iterate = locals_['iterate'] |
| 734 | + |
| 735 | + class MyException(Exception): |
| 736 | + pass |
| 737 | + |
| 738 | + with self.assertRaises(MyException): |
| 739 | + async with pool.acquire() as con: |
| 740 | + async with con.transaction(): |
| 741 | + async for _ in iterate(con): # noqa |
| 742 | + raise MyException() |
| 743 | + |
664 | 744 |
|
665 | 745 | @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
|
666 | 746 | class TestHotStandby(tb.ConnectedTestCase):
|
|
0 commit comments