Skip to content

Commit cab3914

Browse files
authored
Merge pull request #23 from febus982/fix_model_validation
Use class to validate if the model is handled by SQLAlchemy
2 parents afc5181 + 628943d commit cab3914

File tree

3 files changed

+13
-19
lines changed

3 files changed

+13
-19
lines changed

sqlalchemy_bind_manager/_repository/async_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def delete(
101101
:type entity: Union[MODEL, PRIMARY_KEY]
102102
"""
103103
# TODO: delete without loading the model
104-
obj = entity if self._is_mapped_object(entity) else await self.get(entity) # type: ignore
104+
obj = entity if isinstance(entity, self._model) else await self.get(entity) # type: ignore
105105
async with self._get_session() as session:
106106
await session.delete(obj)
107107

sqlalchemy_bind_manager/_repository/common.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
from pydantic.generics import GenericModel
1919
from sqlalchemy import asc, desc, select, func
20-
from sqlalchemy.orm import object_mapper, class_mapper, Mapper, lazyload
21-
from sqlalchemy.orm.exc import UnmappedInstanceError
20+
from sqlalchemy.orm import class_mapper, Mapper, lazyload
21+
from sqlalchemy.orm.exc import UnmappedClassError
2222
from sqlalchemy.sql import Select
2323

2424
from sqlalchemy_bind_manager.exceptions import InvalidModel, UnmappedProperty
@@ -48,31 +48,25 @@ def __init__(self, model_class: Union[Type[MODEL], None] = None) -> None:
4848
if getattr(self, "_model", None) is None and model_class is not None:
4949
self._model = model_class
5050

51-
if getattr(self, "_model", None) is None or not self._is_mapped_object(
52-
self._model()
51+
if getattr(self, "_model", None) is None or not self._is_mapped_class(
52+
self._model
5353
):
5454
raise InvalidModel(
5555
"You need to supply a valid model class either in the `model_class` parameter"
5656
" or in the `_model` class property."
5757
)
5858

59-
def _is_mapped_object(self, obj: object) -> bool:
60-
"""Checks if the object is handled by the repository and is mapped in SQLAlchemy.
59+
def _is_mapped_class(self, class_: Type[MODEL]) -> bool:
60+
"""Checks if the class is mapped in SQLAlchemy.
6161
62-
:param obj: a mapped object instance
63-
:return: True if the object is mapped and matches self._model type, False if it's not a mapped object
62+
:param class_: the model class
63+
:return: True if the Type is mapped, False otherwise
6464
:rtype: bool
65-
:raises InvalidModel: when the object is mapped but doesn't match self._model type
6665
"""
67-
# TODO: This is probably redundant, we could do these checks once in __init__
6866
try:
69-
object_mapper(obj)
70-
if isinstance(obj, self._model):
71-
return True
72-
raise InvalidModel(
73-
f"This repository can handle only `{self._model}` models. `{type(obj)}` has been passed."
74-
)
75-
except UnmappedInstanceError:
67+
class_mapper(class_)
68+
return True
69+
except UnmappedClassError:
7670
return False
7771

7872
def _validate_mapped_property(self, property_name: str) -> None:

sqlalchemy_bind_manager/_repository/sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def delete(self, entity: Union[MODEL, PRIMARY_KEY]) -> None:
9696
:type entity: Union[MODEL, PRIMARY_KEY]
9797
"""
9898
# TODO: delete without loading the model
99-
obj = entity if self._is_mapped_object(entity) else self.get(entity) # type: ignore
99+
obj = entity if isinstance(entity, self._model) else self.get(entity) # type: ignore
100100
with self._get_session() as session:
101101
session.delete(obj)
102102

0 commit comments

Comments
 (0)