Skip to content

Commit 5e842be

Browse files
committed
Rewrite Unit of Work interface methods
1 parent 4660c45 commit 5e842be

File tree

4 files changed

+70
-38
lines changed

4 files changed

+70
-38
lines changed

sqlalchemy_bind_manager/_unit_of_work/__init__.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from abc import ABC
12
from contextlib import asynccontextmanager, contextmanager
2-
from typing import AsyncIterator, Iterable, Iterator, Type
3+
from typing import AsyncIterator, Dict, Generic, Iterator, Type, TypeVar, Union
34

45
from sqlalchemy.ext.asyncio import AsyncSession
56
from sqlalchemy.orm import Session
@@ -9,41 +10,57 @@
910
AsyncSessionHandler,
1011
SessionHandler,
1112
)
13+
from sqlalchemy_bind_manager.exceptions import RepositoryNotFound
1214
from sqlalchemy_bind_manager.repository import (
1315
SQLAlchemyAsyncRepository,
1416
SQLAlchemyRepository,
1517
)
1618

19+
REPOSITORY = TypeVar("REPOSITORY", SQLAlchemyRepository, SQLAlchemyAsyncRepository)
20+
SESSION_HANDLER = TypeVar("SESSION_HANDLER", SessionHandler, AsyncSessionHandler)
1721

18-
class UnitOfWork:
19-
_session_handler: SessionHandler
2022

21-
def __init__(
22-
self, bind: SQLAlchemyBind, repositories: Iterable[Type[SQLAlchemyRepository]]
23-
) -> None:
23+
class BaseUnitOfWork(Generic[REPOSITORY, SESSION_HANDLER], ABC):
24+
_session_handler: SESSION_HANDLER
25+
_repositories: Dict[str, REPOSITORY] = {}
26+
27+
def register_repository(
28+
self,
29+
name: str,
30+
repository_class: Type[REPOSITORY],
31+
model_class: Union[Type, None] = None,
32+
):
33+
self._repositories[name] = repository_class(
34+
session=self._session_handler.scoped_session(), model_class=model_class
35+
)
36+
37+
def repository(self, name: str) -> REPOSITORY:
38+
try:
39+
return self._repositories[name]
40+
except KeyError:
41+
raise RepositoryNotFound(
42+
"The repository has not been initialised in this unit of work"
43+
)
44+
45+
46+
class UnitOfWork(BaseUnitOfWork[SQLAlchemyRepository, SessionHandler]):
47+
def __init__(self, bind: SQLAlchemyBind) -> None:
2448
super().__init__()
2549
self._session_handler = SessionHandler(bind)
26-
for r in repositories:
27-
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
2850

2951
@contextmanager
3052
def transaction(self, read_only: bool = False) -> Iterator[Session]:
3153
with self._session_handler.get_session(read_only=read_only) as _s:
3254
yield _s
3355

3456

35-
class AsyncUnitOfWork:
36-
_session_handler: AsyncSessionHandler
37-
57+
class AsyncUnitOfWork(BaseUnitOfWork[SQLAlchemyAsyncRepository, AsyncSessionHandler]):
3858
def __init__(
3959
self,
4060
bind: SQLAlchemyAsyncBind,
41-
repositories: Iterable[Type[SQLAlchemyAsyncRepository]],
4261
) -> None:
4362
super().__init__()
4463
self._session_handler = AsyncSessionHandler(bind)
45-
for r in repositories:
46-
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
4764

4865
@asynccontextmanager
4966
async def transaction(self, read_only: bool = False) -> AsyncIterator[AsyncSession]:

sqlalchemy_bind_manager/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ class UnmappedProperty(Exception):
2424

2525
class SessionNotFound(Exception):
2626
pass
27+
28+
29+
class RepositoryNotFound(Exception):
30+
pass
Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import pytest
2+
3+
from sqlalchemy_bind_manager.exceptions import RepositoryNotFound
4+
5+
16
async def test_repositories_are_initialised_with_uow_session(
27
sa_bind, repository_class, model_classes, uow_class
38
):
@@ -8,15 +13,21 @@ class ChildRepoClass(repository_class):
813
_model = model_classes[1]
914

1015
repository_classes = [RepoClass, ChildRepoClass]
11-
uow = uow_class(
12-
bind=sa_bind,
13-
repositories=repository_classes,
14-
)
16+
uow = uow_class(bind=sa_bind)
17+
uow.register_repository(RepoClass.__name__, RepoClass)
18+
uow.register_repository("ChildRepoClass", repository_class, model_classes[1])
19+
1520
for repo_class in repository_classes:
16-
assert hasattr(uow, repo_class.__name__)
17-
repo = getattr(uow, repo_class.__name__)
21+
repo = uow.repository(repo_class.__name__)
22+
assert repo is not None
1823
assert not hasattr(repo, "_session_handler")
1924
assert hasattr(repo, "_external_session")
2025
assert (
2126
getattr(repo, "_external_session") is uow._session_handler.scoped_session()
2227
)
28+
29+
30+
async def test_raises_exception_if_repository_not_found(sa_bind, uow_class):
31+
uow = uow_class(bind=sa_bind)
32+
with pytest.raises(RepositoryNotFound):
33+
uow.repository("Not existing")

tests/unit_of_work/test_session_behaviour.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class RepoClass(repository_class):
1919
class ChildRepoClass(repository_class):
2020
_model = model_classes[1]
2121

22-
repository_classes = [RepoClass, ChildRepoClass]
23-
2422
with patch.object(
2523
session_handler_class, "commit", return_value=None
2624
) as mocked_uow_commit:
27-
uow = uow_class(sa_bind, repository_classes)
28-
repo1 = getattr(uow, repository_classes[0].__name__)
29-
repo2 = getattr(uow, repository_classes[1].__name__)
25+
uow = uow_class(sa_bind)
26+
uow.register_repository(RepoClass.__name__, RepoClass)
27+
uow.register_repository(ChildRepoClass.__name__, ChildRepoClass)
28+
repo1 = uow.repository(RepoClass.__name__)
29+
repo2 = uow.repository(ChildRepoClass.__name__)
3030

3131
# Populate a database entry to be used for tests
3232
model1 = model_classes[0](
@@ -56,10 +56,11 @@ class RepoClass(repository_class):
5656
class OtherRepoClass(repository_class):
5757
_model = model_classes[0]
5858

59-
repository_classes = [RepoClass, OtherRepoClass]
60-
uow = uow_class(sa_bind, repository_classes)
61-
repo1 = getattr(uow, repository_classes[0].__name__)
62-
repo2 = getattr(uow, repository_classes[1].__name__)
59+
uow = uow_class(sa_bind)
60+
uow.register_repository(RepoClass.__name__, RepoClass)
61+
uow.register_repository(OtherRepoClass.__name__, OtherRepoClass)
62+
repo1 = uow.repository(RepoClass.__name__)
63+
repo2 = uow.repository(OtherRepoClass.__name__)
6364

6465
# Populate a database entry to be used for tests
6566
model1 = model_classes[0](
@@ -87,10 +88,9 @@ async def test_uow_repository_operations_fail_without_transaction(
8788
class RepoClass(repository_class):
8889
_model = model_classes[0]
8990

90-
repository_classes = [RepoClass]
91-
92-
uow = uow_class(sa_bind, repository_classes)
93-
repo1 = getattr(uow, repository_classes[0].__name__)
91+
uow = uow_class(sa_bind)
92+
uow.register_repository(RepoClass.__name__, RepoClass)
93+
repo1 = uow.repository(RepoClass.__name__)
9494

9595
# Populate a database entry to be used for tests
9696
model1 = model_classes[0](
@@ -115,11 +115,11 @@ class RepoClass(repository_class):
115115
class OtherRepoClass(repository_class):
116116
_model = model_classes[0]
117117

118-
repository_classes = [RepoClass, OtherRepoClass]
119-
120-
uow = uow_class(sa_bind, repository_classes)
121-
repo1 = getattr(uow, repository_classes[0].__name__)
122-
repo2 = getattr(uow, repository_classes[1].__name__)
118+
uow = uow_class(sa_bind)
119+
uow.register_repository(RepoClass.__name__, RepoClass)
120+
uow.register_repository(OtherRepoClass.__name__, OtherRepoClass)
121+
repo1 = uow.repository(RepoClass.__name__)
122+
repo2 = uow.repository(OtherRepoClass.__name__)
123123

124124
# Populate a database entry to be used for tests
125125
model1 = model_classes[0](

0 commit comments

Comments
 (0)