Skip to content

Commit 3668ccb

Browse files
authored
Merge pull request #45 from febus982/fail_operations_for_other_models
Fail operations on models not belonging to repository
2 parents 7bf4193 + e6e7e71 commit 3668ccb

File tree

4 files changed

+35
-0
lines changed

4 files changed

+35
-0
lines changed

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ async def _get_session(self, commit: bool = True) -> AsyncIterator[AsyncSession]
6666
yield self._external_session
6767

6868
async def save(self, instance: MODEL) -> MODEL:
69+
self._fail_if_invalid_models([instance])
6970
async with self._get_session() as session:
7071
session.add(instance)
7172
return instance
@@ -74,6 +75,7 @@ async def save_many(
7475
self,
7576
instances: Iterable[MODEL],
7677
) -> Iterable[MODEL]:
78+
self._fail_if_invalid_models(instances)
7779
async with self._get_session() as session:
7880
session.add_all(instances)
7981
return instances
@@ -94,10 +96,12 @@ async def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
9496
return [x for x in (await session.execute(stmt)).scalars()]
9597

9698
async def delete(self, instance: MODEL) -> None:
99+
self._fail_if_invalid_models([instance])
97100
async with self._get_session() as session:
98101
await session.delete(instance)
99102

100103
async def delete_many(self, instances: Iterable[MODEL]) -> None:
104+
self._fail_if_invalid_models(instances)
101105
async with self._get_session() as session:
102106
for instance in instances:
103107
await session.delete(instance)

sqlalchemy_bind_manager/_repository/base_repository.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,7 @@ def _model_pk(self) -> str:
318318
raise NotImplementedError("Composite primary keys are not supported.")
319319

320320
return primary_keys[0].name
321+
322+
def _fail_if_invalid_models(self, objects: Iterable[MODEL]) -> None:
323+
if [x for x in objects if not isinstance(x, self._model)]:
324+
raise InvalidModel("Cannot handle models not belonging to this repository")

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@ def _get_session(self, commit: bool = True) -> Iterator[Session]:
5858
yield self._external_session
5959

6060
def save(self, instance: MODEL) -> MODEL:
61+
self._fail_if_invalid_models([instance])
6162
with self._get_session() as session:
6263
session.add(instance)
6364
return instance
6465

6566
def save_many(self, instances: Iterable[MODEL]) -> Iterable[MODEL]:
67+
self._fail_if_invalid_models(instances)
6668
with self._get_session() as session:
6769
session.add_all(instances)
6870
return instances
@@ -83,10 +85,12 @@ def get_many(self, identifiers: Iterable[PRIMARY_KEY]) -> List[MODEL]:
8385
return [x for x in session.execute(stmt).scalars()]
8486

8587
def delete(self, instance: MODEL) -> None:
88+
self._fail_if_invalid_models([instance])
8689
with self._get_session() as session:
8790
session.delete(instance)
8891

8992
def delete_many(self, instances: Iterable[MODEL]) -> None:
93+
self._fail_if_invalid_models(instances)
9094
with self._get_session() as session:
9195
for model in instances:
9296
session.delete(model)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from sqlalchemy_bind_manager.exceptions import InvalidModel
4+
5+
6+
async def test_fails_when_saving_models_not_belonging_to_repository(
7+
repository_class, model_classes, sa_bind, sync_async_wrapper
8+
):
9+
repo = repository_class(bind=sa_bind, model_class=model_classes[0])
10+
11+
invalid_model = model_classes[1](name="A Child")
12+
13+
with pytest.raises(InvalidModel):
14+
await sync_async_wrapper(repo.save(invalid_model))
15+
16+
with pytest.raises(InvalidModel):
17+
await sync_async_wrapper(repo.save_many([invalid_model]))
18+
19+
with pytest.raises(InvalidModel):
20+
await sync_async_wrapper(repo.delete(invalid_model))
21+
22+
with pytest.raises(InvalidModel):
23+
await sync_async_wrapper(repo.delete_many([invalid_model]))

0 commit comments

Comments
 (0)