|
17 | 17 |
|
18 | 18 | from pydantic.generics import GenericModel |
19 | 19 | from sqlalchemy import asc, desc, select, func |
20 | | -from sqlalchemy.orm import object_mapper, class_mapper, Mapper, lazyload |
21 | | -from sqlalchemy.orm.exc import UnmappedInstanceError |
| 20 | +from sqlalchemy.orm import class_mapper, Mapper, lazyload |
| 21 | +from sqlalchemy.orm.exc import UnmappedClassError |
22 | 22 | from sqlalchemy.sql import Select |
23 | 23 |
|
24 | 24 | from sqlalchemy_bind_manager.exceptions import InvalidModel, UnmappedProperty |
@@ -48,31 +48,25 @@ def __init__(self, model_class: Union[Type[MODEL], None] = None) -> None: |
48 | 48 | if getattr(self, "_model", None) is None and model_class is not None: |
49 | 49 | self._model = model_class |
50 | 50 |
|
51 | | - if getattr(self, "_model", None) is None or not self._is_mapped_object( |
52 | | - self._model() |
| 51 | + if getattr(self, "_model", None) is None or not self._is_mapped_class( |
| 52 | + self._model |
53 | 53 | ): |
54 | 54 | raise InvalidModel( |
55 | 55 | "You need to supply a valid model class either in the `model_class` parameter" |
56 | 56 | " or in the `_model` class property." |
57 | 57 | ) |
58 | 58 |
|
59 | | - def _is_mapped_object(self, obj: object) -> bool: |
60 | | - """Checks if the object is handled by the repository and is mapped in SQLAlchemy. |
| 59 | + def _is_mapped_class(self, class_: Type[MODEL]) -> bool: |
| 60 | + """Checks if the class is mapped in SQLAlchemy. |
61 | 61 |
|
62 | | - :param obj: a mapped object instance |
63 | | - :return: True if the object is mapped and matches self._model type, False if it's not a mapped object |
| 62 | + :param class_: the model class |
| 63 | + :return: True if the Type is mapped, False otherwise |
64 | 64 | :rtype: bool |
65 | | - :raises InvalidModel: when the object is mapped but doesn't match self._model type |
66 | 65 | """ |
67 | | - # TODO: This is probably redundant, we could do these checks once in __init__ |
68 | 66 | try: |
69 | | - object_mapper(obj) |
70 | | - if isinstance(obj, self._model): |
71 | | - return True |
72 | | - raise InvalidModel( |
73 | | - f"This repository can handle only `{self._model}` models. `{type(obj)}` has been passed." |
74 | | - ) |
75 | | - except UnmappedInstanceError: |
| 67 | + class_mapper(class_) |
| 68 | + return True |
| 69 | + except UnmappedClassError: |
76 | 70 | return False |
77 | 71 |
|
78 | 72 | def _validate_mapped_property(self, property_name: str) -> None: |
|
0 commit comments