Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/peewee_async/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ Databases

.. automethod:: peewee_async.databases.AioDatabase.allow_sync

.. automethod:: peewee_async.databases.AioDatabase.aio_begin

.. automethod:: peewee_async.databases.AioDatabase.aio_savepoint

.. automethod:: peewee_async.databases.AioDatabase.aio_atomic

.. automethod:: peewee_async.databases.AioDatabase.aio_transaction
Expand Down
4 changes: 1 addition & 3 deletions docs/peewee_async/transaction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ If you want to manage transactions manually you have to acquire a connection by

.. code-block:: python

from peewee_async import Transaction
async with db.aio_connection() as connection:
tr = Transaction(connection)
await tr.begin() # BEGIN
tr = await db.aio_begin() # BEGIN
await TestModel.aio_create(text='FOO')
try:
await TestModel.aio_create(text='FOO')
Expand Down
41 changes: 41 additions & 0 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,47 @@ async def aio_close(self) -> None:

await self.pool_backend.close()

async def _aio_begin(self, use_savepoint: bool = False) -> Transaction:
_connection_context = connection_context.get()
if _connection_context is None:
raise peewee.OperationalError("This method can only be called within the aio_connection context manager")
tr = Transaction(_connection_context.connection, is_savepoint=use_savepoint)
await tr.begin()
return tr

async def aio_begin(self) -> Transaction:
"""
Start a new database transaction.

This method executes the SQL `BEGIN` statement and returns a
`Transaction` object representing the started transaction.

Notes:
- This method must be called within an active :meth:`aio_connection)` context manager.
- The returned :meth:`Transaction` object should be used to manage commit or rollback operations.

Returns:
Transaction: An instance representing the active transaction.
"""
return await self._aio_begin()

async def aio_savepoint(self) -> Transaction:
"""
Start a new transaction savepoint.

This method executes the SQL `SAVEPOINT` statement and returns
a `Transaction` object representing the created savepoint.

Notes:
- This method must be called within an active :meth:`aio_connection` context manager.

- The returned :meth:`Transaction` object should be used to manage commit or rollback operations.

Returns:
Transaction: An instance representing the active savepoint.
"""
return await self._aio_begin(use_savepoint=True)

def aio_atomic(self) -> AbstractAsyncContextManager[None]:
"""Create an async context-manager which runs any queries in the wrapped block
in a transaction (or save-point if blocks are nested).
Expand Down
10 changes: 8 additions & 2 deletions peewee_async/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ def __init__(self, database: AioDatabase) -> None:
async def _disable_transactions(self) -> AsyncIterator[None]:
@asynccontextmanager
async def patched__aio_atomic(use_savepoint: bool = False) -> AsyncIterator[None]:
raise ValueError("Using transactions 'aio_atomic' and 'aio_transcation' is disabled.")
raise ValueError("Using 'aio_atomic' and 'aio_transcation' is disabled.")
yield

with mock.patch.object(self.database, "_aio_atomic", patched__aio_atomic):
async def patched___aio_begin(use_savepoint: bool = False) -> AsyncIterator[None]:
raise ValueError("Using 'aio_begin' and 'aio_savepoint' is disabled.")

with (
mock.patch.object(self.database, "_aio_atomic", patched__aio_atomic),
mock.patch.object(self.database, "_aio_begin", patched___aio_begin),
):
yield

@asynccontextmanager
Expand Down
10 changes: 10 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ async def test_transcations_disabled(db: AioDatabase, transaction_method: str) -
pass


@dbs_all
async def test_aio_begin_savepoint_disabled(db: AioDatabase) -> None:
async with TransactionTestCase(db):
async with db.aio_connection():
with pytest.raises(ValueError):
await db.aio_begin()
with pytest.raises(ValueError):
await db.aio_savepoint()


@dbs_all
async def test_integration(db: AioDatabase) -> None:

Expand Down
22 changes: 14 additions & 8 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ async def t3() -> None:

@dbs_all
async def test_transaction_manual_work(db: AioDatabase) -> None:
async with db.aio_connection() as connection:
tr = Transaction(connection)
await tr.begin()
async with db.aio_connection():
tr = await db.aio_begin()
await TestModel.aio_create(text="FOO")
assert await TestModel.aio_get_or_none(text="FOO") is not None
try:
Expand All @@ -126,6 +125,15 @@ async def test_transaction_manual_work(db: AioDatabase) -> None:
assert db.pool_backend.has_acquired_connections() is False


@dbs_all
async def test_aio_begin_savepoint_error(db: AioDatabase) -> None:
with pytest.raises(OperationalError):
await db.aio_begin()

with pytest.raises(OperationalError):
await db.aio_savepoint()


@pytest.mark.parametrize(
("method1", "method2"),
[
Expand Down Expand Up @@ -180,14 +188,12 @@ async def test_savepoint_rollback(db: AioDatabase) -> None:

@dbs_all
async def test_savepoint_manual_work(db: AioDatabase) -> None:
async with db.aio_connection() as connection:
tr = Transaction(connection)
await tr.begin()
async with db.aio_connection():
tr = await db.aio_begin()
await TestModel.aio_create(text="FOO")
assert await TestModel.aio_get_or_none(text="FOO") is not None

savepoint = Transaction(connection, is_savepoint=True)
await savepoint.begin()
savepoint = await db.aio_savepoint()
try:
await TestModel.aio_create(text="FOO")
except: # noqa: E722
Expand Down
Loading