Skip to content

Commit 000e20b

Browse files
authored
Merge pull request #33 from febus982/add_repository_methods_for_many_models
Add repository methods for collections of models
2 parents d57c79a + bfc2d1a commit 000e20b

File tree

9 files changed

+346
-69
lines changed

9 files changed

+346
-69
lines changed

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Union,
1313
)
1414

15+
from sqlalchemy import select
1516
from sqlalchemy.ext.asyncio import AsyncSession
1617

1718
from .._bind_manager import SQLAlchemyAsyncBind
@@ -84,17 +85,22 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
8485
raise ModelNotFound("No rows found for provided primary key.")
8586
return model
8687

87-
async def delete(
88-
self,
89-
entity: Union[MODEL, PRIMARY_KEY],
90-
) -> None:
91-
# TODO: delete without loading the model
92-
if isinstance(entity, self._model):
93-
obj = entity
94-
else:
95-
obj = await self.get(entity) # type: ignore
88+
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
89+
stmt = select(self._model).where(
90+
getattr(self._model, self._model_pk()).in_(identifiers)
91+
)
92+
93+
async with self._get_session(commit=False) as session:
94+
return [x for x in (await session.execute(stmt)).scalars()]
95+
96+
async def delete(self, instance: MODEL) -> None:
97+
async with self._get_session() as session:
98+
await session.delete(instance)
99+
100+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
96101
async with self._get_session() as session:
97-
await session.delete(obj)
102+
for instance in instances:
103+
await session.delete(instance)
98104

99105
async def find(
100106
self,

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Union,
1313
)
1414

15+
from sqlalchemy import select
1516
from sqlalchemy.orm import Session
1617

1718
from .._bind_manager import SQLAlchemyBind
@@ -73,14 +74,22 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
7374
raise ModelNotFound("No rows found for provided primary key.")
7475
return model
7576

76-
def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
77-
# TODO: delete without loading the model
78-
if isinstance(entity, self._model):
79-
obj = entity
80-
else:
81-
obj = self.get(entity) # type: ignore
77+
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
78+
stmt = select(self._model).where(
79+
getattr(self._model, self._model_pk()).in_(identifiers)
80+
)
81+
82+
with self._get_session(commit=False) as session:
83+
return [x for x in session.execute(stmt).scalars()]
84+
85+
def delete(self, instance: MODEL) -> None:
86+
with self._get_session() as session:
87+
session.delete(instance)
88+
89+
def delete_many(self, instances: Iterable[MODEL]) -> None:
8290
with self._get_session() as session:
83-
session.delete(obj)
91+
for model in instances:
92+
session.delete(model)
8493

8594
def find(
8695
self,

sqlalchemy_bind_manager/protocols.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,29 @@ async def get(self, identifier: PRIMARY_KEY) -> MODEL:
4747
:return: A model instance
4848
:raises ModelNotFound: No model has been found using the primary key
4949
"""
50-
# TODO: implement get_many()
5150
...
5251

53-
async def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
52+
async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
53+
"""Get a list of models by primary keys.
54+
55+
:param identifiers: A list of primary keys
56+
:type identifiers: List
57+
:return: A list of models
58+
:rtype: List
59+
"""
60+
...
61+
62+
async def delete(self, instance: MODEL) -> None:
5463
"""Deletes a model.
5564
56-
:param entity: The model instance or the primary key
57-
:type entity: Union[MODEL, PRIMARY_KEY]
65+
:param instance: The model instance
66+
"""
67+
...
68+
69+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
70+
"""Deletes a collection of models in a single transaction.
71+
72+
:param instances: The model instances
5873
"""
5974
...
6075

@@ -188,14 +203,29 @@ def get(self, identifier: PRIMARY_KEY) -> MODEL:
188203
:return: A model instance
189204
:raises ModelNotFound: No model has been found using the primary key
190205
"""
191-
# TODO: implement get_many()
192206
...
193207

194-
def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
208+
def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
209+
"""Get a list of models by primary keys.
210+
211+
:param identifiers: A list of primary keys
212+
:type identifiers: List
213+
:return: A list of models
214+
:rtype: List
215+
"""
216+
...
217+
218+
def delete(self, instance: MODEL) -> None:
195219
"""Deletes a model.
196220
197-
:param entity: The model instance or the primary key
198-
:type entity: Union[MODEL, PRIMARY_KEY]
221+
:param instance: The model instance
222+
"""
223+
...
224+
225+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
226+
"""Deletes a collection of models in a single transaction.
227+
228+
:param instances: The model instances
199229
"""
200230
...
201231

tests/repository/async_/test_delete.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,5 @@
11
import pytest
2-
3-
4-
async def test_can_delete_by_pk(repository_class, model_class, sa_manager):
5-
model = model_class(
6-
model_id=1,
7-
name="Someone",
8-
)
9-
model2 = model_class(
10-
model_id=2,
11-
name="SomeoneElse",
12-
)
13-
repo = repository_class(sa_manager.get_bind())
14-
await repo.save_many({model, model2})
15-
16-
results = [x for x in await repo.find()]
17-
assert len(results) == 2
18-
19-
await repo.delete(1)
20-
results = [x for x in await repo.find()]
21-
assert len(results) == 1
22-
assert results[0].model_id == 2
23-
assert results[0].name == "SomeoneElse"
2+
from sqlalchemy import select
243

254

265
async def test_can_delete_by_instance(repository_class, model_class, sa_manager):
@@ -55,3 +34,37 @@ async def test_delete_inexistent_raises_exception(
5534

5635
with pytest.raises(Exception):
5736
await repo.delete(4)
37+
38+
with pytest.raises(Exception):
39+
await repo.delete(
40+
model_class(
41+
model_id=823,
42+
name="Someone",
43+
)
44+
)
45+
46+
47+
async def test_relationships_are_respected(
48+
related_repository_class, related_model_classes, sa_manager
49+
):
50+
parent = related_model_classes[0](
51+
name="A Parent",
52+
)
53+
child = related_model_classes[1](name="A Child")
54+
child2 = related_model_classes[1](name="Another Child")
55+
parent.children.append(child)
56+
parent.children.append(child2)
57+
repo = related_repository_class(sa_manager.get_bind())
58+
await repo.save(parent)
59+
60+
retrieved_parent = await repo.get(parent.parent_model_id)
61+
assert len(retrieved_parent.children) == 2
62+
63+
await repo.delete(retrieved_parent)
64+
65+
async with repo._get_session() as session:
66+
result = [
67+
x
68+
for x in (await session.execute(select(related_model_classes[1]))).scalars()
69+
]
70+
assert len(result) == 0
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from sqlalchemy import select
3+
4+
5+
async def test_can_delete_by_instance(repository_class, model_class, sa_manager):
6+
model = model_class(
7+
model_id=1,
8+
name="Someone",
9+
)
10+
model2 = model_class(
11+
model_id=2,
12+
name="SomeoneElse",
13+
)
14+
repo = repository_class(sa_manager.get_bind())
15+
await repo.save_many({model, model2})
16+
17+
results = [x for x in await repo.find()]
18+
assert len(results) == 2
19+
20+
await repo.delete_many([model])
21+
results = [x for x in await repo.find()]
22+
assert len(results) == 1
23+
assert results[0].model_id == 2
24+
assert results[0].name == "SomeoneElse"
25+
26+
27+
async def test_delete_inexistent_raises_exception(
28+
repository_class, model_class, sa_manager
29+
):
30+
repo = repository_class(sa_manager.get_bind())
31+
32+
results = [x for x in await repo.find()]
33+
assert len(results) == 0
34+
35+
with pytest.raises(Exception):
36+
await repo.delete_many([4])
37+
38+
with pytest.raises(Exception):
39+
await repo.delete_many(
40+
[
41+
model_class(
42+
model_id=823,
43+
name="Someone",
44+
)
45+
]
46+
)
47+
48+
49+
async def test_relationships_are_respected(
50+
related_repository_class, related_model_classes, sa_manager
51+
):
52+
parent = related_model_classes[0](
53+
name="A Parent",
54+
)
55+
child = related_model_classes[1](name="A Child")
56+
child2 = related_model_classes[1](name="Another Child")
57+
parent.children.append(child)
58+
parent.children.append(child2)
59+
repo = related_repository_class(sa_manager.get_bind())
60+
await repo.save(parent)
61+
62+
retrieved_parent = await repo.get(parent.parent_model_id)
63+
assert len(retrieved_parent.children) == 2
64+
65+
await repo.delete_many([retrieved_parent])
66+
67+
async with repo._get_session() as session:
68+
result = [
69+
x
70+
for x in (await session.execute(select(related_model_classes[1]))).scalars()
71+
]
72+
assert len(result) == 0

tests/repository/async_/test_get.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@ async def test_get_returns_model(repository_class, model_class, sa_manager):
1919
assert isinstance(result, model_class)
2020

2121

22+
async def test_get_many_returns_models(repository_class, model_class, sa_manager):
23+
model = model_class(
24+
model_id=1,
25+
name="Someone",
26+
)
27+
model2 = model_class(
28+
model_id=2,
29+
name="SomeoneElse",
30+
)
31+
model3 = model_class(
32+
model_id=3,
33+
name="StillSomeoneElse",
34+
)
35+
repo = repository_class(sa_manager.get_bind())
36+
await repo.save_many({model, model2, model3})
37+
38+
result = await repo.get_many([1, 2])
39+
assert isinstance(result, list)
40+
assert len(result) == 2
41+
assert result[0].model_id == 1
42+
assert result[1].model_id == 2
43+
44+
45+
async def test_get_many_returns_empty_list_if_nothing_found(
46+
repository_class, model_class, sa_manager
47+
):
48+
repo = repository_class(sa_manager.get_bind())
49+
50+
result = await repo.get_many([1, 2])
51+
assert isinstance(result, list)
52+
assert len(result) == 0
53+
54+
2255
async def test_get_raises_exception_if_not_found(repository_class, sa_manager):
2356
repo = repository_class(sa_manager.get_bind())
2457

tests/repository/sync/test_delete.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,5 @@
11
import pytest
2-
3-
4-
def test_can_delete_by_pk(repository_class, model_class, sa_manager):
5-
model = model_class(
6-
model_id=1,
7-
name="Someone",
8-
)
9-
model2 = model_class(
10-
model_id=2,
11-
name="SomeoneElse",
12-
)
13-
repo = repository_class(sa_manager.get_bind())
14-
repo.save_many({model, model2})
15-
16-
results = [x for x in repo.find()]
17-
assert len(results) == 2
18-
19-
repo.delete(1)
20-
results = [x for x in repo.find()]
21-
assert len(results) == 1
22-
assert results[0].model_id == 2
23-
assert results[0].name == "SomeoneElse"
2+
from sqlalchemy import select
243

254

265
def test_can_delete_by_instance(repository_class, model_class, sa_manager):
@@ -53,3 +32,36 @@ def test_delete_inexistent_raises_exception(repository_class, model_class, sa_ma
5332

5433
with pytest.raises(Exception):
5534
repo.delete(4)
35+
36+
with pytest.raises(Exception):
37+
repo.delete(
38+
model_class(
39+
model_id=823,
40+
name="Someone",
41+
)
42+
)
43+
44+
45+
def test_relationships_are_respected(
46+
related_repository_class, related_model_classes, sa_manager
47+
):
48+
parent = related_model_classes[0](
49+
name="A Parent",
50+
)
51+
child = related_model_classes[1](name="A Child")
52+
child2 = related_model_classes[1](name="Another Child")
53+
parent.children.append(child)
54+
parent.children.append(child2)
55+
repo = related_repository_class(sa_manager.get_bind())
56+
repo.save(parent)
57+
58+
retrieved_parent = repo.get(parent.parent_model_id)
59+
assert len(retrieved_parent.children) == 2
60+
61+
repo.delete(retrieved_parent)
62+
63+
with repo._get_session() as session:
64+
result = [
65+
x for x in session.execute(select(related_model_classes[1])).scalars()
66+
]
67+
assert len(result) == 0

0 commit comments

Comments
 (0)