Skip to content

Commit bfc2d1a

Browse files
committed
Implement delete_many
1 parent 415a7aa commit bfc2d1a

File tree

7 files changed

+190
-5
lines changed

7 files changed

+190
-5
lines changed

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ async def delete(self, instance: MODEL) -> None:
9797
async with self._get_session() as session:
9898
await session.delete(instance)
9999

100+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
101+
async with self._get_session() as session:
102+
for instance in instances:
103+
await session.delete(instance)
104+
100105
async def find(
101106
self,
102107
search_params: Union[None, Mapping[str, Any]] = None,

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def delete(self, instance: MODEL) -> None:
8686
with self._get_session() as session:
8787
session.delete(instance)
8888

89+
def delete_many(self, instances: Iterable[MODEL]) -> None:
90+
with self._get_session() as session:
91+
for model in instances:
92+
session.delete(model)
93+
8994
def find(
9095
self,
9196
search_params: Union[None, Mapping[str, Any]] = None,

sqlalchemy_bind_manager/protocols.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@ async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
6262
async def delete(self, instance: MODEL) -> None:
6363
"""Deletes a model.
6464
65-
:param instance: The model instance or the primary key
65+
:param instance: The model instance
66+
"""
67+
...
68+
69+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
70+
"""Deletes a collection of models in a single transaction.
71+
72+
:param instances: The model instances
6673
"""
6774
...
6875

@@ -211,7 +218,14 @@ def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
211218
def delete(self, instance: MODEL) -> None:
212219
"""Deletes a model.
213220
214-
:param instance: The model instance or the primary key
221+
:param instance: The model instance
222+
"""
223+
...
224+
225+
async def delete_many(self, instances: Iterable[MODEL]) -> None:
226+
"""Deletes a collection of models in a single transaction.
227+
228+
:param instances: The model instances
215229
"""
216230
...
217231

tests/repository/async_/test_delete.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ async def test_delete_inexistent_raises_exception(
3535
with pytest.raises(Exception):
3636
await repo.delete(4)
3737

38+
with pytest.raises(Exception):
39+
await repo.delete(
40+
model_class(
41+
model_id=823,
42+
name="Someone",
43+
)
44+
)
45+
3846

3947
async def test_relationships_are_respected(
4048
related_repository_class, related_model_classes, sa_manager
@@ -55,5 +63,8 @@ async def test_relationships_are_respected(
5563
await repo.delete(retrieved_parent)
5664

5765
async with repo._get_session() as session:
58-
result = [x for x in (await session.execute(select(related_model_classes[1]))).scalars()]
66+
result = [
67+
x
68+
for x in (await session.execute(select(related_model_classes[1]))).scalars()
69+
]
5970
assert len(result) == 0
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from sqlalchemy import select
3+
4+
5+
async def test_can_delete_by_instance(repository_class, model_class, sa_manager):
6+
model = model_class(
7+
model_id=1,
8+
name="Someone",
9+
)
10+
model2 = model_class(
11+
model_id=2,
12+
name="SomeoneElse",
13+
)
14+
repo = repository_class(sa_manager.get_bind())
15+
await repo.save_many({model, model2})
16+
17+
results = [x for x in await repo.find()]
18+
assert len(results) == 2
19+
20+
await repo.delete_many([model])
21+
results = [x for x in await repo.find()]
22+
assert len(results) == 1
23+
assert results[0].model_id == 2
24+
assert results[0].name == "SomeoneElse"
25+
26+
27+
async def test_delete_inexistent_raises_exception(
28+
repository_class, model_class, sa_manager
29+
):
30+
repo = repository_class(sa_manager.get_bind())
31+
32+
results = [x for x in await repo.find()]
33+
assert len(results) == 0
34+
35+
with pytest.raises(Exception):
36+
await repo.delete_many([4])
37+
38+
with pytest.raises(Exception):
39+
await repo.delete_many(
40+
[
41+
model_class(
42+
model_id=823,
43+
name="Someone",
44+
)
45+
]
46+
)
47+
48+
49+
async def test_relationships_are_respected(
50+
related_repository_class, related_model_classes, sa_manager
51+
):
52+
parent = related_model_classes[0](
53+
name="A Parent",
54+
)
55+
child = related_model_classes[1](name="A Child")
56+
child2 = related_model_classes[1](name="Another Child")
57+
parent.children.append(child)
58+
parent.children.append(child2)
59+
repo = related_repository_class(sa_manager.get_bind())
60+
await repo.save(parent)
61+
62+
retrieved_parent = await repo.get(parent.parent_model_id)
63+
assert len(retrieved_parent.children) == 2
64+
65+
await repo.delete_many([retrieved_parent])
66+
67+
async with repo._get_session() as session:
68+
result = [
69+
x
70+
for x in (await session.execute(select(related_model_classes[1]))).scalars()
71+
]
72+
assert len(result) == 0

tests/repository/sync/test_delete.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
32
from sqlalchemy import select
43

54

@@ -34,6 +33,14 @@ def test_delete_inexistent_raises_exception(repository_class, model_class, sa_ma
3433
with pytest.raises(Exception):
3534
repo.delete(4)
3635

36+
with pytest.raises(Exception):
37+
repo.delete(
38+
model_class(
39+
model_id=823,
40+
name="Someone",
41+
)
42+
)
43+
3744

3845
def test_relationships_are_respected(
3946
related_repository_class, related_model_classes, sa_manager
@@ -54,5 +61,7 @@ def test_relationships_are_respected(
5461
repo.delete(retrieved_parent)
5562

5663
with repo._get_session() as session:
57-
result = [x for x in session.execute(select(related_model_classes[1])).scalars()]
64+
result = [
65+
x for x in session.execute(select(related_model_classes[1])).scalars()
66+
]
5867
assert len(result) == 0
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
from sqlalchemy import select
3+
4+
5+
def test_can_delete_by_instance(repository_class, model_class, sa_manager):
6+
model = model_class(
7+
model_id=1,
8+
name="Someone",
9+
)
10+
model2 = model_class(
11+
model_id=2,
12+
name="SomeoneElse",
13+
)
14+
repo = repository_class(sa_manager.get_bind())
15+
repo.save_many({model, model2})
16+
17+
results = [x for x in repo.find()]
18+
assert len(results) == 2
19+
20+
repo.delete_many([model])
21+
results = [x for x in repo.find()]
22+
assert len(results) == 1
23+
assert results[0].model_id == 2
24+
assert results[0].name == "SomeoneElse"
25+
26+
27+
def test_delete_inexistent_raises_exception(repository_class, model_class, sa_manager):
28+
repo = repository_class(sa_manager.get_bind())
29+
30+
results = [x for x in repo.find()]
31+
assert len(results) == 0
32+
33+
with pytest.raises(Exception):
34+
repo.delete_many([4])
35+
36+
with pytest.raises(Exception):
37+
repo.delete_many(
38+
[
39+
model_class(
40+
model_id=823,
41+
name="Someone",
42+
)
43+
]
44+
)
45+
46+
47+
def test_relationships_are_respected(
48+
related_repository_class, related_model_classes, sa_manager
49+
):
50+
parent = related_model_classes[0](
51+
name="A Parent",
52+
)
53+
child = related_model_classes[1](name="A Child")
54+
child2 = related_model_classes[1](name="Another Child")
55+
parent.children.append(child)
56+
parent.children.append(child2)
57+
repo = related_repository_class(sa_manager.get_bind())
58+
repo.save(parent)
59+
60+
retrieved_parent = repo.get(parent.parent_model_id)
61+
assert len(retrieved_parent.children) == 2
62+
63+
repo.delete_many([retrieved_parent])
64+
65+
with repo._get_session() as session:
66+
result = [
67+
x for x in session.execute(select(related_model_classes[1])).scalars()
68+
]
69+
assert len(result) == 0

0 commit comments

Comments
 (0)