diff --git a/README.md b/README.md index 01f13bc..96da25a 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ class Address: number: int zip_code: int city: str - + class PersonInfo: def __init__(self, name: str, age: int, address: Address): self.name = name @@ -181,6 +181,42 @@ print("Target public_info.address is same as source address: ", address is publi * [TortoiseORM](https://github.com/tortoise/tortoise-orm) * [SQLAlchemy](https://www.sqlalchemy.org/) +## Complexer mapping registration + +Support for defining mappings using `lambda`s. + +```python +class AgeGroup(Enum): + CHILD = "child" + TEENAGER = "teenager" + ADULT = "adult" + SENIOR = "senior" + +class UserInfo: + def __init__(self, name: str, profession: str, age: int): + self.name = name + self.profession = profession + self.age = age + +class PublicUserInfo: + def __init__(self, name: str, profession: str, age_group: AgeGroup): + self.name = name + self.profession = profession + self.age_group + +mapper.add(UserInfo, PublicUserInfo, fields_mapping={) + "age_group": lambda user: ( + AgeGroup.CHILD if user.age < 13 else + AgeGroup.TEENAGER if user.age < 20 else + AgeGroup.ADULT if user.age < 65 else + AgeGroup.SENIOR + ) +}) + +mapper.map(UserInfo("John Malkovich", "engineer", 35)) +# {'name': 'John Malkovich', 'profession': 'engineer', 'age_group': } +``` + ## Pydantic/FastAPI Support Out of the box Pydantic models support: ```python @@ -273,7 +309,7 @@ class PublicUserInfo(Base): id = Column(Integer, primary_key=True) public_name = Column(String) hobbies = Column(String) - + obj = UserInfo( id=2, full_name="Danny DeVito", @@ -304,7 +340,7 @@ class TargetClass: def __init__(self, **kwargs): self.name = kwargs["name"] self.age = kwargs["age"] - + @staticmethod def get_fields(cls): return ["name", "age"] @@ -358,7 +394,7 @@ T = TypeVar("T") def class_has_fields_property(target_cls: Type[T]) -> bool: return callable(getattr(target_cls, "fields", None)) - + mapper.add_spec(class_has_fields_property, lambda t: getattr(t, "fields")()) target_obj = mapper.to(TargetClass).map(source_obj) diff --git a/automapper/mapper.py b/automapper/mapper.py index ea1be9a..7fc658d 100644 --- a/automapper/mapper.py +++ b/automapper/mapper.py @@ -28,13 +28,15 @@ T = TypeVar("T") ClassifierFunction = Callable[[Type[T]], bool] SpecFunction = Callable[[Type[T]], Iterable[str]] -FieldsMap = Optional[Dict[str, Any]] +FieldsMap = Optional[Dict[str, Union[Callable[[S], Any], Any]]] def _try_get_field_value( field_name: str, original_obj: Any, custom_mapping: FieldsMap ) -> Tuple[bool, Any]: - if field_name in (custom_mapping or {}): + if field_name in (custom_mapping or {}): # type: ignore [index] + if callable(custom_mapping[field_name]): # type: ignore [index] + return True, custom_mapping[field_name](original_obj) # type: ignore [index] return True, custom_mapping[field_name] # type: ignore [index] if hasattr(original_obj, field_name): return True, getattr(original_obj, field_name) @@ -184,7 +186,8 @@ def map( obj (object): Source object to map to `target class`. skip_none_values (bool, optional): Skip None values when creating `target class` obj. Defaults to False. fields_mapping (FieldsMap, optional): Custom mapping. - Specify dictionary in format {"field_name": value_object}. Defaults to None. + Specify dictionary in format {"field_name": value_object | lambda soure_obj}. Can take a lamdba + funtion as argument, that will get the source_cls as argument. Defaults to None. use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object. Defaults to True. diff --git a/tests/test_predefined_mapping.py b/tests/test_predefined_mapping.py index 411422e..53d7029 100644 --- a/tests/test_predefined_mapping.py +++ b/tests/test_predefined_mapping.py @@ -30,6 +30,9 @@ def __init__(self, text: Optional[str], num: int) -> None: self.text = text self.num = num +class ComplexClass: + def __init__(self, text: Optional[str], num: int) -> None: + self.data = AnotherClass(text, num) class ClassWithoutInitAttrDef: def __init__(self, **kwargs: Any) -> None: @@ -138,3 +141,15 @@ def test_map__pass_none_values_from_source_object(self): assert "num" in obj.data assert obj.data.get("text") is None assert obj.data.get("num") == 11 + + def test_add__lambda_resolver_works_with_lambda_function(self): + self.mapper.add(ComplexClass, AnotherClass, fields_mapping={ + "text": lambda x: x.data.text.upper(), + "num": lambda x: x.data.num * 2 + }) + result: AnotherClass = self.mapper.map(ComplexClass("test_message", 10)) + + assert isinstance(result, AnotherClass) + assert result.text == "TEST_MESSAGE" + assert result.num == 20 +