Skip to content

Commit c7f1e0c

Browse files
committed
feat: Allow lambda functions in mappings
1 parent 0e91f49 commit c7f1e0c

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

README.md

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class Address:
149149
number: int
150150
zip_code: int
151151
city: str
152-
152+
153153
class PersonInfo:
154154
def __init__(self, name: str, age: int, address: Address):
155155
self.name = name
@@ -181,6 +181,42 @@ print("Target public_info.address is same as source address: ", address is publi
181181
* [TortoiseORM](https://github.com/tortoise/tortoise-orm)
182182
* [SQLAlchemy](https://www.sqlalchemy.org/)
183183

184+
## Complexer mapping registration
185+
186+
Support for defining mappings using `lambda`s.
187+
188+
```python
189+
class AgeGroup(Enum):
190+
CHILD = "child"
191+
TEENAGER = "teenager"
192+
ADULT = "adult"
193+
SENIOR = "senior"
194+
195+
class UserInfo:
196+
def __init__(self, name: str, profession: str, age: int):
197+
self.name = name
198+
self.profession = profession
199+
self.age = age
200+
201+
class PublicUserInfo:
202+
def __init__(self, name: str, profession: str, age_group: AgeGroup):
203+
self.name = name
204+
self.profession = profession
205+
self.age_group
206+
207+
mapper.add(UserInfo, PublicUserInfo, fields_mapping={)
208+
"age_group": lambda user: (
209+
AgeGroup.CHILD if user.age < 13 else
210+
AgeGroup.TEENAGER if user.age < 20 else
211+
AgeGroup.ADULT if user.age < 65 else
212+
AgeGroup.SENIOR
213+
)
214+
})
215+
216+
mapper.map(UserInfo("John Malkovich", "engineer", 35))
217+
# {'name': 'John Malkovich', 'profession': 'engineer', 'age_group': <AgeGroup.ADULT: 'adult'>}
218+
```
219+
184220
## Pydantic/FastAPI Support
185221
Out of the box Pydantic models support:
186222
```python
@@ -273,7 +309,7 @@ class PublicUserInfo(Base):
273309
id = Column(Integer, primary_key=True)
274310
public_name = Column(String)
275311
hobbies = Column(String)
276-
312+
277313
obj = UserInfo(
278314
id=2,
279315
full_name="Danny DeVito",
@@ -304,7 +340,7 @@ class TargetClass:
304340
def __init__(self, **kwargs):
305341
self.name = kwargs["name"]
306342
self.age = kwargs["age"]
307-
343+
308344
@staticmethod
309345
def get_fields(cls):
310346
return ["name", "age"]
@@ -358,7 +394,7 @@ T = TypeVar("T")
358394

359395
def class_has_fields_property(target_cls: Type[T]) -> bool:
360396
return callable(getattr(target_cls, "fields", None))
361-
397+
362398
mapper.add_spec(class_has_fields_property, lambda t: getattr(t, "fields")())
363399

364400
target_obj = mapper.to(TargetClass).map(source_obj)

automapper/mapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
T = TypeVar("T")
2929
ClassifierFunction = Callable[[Type[T]], bool]
3030
SpecFunction = Callable[[Type[T]], Iterable[str]]
31-
FieldsMap = Optional[Dict[str, Any]]
31+
FieldsMap = Optional[Dict[str, Union[Callable[[S], Any], Any]]]
3232

3333

3434
def _try_get_field_value(
3535
field_name: str, original_obj: Any, custom_mapping: FieldsMap
3636
) -> Tuple[bool, Any]:
37-
if field_name in (custom_mapping or {}):
37+
if field_name in (custom_mapping or {}): # type: ignore [index]
38+
if callable(custom_mapping[field_name]): # type: ignore [index]
39+
return True, custom_mapping[field_name](original_obj) # type: ignore [index]
3840
return True, custom_mapping[field_name] # type: ignore [index]
3941
if hasattr(original_obj, field_name):
4042
return True, getattr(original_obj, field_name)
@@ -184,7 +186,8 @@ def map(
184186
obj (object): Source object to map to `target class`.
185187
skip_none_values (bool, optional): Skip None values when creating `target class` obj. Defaults to False.
186188
fields_mapping (FieldsMap, optional): Custom mapping.
187-
Specify dictionary in format {"field_name": value_object}. Defaults to None.
189+
Specify dictionary in format {"field_name": value_object | lambda soure_obj}. Can take a lamdba
190+
funtion as argument, that will get the source_cls as argument. Defaults to None.
188191
use_deepcopy (bool, optional): Apply deepcopy to all child objects when copy from source to target object.
189192
Defaults to True.
190193

tests/test_predefined_mapping.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def __init__(self, text: Optional[str], num: int) -> None:
3030
self.text = text
3131
self.num = num
3232

33+
class ComplexClass:
34+
def __init__(self, text: Optional[str], num: int) -> None:
35+
self.data = AnotherClass(text, num)
3336

3437
class ClassWithoutInitAttrDef:
3538
def __init__(self, **kwargs: Any) -> None:
@@ -138,3 +141,15 @@ def test_map__pass_none_values_from_source_object(self):
138141
assert "num" in obj.data
139142
assert obj.data.get("text") is None
140143
assert obj.data.get("num") == 11
144+
145+
def test_add__lambda_resolver_works_with_lambda_function(self):
146+
self.mapper.add(ComplexClass, AnotherClass, fields_mapping={
147+
"text": lambda x: x.data.text.upper(),
148+
"num": lambda x: x.data.num * 2
149+
})
150+
result: AnotherClass = self.mapper.map(ComplexClass("test_message", 10))
151+
152+
assert isinstance(result, AnotherClass)
153+
assert result.text == "TEST_MESSAGE"
154+
assert result.num == 20
155+

0 commit comments

Comments
 (0)