Skip to content

Commit 476fe83

Browse files
committed
Allow to submit additional args and kwargs to repositories
1 parent 5e842be commit 476fe83

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

sqlalchemy_bind_manager/_unit_of_work/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ def register_repository(
2929
name: str,
3030
repository_class: Type[REPOSITORY],
3131
model_class: Union[Type, None] = None,
32+
*args,
33+
**kwargs,
3234
):
35+
kwargs.pop("session", None)
3336
self._repositories[name] = repository_class(
34-
session=self._session_handler.scoped_session(), model_class=model_class
37+
*args,
38+
session=self._session_handler.scoped_session(),
39+
model_class=model_class,
40+
**kwargs,
3541
)
3642

3743
def repository(self, name: str) -> REPOSITORY:

tests/unit_of_work/test_lifecycle.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import MagicMock
2+
13
import pytest
24

35
from sqlalchemy_bind_manager.exceptions import RepositoryNotFound
@@ -31,3 +33,54 @@ async def test_raises_exception_if_repository_not_found(sa_bind, uow_class):
3133
uow = uow_class(bind=sa_bind)
3234
with pytest.raises(RepositoryNotFound):
3335
uow.repository("Not existing")
36+
37+
38+
@pytest.mark.parametrize(
39+
["submitted_args", "submitted_kwargs", "received_args", "received_kwargs"],
40+
[
41+
pytest.param(
42+
("1", "2"),
43+
dict(a="b"),
44+
("2",),
45+
dict(model_class="1", a="b"),
46+
id="first_arg_model_class_if_no_kwarg",
47+
),
48+
pytest.param(
49+
tuple([]),
50+
dict(a="b", model_class="c"),
51+
tuple([]),
52+
dict(model_class="c", a="b"),
53+
id="model_class_rearranged_if_in_kwargs",
54+
),
55+
pytest.param(
56+
tuple([]),
57+
dict(a="b"),
58+
tuple([]),
59+
dict(model_class=None, a="b"),
60+
id="model_class_default_to_none",
61+
),
62+
pytest.param(
63+
tuple([]),
64+
dict(a="b", session="c"),
65+
tuple([]),
66+
dict(model_class=None, a="b"),
67+
id="session_removed_from_kwargs",
68+
),
69+
],
70+
)
71+
async def test_additional_arguments_are_forwarded(
72+
sa_bind,
73+
uow_class,
74+
submitted_args: tuple,
75+
submitted_kwargs: dict,
76+
received_args: tuple,
77+
received_kwargs: dict,
78+
):
79+
repo = MagicMock()
80+
81+
uow = uow_class(bind=sa_bind)
82+
uow.register_repository("r", repo, *submitted_args, **submitted_kwargs)
83+
84+
repo.assert_called_once_with(
85+
*received_args, session=uow._session_handler.scoped_session(), **received_kwargs
86+
)

0 commit comments

Comments
 (0)