Skip to content

Commit 24e177a

Browse files
authored
Merge pull request #38 from febus982/unit_of_work_rewrite
Unit of work rewrite
2 parents 4660c45 + 97443cd commit 24e177a

File tree

7 files changed

+148
-52
lines changed

7 files changed

+148
-52
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,13 @@ be used for repository operations, **assumed the same bind is used for all the r
182182
```python
183183
class MyRepo(SQLAlchemyRepository):
184184
_model = MyModel
185-
class MyOtherRepo(SQLAlchemyRepository):
186-
_model = MyOtherModel
187185

188186
bind = sa_manager.get_bind()
189-
uow = UnitOfWork(bind, (MyRepo, MyOtherRepo))
187+
uow = UnitOfWork(bind)
188+
uow.register_repository("repo_a", MyRepo)
189+
uow.register_repository("repo_b", SQLAlchemyRepository, MyOtherModel)
190190

191191
with uow.transaction():
192-
uow.MyRepo.save(some_model)
193-
uow.MyOtherRepo.save(some_other_model)
192+
uow.repository("repo_a").save(some_model)
193+
uow.repository("repo_b").save(some_other_model)
194194
```

docs/lifecycle.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ What you can do is:
3434
* Save the repositories in global variables and start a thread / asyncio task to handle
3535
a scoped request (e.g. one thread per HTTP request)
3636

37-
What you cannot do is:
37+
What you should not do is:
3838

3939
* Get a list of models
4040
* Save the models using `save()` in parallel threads / tasks (each task will have a different session)
4141

42-
/// tip | The recommendation is of course to try to use a single repository instance, where possible.
42+
/// warning | Remember: Concurrent writes to the db can cause undesired scenarios like locks and deadlocks!
43+
44+
///
45+
46+
/// tip | The recommendation is to try to use a single repository instance, where possible.
4347

4448
For example a strategy similar to this would be optimal, if possible:
4549

docs/repository/uow.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@ be used for repository operations.
1010
```python
1111
class MyRepo(SQLAlchemyRepository):
1212
_model = MyModel
13-
class MyOtherRepo(SQLAlchemyRepository):
14-
_model = MyOtherModel
1513

1614
bind = sa_manager.get_bind()
17-
uow = UnitOfWork(bind, (MyRepo, MyOtherRepo))
15+
uow = UnitOfWork(bind)
16+
uow.register_repository("repo_a", MyRepo)
17+
# args and kwargs are forwarded so we can also use directly `SQLAlchemyRepository` class
18+
uow.register_repository("repo_b", SQLAlchemyRepository, MyOtherModel)
1819

1920
with uow.transaction():
20-
uow.MyRepo.save(some_model)
21-
uow.MyOtherRepo.save(some_other_model)
21+
uow.repository("repo_a").save(some_model)
22+
uow.repository("repo_b").save(some_other_model)
2223

2324
# Optionally disable the commit/rollback handling
2425
with uow.transaction(read_only=True):
25-
model1 = uow.MyRepo.get(1)
26-
model2 = uow.MyOtherRepo.get(2)
26+
model1 = uow.repository("repo_a").get(1)
27+
model2 = uow.repository("repo_b").get(2)
2728
```
2829

2930
/// admonition | The unit of work implementation is still experimental.

sqlalchemy_bind_manager/_unit_of_work/__init__.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from abc import ABC
12
from contextlib import asynccontextmanager, contextmanager
2-
from typing import AsyncIterator, Iterable, Iterator, Type
3+
from typing import AsyncIterator, Dict, Generic, Iterator, Type, TypeVar, Union
34

45
from sqlalchemy.ext.asyncio import AsyncSession
56
from sqlalchemy.orm import Session
@@ -9,41 +10,63 @@
910
AsyncSessionHandler,
1011
SessionHandler,
1112
)
13+
from sqlalchemy_bind_manager.exceptions import RepositoryNotFound
1214
from sqlalchemy_bind_manager.repository import (
1315
SQLAlchemyAsyncRepository,
1416
SQLAlchemyRepository,
1517
)
1618

19+
REPOSITORY = TypeVar("REPOSITORY", SQLAlchemyRepository, SQLAlchemyAsyncRepository)
20+
SESSION_HANDLER = TypeVar("SESSION_HANDLER", SessionHandler, AsyncSessionHandler)
1721

18-
class UnitOfWork:
19-
_session_handler: SessionHandler
2022

21-
def __init__(
22-
self, bind: SQLAlchemyBind, repositories: Iterable[Type[SQLAlchemyRepository]]
23-
) -> None:
23+
class BaseUnitOfWork(Generic[REPOSITORY, SESSION_HANDLER], ABC):
24+
_session_handler: SESSION_HANDLER
25+
_repositories: Dict[str, REPOSITORY] = {}
26+
27+
def register_repository(
28+
self,
29+
name: str,
30+
repository_class: Type[REPOSITORY],
31+
model_class: Union[Type, None] = None,
32+
*args,
33+
**kwargs,
34+
):
35+
kwargs.pop("session", None)
36+
self._repositories[name] = repository_class( # type: ignore
37+
*args,
38+
session=self._session_handler.scoped_session(),
39+
model_class=model_class,
40+
**kwargs,
41+
)
42+
43+
def repository(self, name: str) -> REPOSITORY:
44+
try:
45+
return self._repositories[name]
46+
except KeyError:
47+
raise RepositoryNotFound(
48+
"The repository has not been initialised in this unit of work"
49+
)
50+
51+
52+
class UnitOfWork(BaseUnitOfWork[SQLAlchemyRepository, SessionHandler]):
53+
def __init__(self, bind: SQLAlchemyBind) -> None:
2454
super().__init__()
2555
self._session_handler = SessionHandler(bind)
26-
for r in repositories:
27-
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
2856

2957
@contextmanager
3058
def transaction(self, read_only: bool = False) -> Iterator[Session]:
3159
with self._session_handler.get_session(read_only=read_only) as _s:
3260
yield _s
3361

3462

35-
class AsyncUnitOfWork:
36-
_session_handler: AsyncSessionHandler
37-
63+
class AsyncUnitOfWork(BaseUnitOfWork[SQLAlchemyAsyncRepository, AsyncSessionHandler]):
3864
def __init__(
3965
self,
4066
bind: SQLAlchemyAsyncBind,
41-
repositories: Iterable[Type[SQLAlchemyAsyncRepository]],
4267
) -> None:
4368
super().__init__()
4469
self._session_handler = AsyncSessionHandler(bind)
45-
for r in repositories:
46-
setattr(self, r.__name__, r(session=self._session_handler.scoped_session()))
4770

4871
@asynccontextmanager
4972
async def transaction(self, read_only: bool = False) -> AsyncIterator[AsyncSession]:

sqlalchemy_bind_manager/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ class UnmappedProperty(Exception):
2424

2525
class SessionNotFound(Exception):
2626
pass
27+
28+
29+
class RepositoryNotFound(Exception):
30+
pass
Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from sqlalchemy_bind_manager.exceptions import RepositoryNotFound
6+
7+
18
async def test_repositories_are_initialised_with_uow_session(
29
sa_bind, repository_class, model_classes, uow_class
310
):
@@ -8,15 +15,72 @@ class ChildRepoClass(repository_class):
815
_model = model_classes[1]
916

1017
repository_classes = [RepoClass, ChildRepoClass]
11-
uow = uow_class(
12-
bind=sa_bind,
13-
repositories=repository_classes,
14-
)
18+
uow = uow_class(bind=sa_bind)
19+
uow.register_repository(RepoClass.__name__, RepoClass)
20+
uow.register_repository("ChildRepoClass", repository_class, model_classes[1])
21+
1522
for repo_class in repository_classes:
16-
assert hasattr(uow, repo_class.__name__)
17-
repo = getattr(uow, repo_class.__name__)
23+
repo = uow.repository(repo_class.__name__)
24+
assert repo is not None
1825
assert not hasattr(repo, "_session_handler")
1926
assert hasattr(repo, "_external_session")
2027
assert (
2128
getattr(repo, "_external_session") is uow._session_handler.scoped_session()
2229
)
30+
31+
32+
async def test_raises_exception_if_repository_not_found(sa_bind, uow_class):
33+
uow = uow_class(bind=sa_bind)
34+
with pytest.raises(RepositoryNotFound):
35+
uow.repository("Not existing")
36+
37+
38+
@pytest.mark.parametrize(
39+
["submitted_args", "submitted_kwargs", "received_args", "received_kwargs"],
40+
[
41+
pytest.param(
42+
("1", "2"),
43+
dict(a="b"),
44+
("2",),
45+
dict(model_class="1", a="b"),
46+
id="first_arg_model_class_if_no_kwarg",
47+
),
48+
pytest.param(
49+
tuple([]),
50+
dict(a="b", model_class="c"),
51+
tuple([]),
52+
dict(model_class="c", a="b"),
53+
id="model_class_rearranged_if_in_kwargs",
54+
),
55+
pytest.param(
56+
tuple([]),
57+
dict(a="b"),
58+
tuple([]),
59+
dict(model_class=None, a="b"),
60+
id="model_class_default_to_none",
61+
),
62+
pytest.param(
63+
tuple([]),
64+
dict(a="b", session="c"),
65+
tuple([]),
66+
dict(model_class=None, a="b"),
67+
id="session_removed_from_kwargs",
68+
),
69+
],
70+
)
71+
async def test_additional_arguments_are_forwarded(
72+
sa_bind,
73+
uow_class,
74+
submitted_args: tuple,
75+
submitted_kwargs: dict,
76+
received_args: tuple,
77+
received_kwargs: dict,
78+
):
79+
repo = MagicMock()
80+
81+
uow = uow_class(bind=sa_bind)
82+
uow.register_repository("r", repo, *submitted_args, **submitted_kwargs)
83+
84+
repo.assert_called_once_with(
85+
*received_args, session=uow._session_handler.scoped_session(), **received_kwargs
86+
)

tests/unit_of_work/test_session_behaviour.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ class RepoClass(repository_class):
1919
class ChildRepoClass(repository_class):
2020
_model = model_classes[1]
2121

22-
repository_classes = [RepoClass, ChildRepoClass]
23-
2422
with patch.object(
2523
session_handler_class, "commit", return_value=None
2624
) as mocked_uow_commit:
27-
uow = uow_class(sa_bind, repository_classes)
28-
repo1 = getattr(uow, repository_classes[0].__name__)
29-
repo2 = getattr(uow, repository_classes[1].__name__)
25+
uow = uow_class(sa_bind)
26+
uow.register_repository(RepoClass.__name__, RepoClass)
27+
uow.register_repository(ChildRepoClass.__name__, ChildRepoClass)
28+
repo1 = uow.repository(RepoClass.__name__)
29+
repo2 = uow.repository(ChildRepoClass.__name__)
3030

3131
# Populate a database entry to be used for tests
3232
model1 = model_classes[0](
@@ -56,10 +56,11 @@ class RepoClass(repository_class):
5656
class OtherRepoClass(repository_class):
5757
_model = model_classes[0]
5858

59-
repository_classes = [RepoClass, OtherRepoClass]
60-
uow = uow_class(sa_bind, repository_classes)
61-
repo1 = getattr(uow, repository_classes[0].__name__)
62-
repo2 = getattr(uow, repository_classes[1].__name__)
59+
uow = uow_class(sa_bind)
60+
uow.register_repository(RepoClass.__name__, RepoClass)
61+
uow.register_repository(OtherRepoClass.__name__, OtherRepoClass)
62+
repo1 = uow.repository(RepoClass.__name__)
63+
repo2 = uow.repository(OtherRepoClass.__name__)
6364

6465
# Populate a database entry to be used for tests
6566
model1 = model_classes[0](
@@ -87,10 +88,9 @@ async def test_uow_repository_operations_fail_without_transaction(
8788
class RepoClass(repository_class):
8889
_model = model_classes[0]
8990

90-
repository_classes = [RepoClass]
91-
92-
uow = uow_class(sa_bind, repository_classes)
93-
repo1 = getattr(uow, repository_classes[0].__name__)
91+
uow = uow_class(sa_bind)
92+
uow.register_repository(RepoClass.__name__, RepoClass)
93+
repo1 = uow.repository(RepoClass.__name__)
9494

9595
# Populate a database entry to be used for tests
9696
model1 = model_classes[0](
@@ -115,11 +115,11 @@ class RepoClass(repository_class):
115115
class OtherRepoClass(repository_class):
116116
_model = model_classes[0]
117117

118-
repository_classes = [RepoClass, OtherRepoClass]
119-
120-
uow = uow_class(sa_bind, repository_classes)
121-
repo1 = getattr(uow, repository_classes[0].__name__)
122-
repo2 = getattr(uow, repository_classes[1].__name__)
118+
uow = uow_class(sa_bind)
119+
uow.register_repository(RepoClass.__name__, RepoClass)
120+
uow.register_repository(OtherRepoClass.__name__, OtherRepoClass)
121+
repo1 = uow.repository(RepoClass.__name__)
122+
repo2 = uow.repository(OtherRepoClass.__name__)
123123

124124
# Populate a database entry to be used for tests
125125
model1 = model_classes[0](

0 commit comments

Comments
 (0)