Skip to content

Commit cf4710c

Browse files
authored
feat: add shortcuts for transaction (#355)
1 parent 35431ec commit cf4710c

File tree

6 files changed

+78
-13
lines changed

6 files changed

+78
-13
lines changed

docs/peewee_async/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ Databases
1919

2020
.. automethod:: peewee_async.databases.AioDatabase.allow_sync
2121

22+
.. automethod:: peewee_async.databases.AioDatabase.aio_begin
23+
24+
.. automethod:: peewee_async.databases.AioDatabase.aio_savepoint
25+
2226
.. automethod:: peewee_async.databases.AioDatabase.aio_atomic
2327

2428
.. automethod:: peewee_async.databases.AioDatabase.aio_transaction

docs/peewee_async/transaction.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ If you want to manage transactions manually you have to acquire a connection by
3535

3636
.. code-block:: python
3737
38-
from peewee_async import Transaction
3938
async with db.aio_connection() as connection:
40-
tr = Transaction(connection)
41-
await tr.begin() # BEGIN
39+
tr = await db.aio_begin() # BEGIN
4240
await TestModel.aio_create(text='FOO')
4341
try:
4442
await TestModel.aio_create(text='FOO')

peewee_async/databases.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,47 @@ async def aio_close(self) -> None:
101101

102102
await self.pool_backend.close()
103103

104+
async def _aio_begin(self, use_savepoint: bool = False) -> Transaction:
105+
_connection_context = connection_context.get()
106+
if _connection_context is None:
107+
raise peewee.OperationalError("This method can only be called within the aio_connection context manager")
108+
tr = Transaction(_connection_context.connection, is_savepoint=use_savepoint)
109+
await tr.begin()
110+
return tr
111+
112+
async def aio_begin(self) -> Transaction:
113+
"""
114+
Start a new database transaction.
115+
116+
This method executes the SQL `BEGIN` statement and returns a
117+
`Transaction` object representing the started transaction.
118+
119+
Notes:
120+
- This method must be called within an active :meth:`aio_connection)` context manager.
121+
- The returned :meth:`Transaction` object should be used to manage commit or rollback operations.
122+
123+
Returns:
124+
Transaction: An instance representing the active transaction.
125+
"""
126+
return await self._aio_begin()
127+
128+
async def aio_savepoint(self) -> Transaction:
129+
"""
130+
Start a new transaction savepoint.
131+
132+
This method executes the SQL `SAVEPOINT` statement and returns
133+
a `Transaction` object representing the created savepoint.
134+
135+
Notes:
136+
- This method must be called within an active :meth:`aio_connection` context manager.
137+
138+
- The returned :meth:`Transaction` object should be used to manage commit or rollback operations.
139+
140+
Returns:
141+
Transaction: An instance representing the active savepoint.
142+
"""
143+
return await self._aio_begin(use_savepoint=True)
144+
104145
def aio_atomic(self) -> AbstractAsyncContextManager[None]:
105146
"""Create an async context-manager which runs any queries in the wrapped block
106147
in a transaction (or save-point if blocks are nested).

peewee_async/testing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,16 @@ def __init__(self, database: AioDatabase) -> None:
3838
async def _disable_transactions(self) -> AsyncIterator[None]:
3939
@asynccontextmanager
4040
async def patched__aio_atomic(use_savepoint: bool = False) -> AsyncIterator[None]:
41-
raise ValueError("Using transactions 'aio_atomic' and 'aio_transcation' is disabled.")
41+
raise ValueError("Using 'aio_atomic' and 'aio_transcation' is disabled.")
4242
yield
4343

44-
with mock.patch.object(self.database, "_aio_atomic", patched__aio_atomic):
44+
async def patched___aio_begin(use_savepoint: bool = False) -> AsyncIterator[None]:
45+
raise ValueError("Using 'aio_begin' and 'aio_savepoint' is disabled.")
46+
47+
with (
48+
mock.patch.object(self.database, "_aio_atomic", patched__aio_atomic),
49+
mock.patch.object(self.database, "_aio_begin", patched___aio_begin),
50+
):
4551
yield
4652

4753
@asynccontextmanager

tests/test_testing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ async def test_transcations_disabled(db: AioDatabase, transaction_method: str) -
1818
pass
1919

2020

21+
@dbs_all
22+
async def test_aio_begin_savepoint_disabled(db: AioDatabase) -> None:
23+
async with TransactionTestCase(db):
24+
async with db.aio_connection():
25+
with pytest.raises(ValueError):
26+
await db.aio_begin()
27+
with pytest.raises(ValueError):
28+
await db.aio_savepoint()
29+
30+
2131
@dbs_all
2232
async def test_integration(db: AioDatabase) -> None:
2333

tests/test_transaction.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ async def t3() -> None:
110110

111111
@dbs_all
112112
async def test_transaction_manual_work(db: AioDatabase) -> None:
113-
async with db.aio_connection() as connection:
114-
tr = Transaction(connection)
115-
await tr.begin()
113+
async with db.aio_connection():
114+
tr = await db.aio_begin()
116115
await TestModel.aio_create(text="FOO")
117116
assert await TestModel.aio_get_or_none(text="FOO") is not None
118117
try:
@@ -126,6 +125,15 @@ async def test_transaction_manual_work(db: AioDatabase) -> None:
126125
assert db.pool_backend.has_acquired_connections() is False
127126

128127

128+
@dbs_all
129+
async def test_aio_begin_savepoint_error(db: AioDatabase) -> None:
130+
with pytest.raises(OperationalError):
131+
await db.aio_begin()
132+
133+
with pytest.raises(OperationalError):
134+
await db.aio_savepoint()
135+
136+
129137
@pytest.mark.parametrize(
130138
("method1", "method2"),
131139
[
@@ -180,14 +188,12 @@ async def test_savepoint_rollback(db: AioDatabase) -> None:
180188

181189
@dbs_all
182190
async def test_savepoint_manual_work(db: AioDatabase) -> None:
183-
async with db.aio_connection() as connection:
184-
tr = Transaction(connection)
185-
await tr.begin()
191+
async with db.aio_connection():
192+
tr = await db.aio_begin()
186193
await TestModel.aio_create(text="FOO")
187194
assert await TestModel.aio_get_or_none(text="FOO") is not None
188195

189-
savepoint = Transaction(connection, is_savepoint=True)
190-
await savepoint.begin()
196+
savepoint = await db.aio_savepoint()
191197
try:
192198
await TestModel.aio_create(text="FOO")
193199
except: # noqa: E722

0 commit comments

Comments
 (0)