diff --git a/app/tests/test_mapping_serializer.py b/app/tests/test_mapping_serializer.py index f6aacfd..59fb91e 100644 --- a/app/tests/test_mapping_serializer.py +++ b/app/tests/test_mapping_serializer.py @@ -252,3 +252,50 @@ def test_mapping_serializer_update(self) -> None: self.assertEqual(2, instance.addresses.count()) self.assertEqual(person, instance.addresses.first().target) self.assertEqual(person, instance.addresses.last().target) + + @override_config(MODEL_MAPPING_FIELD=MODEL_MAPPING_FIELD) + def test_list_mapping_serializer_create(self) -> None: + koeniz = ElectionDistrictFactory(title="Koeniz") + muri = ElectionDistrictFactory(title="Muri") + + data = [ + { + "external_firstname": "Hugo", + "external_lastname": "Boss", + "external_election_district_title": "Koeniz", + "external_addresses": [ + self.address_1.external_uid, + self.address_2.external_uid, + ], + }, + { + "external_firstname": "Stefanie", + "external_lastname": "Muster", + "external_election_district_title": "Muri", + "external_addresses": [ + self.address_3.external_uid, + ], + }, + ] + + serializer = PersonMappingSerializer(data=data, many=True) + self.assertTrue(serializer.is_valid(raise_exception=True)) + serializer.save() + + hugo = Person.objects.get(firstname="Hugo") + stefanie = Person.objects.get(firstname="Stefanie") + + self.assertEqual("Hugo", hugo.firstname) + self.assertEqual("Boss", hugo.lastname) + + self.assertEqual("Stefanie", stefanie.firstname) + self.assertEqual("Muster", stefanie.lastname) + + self.assertEqual(2, ElectionDistrict.objects.count()) + self.assertEqual(koeniz, hugo.election_district) + self.assertEqual(muri, stefanie.election_district) + + self.assertEqual(2, hugo.addresses.count()) + self.assertEqual(1, stefanie.addresses.count()) + self.assertEqual(hugo, hugo.addresses.first().target) + self.assertEqual(stefanie, stefanie.addresses.last().target) diff --git a/app/tests/test_mapping_serializer_data.py b/app/tests/test_mapping_serializer_data.py index 9312407..24556d4 100644 --- a/app/tests/test_mapping_serializer_data.py +++ b/app/tests/test_mapping_serializer_data.py @@ -70,3 +70,86 @@ def test_mapping_serializer_map_initial_data(self) -> None: mapped_data = TestMappingSerializer().map_data(data) self.assertEqual(mapped_data, expected_data) + + def test_list_mapping_serializer_map_initial_data(self) -> None: + data = [ + { + "external_base_field": "base_value", + "external_single_field_1": "nested_value_1", + "external_single_field_2": "nested_value_2", + "external_dict_field": {"nested_field": "single_value"}, + "external_object_field": { + "nested_external_field_1": "nested_value_1", + "nested_external_field_2": "nested_value_2", + }, + "external_object_field_with_object": { + "external_object_field_1": { + "external_field_1": "value_1", + "external_field_2": "value_2", + }, + "external_object_field_2": { + "external_field_1": "value_1", + "external_field_2": "value_2", + }, + }, + }, + { + "external_base_field": "other_value", + "external_single_field_1": "nested_value_3", + "external_single_field_2": "nested_value_4", + "external_dict_field": {"nested_field": "other_value"}, + "external_object_field": { + "nested_external_field_1": "nested_value_3", + "nested_external_field_2": "nested_value_4", + }, + "external_object_field_with_object": { + "external_object_field_1": { + "external_field_1": "value_3", + "external_field_2": "value_4", + }, + "external_object_field_2": { + "external_field_1": "value_3", + "external_field_2": "value_4", + }, + }, + }, + ] + + expected_data = [ + { + "base_field": "base_value", + "dict_field": { + "nested_field_1": "nested_value_1", + "nested_field_2": "nested_value_2", + }, + "single_field": "single_value", + "object_field": { + "nested_field_1": "nested_value_1", + "nested_field_2": "nested_value_2", + }, + "object_field_with_object": { + "object_field_1": {"field_1": "value_1", "field_2": "value_2"}, + "object_field_2": {"field_1": "value_1", "field_2": "value_2"}, + }, + }, + { + "base_field": "other_value", + "dict_field": { + "nested_field_1": "nested_value_3", + "nested_field_2": "nested_value_4", + }, + "single_field": "other_value", + "object_field": { + "nested_field_1": "nested_value_3", + "nested_field_2": "nested_value_4", + }, + "object_field_with_object": { + "object_field_1": {"field_1": "value_3", "field_2": "value_4"}, + "object_field_2": {"field_1": "value_3", "field_2": "value_4"}, + }, + }, + ] + + serializer = TestMappingSerializer(many=True) + mapped_data = serializer.map_list_data(data) + self.assertEqual(mapped_data, expected_data) diff --git a/changes/TI-2893.bugfix b/changes/TI-2893.bugfix new file mode 100644 index 0000000..beb5dc8 --- /dev/null +++ b/changes/TI-2893.bugfix @@ -0,0 +1 @@ +Fix MappingSerializer if many is true. [TI-2893](https://4teamwork.atlassian.net/browse/TI-2893>) diff --git a/django_features/serializers.py b/django_features/serializers.py index 75b4aab..60310f4 100644 --- a/django_features/serializers.py +++ b/django_features/serializers.py @@ -5,6 +5,7 @@ from django.core.exceptions import ValidationError from django.db import models from django.db.models import NOT_PROVIDED +from rest_framework import serializers from rest_framework.fields import empty from rest_framework.relations import ManyRelatedField @@ -12,8 +13,66 @@ from django_features.fields import UUIDRelatedField -class BaseMappingSerializer(CustomFieldBaseModelSerializer): +class PropertySerializer(serializers.Serializer): relation_separator: str = "." + + class Meta: + abstract = True + fields = "__all__" + model = None + + @property + def mapping(self) -> dict[str, dict[str, Any]]: + if getattr(self, "_mapping") is None: + raise ValueError( + "Property 'mapping' on instance must be set and can't be 'None'" + ) + return self._mapping + + @mapping.setter + def mapping(self, value: dict[str, dict[str, Any]]) -> None: + self._mapping = value + + @property + def mapping_fields(self) -> list[str]: + mapping_fields = getattr( + self, "_mapping_fields", list(self.model_mapping.values()) + ) + if mapping_fields is None: + raise ValueError("Property 'mapping_fields' must be set and can't be 'None") + return mapping_fields + + @mapping_fields.setter + def mapping_fields(self, value: list[str]) -> None: + self._mapping_fields = value + + @property + def model_mapping(self) -> dict[str, Any]: + for key_path in self.mapping.keys(): + key = key_path.split(self.relation_separator)[-1] + if key.lower() == self.model.__name__.lower(): + return self.mapping.get(key_path, {}) + return {} + + @model_mapping.setter + def model_mapping(self, value: dict[str, Any]) -> None: + self._model_mapping = value + + @property + def model(self) -> models.Model: + model = getattr(self, "_model", self.Meta.model) + if model is None: + raise ValueError( + "Property 'model' must be set and can't be 'None. Default is 'Meta.model" + ) + return model + + @model.setter + def model(self, value: models.Model) -> None: + self._model = value + + +class BaseMappingSerializer(CustomFieldBaseModelSerializer, PropertySerializer): serializer_related_field = UUIDRelatedField serializer_related_fields: dict[str, Any] = {} @@ -32,26 +91,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.exclude: list[str] = [] self.related_fields: set[str] = set() - @property - def mapping(self) -> dict[str, dict[str, Any]]: - raise NotImplementedError("Mapping must be set") - - @property - def mapping_fields(self) -> list[str]: - raise NotImplementedError("Mapping fields must be set") - - @property - def model(self) -> models.Model: - if self.Meta.model is None: - raise ValueError("Meta.model must be set") - return self.Meta.model - def get_fields(self) -> dict[str, Any]: initial_fields = super().get_fields() fields: dict[str, Any] = dict() nested_fields: dict[str, Any] = dict() nested_field_fields: dict[str, list[str]] = dict() - self.related_fields: set[str] = set() for internal_name in self.mapping_fields: if internal_name in self.exclude: continue @@ -160,40 +204,16 @@ def __init__( **kwargs: Any, ) -> None: self.exclude = exclude - self.nested_fields = nested_fields - self.parent_mapping = parent_mapping + self.mapping_fields = nested_fields + self.mapping = parent_mapping self.Meta.model = field.related_model super().__init__(*args, **kwargs) - @property - def mapping(self) -> dict[str, dict[str, Any]]: - return self.parent_mapping - @property - def mapping_fields(self) -> list[str]: - return self.nested_fields - - -class MappingSerializer(BaseMappingSerializer): +class DataMappingSerializer(PropertySerializer): _default_prefix = "default" _format_prefix = "format" - class Meta: - abstract = True - fields = "__all__" - model = None - - def __init__( - self, - instance: Any = None, - data: Any = empty, - **kwargs: Any, - ) -> None: - self.instance = instance - self.unmapped_data = data - mapped_data = self.map_data(data) - super().__init__(instance, data=mapped_data, **kwargs) - def _get_nested_data(self, field_path: list[str], data: Any) -> tuple[Any, bool]: field_name = field_path[0] if not isinstance(data, dict): @@ -248,17 +268,69 @@ def map_data(self, initial_data: Any) -> Any: ) return data - @property - def mapping_fields(self) -> list[str]: - return list(self.model_mapping.values()) - @property - def model_mapping(self) -> dict[str, Any]: - mapping = getattr(self, "mapping", None) - if mapping is None: - raise ValueError("Mapping must be set") - for key_path in mapping.keys(): - key = key_path.split(self.relation_separator)[-1] - if key.lower() == self.model.__name__.lower(): - return mapping.get(key_path, {}) - return {} +class ListDataMappingSerializer(serializers.ListSerializer, DataMappingSerializer): + def __init__(self, data: Any = empty, *args: Any, **kwargs: Any) -> None: + self.instance = None + self.mapping = kwargs.pop("mapping", {}) + self.model = kwargs.pop("model") + self.unmapped_data = data if data is not empty else [] + mapped_data = self.map_list_data(self.unmapped_data) + super().__init__(data=mapped_data, *args, **kwargs) + + def map_list_data(self, initial_data: Any) -> list[Any]: + list_data: list[dict[str, Any]] = [] + for item in initial_data: + list_data.append(self.map_data(item)) + return list_data + + +class MappingSerializer(BaseMappingSerializer, DataMappingSerializer): + list_serializer_class = ListDataMappingSerializer + + class Meta: + abstract = True + fields = "__all__" + model = None + + def __init__( + self, + instance: Any = None, + data: Any = empty, + **kwargs: Any, + ) -> None: + self.instance = instance + self.unmapped_data = data + mapped_data = self.map_data(data) + super().__init__(instance, data=mapped_data, **kwargs) + + @classmethod + def many_init(cls, *args: Any, **kwargs: Any) -> ListDataMappingSerializer: + """ + Overwrite the many_init function from the ModelSerializer to change the default listing serializer to the given + list_serializer_class attribute instead of the default ListSerializer. Therefore, the list serializer class can + be set with the attribute list_serializer_class on the serializer class instead of the Meta class. + """ + + list_kwargs = {} + for key in serializers.LIST_SERIALIZER_KWARGS_REMOVE: + value = kwargs.pop(key, None) + if value is not None: + list_kwargs[key] = value + child = cls(*args, **kwargs) + list_kwargs["child"] = child + list_kwargs["mapping"] = getattr(child, "mapping", {}) + list_kwargs.update( + { + key: value + for key, value in kwargs.items() + if key in serializers.LIST_SERIALIZER_KWARGS + } + ) + meta = getattr(cls, "Meta", None) + list_serializer_class = getattr( + meta, "list_serializer_class", cls.list_serializer_class + ) + model = getattr(meta, "model", None) + list_kwargs["model"] = model + return list_serializer_class(*args, **list_kwargs)