Skip to content

Commit 1643610

Browse files
committed
Merge UOW lifcycle tests
1 parent f82ccdc commit 1643610

File tree

5 files changed

+63
-34
lines changed

5 files changed

+63
-34
lines changed

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import os
3+
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
34
from typing import Tuple, Type
45
from uuid import uuid4
56

@@ -47,6 +48,26 @@ async def f(call):
4748
return f
4849

4950

51+
@pytest.fixture()
52+
def sync_async_cm_wrapper():
53+
"""
54+
Tiny wrapper to allow calling sync and async methods using await.
55+
56+
:return:
57+
"""
58+
59+
@asynccontextmanager
60+
async def f(cm):
61+
if isinstance(cm, _AsyncGeneratorContextManager):
62+
async with cm as c:
63+
yield c
64+
else:
65+
with cm as c:
66+
yield c
67+
68+
return f
69+
70+
5071
@pytest.fixture
5172
def sa_manager() -> SQLAlchemyBindManager:
5273
test_sync_db_path = f"./{uuid4()}.db"

tests/session_handler/test_session_lifecycle.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
21
from unittest.mock import AsyncMock, MagicMock, patch
32

43
import pytest
@@ -8,16 +7,6 @@
87
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyAsyncBind
98

109

11-
@asynccontextmanager
12-
async def cm_wrapper(cm):
13-
if isinstance(cm, _AsyncGeneratorContextManager):
14-
async with cm as c:
15-
yield c
16-
else:
17-
with cm as c:
18-
yield c
19-
20-
2110
async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
2211
uow = session_handler_class(sa_bind)
2312
original_session_remove = uow._session_class.remove
@@ -64,6 +53,7 @@ async def test_commit_is_called_only_if_not_read_only(
6453
session_handler_class,
6554
model_class,
6655
sa_bind,
56+
sync_async_cm_wrapper,
6757
):
6858
uow = session_handler_class(sa_bind)
6959

@@ -75,7 +65,9 @@ async def test_commit_is_called_only_if_not_read_only(
7565
with patch.object(
7666
session_handler_class, "commit", return_value=None
7767
) as mocked_uow_commit:
78-
async with cm_wrapper(uow.get_session(read_only=read_only_flag)) as _session:
68+
async with sync_async_cm_wrapper(
69+
uow.get_session(read_only=read_only_flag)
70+
) as _session:
7971
_session.add(model1)
8072

8173
assert mocked_uow_commit.call_count == int(not read_only_flag)

tests/unit_of_work/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Type, Union
2+
3+
import pytest
4+
5+
from sqlalchemy_bind_manager._bind_manager import SQLAlchemyAsyncBind, SQLAlchemyBind
6+
from sqlalchemy_bind_manager._repository import (
7+
SQLAlchemyAsyncRepository,
8+
SQLAlchemyRepository,
9+
)
10+
from sqlalchemy_bind_manager.repository import AsyncUnitOfWork, UnitOfWork
11+
12+
13+
@pytest.fixture
14+
def uow_class(sa_bind):
15+
return AsyncUnitOfWork if isinstance(sa_bind, SQLAlchemyAsyncBind) else UnitOfWork
16+
17+
18+
@pytest.fixture
19+
def repository_class(
20+
sa_bind: Union[SQLAlchemyBind, SQLAlchemyAsyncBind]
21+
) -> Type[Union[SQLAlchemyAsyncRepository, SQLAlchemyRepository]]:
22+
base_class = (
23+
SQLAlchemyRepository
24+
if isinstance(sa_bind, SQLAlchemyBind)
25+
else SQLAlchemyAsyncRepository
26+
)
27+
28+
return base_class

tests/unit_of_work/sync/test_lifecycle.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tests/unit_of_work/async_/test_lifecycle.py renamed to tests/unit_of_work/test_lifecycle.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from sqlalchemy_bind_manager._unit_of_work import AsyncUnitOfWork
2-
3-
41
async def test_repositories_are_initialised_with_uow_session(
5-
sa_manager, repository_classes
2+
sa_bind, repository_class, model_classes, uow_class
63
):
7-
uow = AsyncUnitOfWork(
8-
bind=sa_manager.get_bind(),
4+
class RepoClass(repository_class):
5+
_model = model_classes[0]
6+
7+
class ChildRepoClass(repository_class):
8+
_model = model_classes[1]
9+
10+
repository_classes = [RepoClass, ChildRepoClass]
11+
uow = uow_class(
12+
bind=sa_bind,
913
repositories=repository_classes,
1014
)
1115
for repo_class in repository_classes:

0 commit comments

Comments
 (0)