diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4409bce77c..62e8c40b01 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,6 +4,7 @@ from urllib import parse from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.db import models from django.db.models import Manager from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve @@ -583,3 +584,40 @@ def iter_options(self): cutoff=self.html_cutoff, cutoff_text=self.html_cutoff_text ) + + +class SerializedRelatedField(SlugRelatedField): + """ + A relational field that accepts a simple slug for writes + (like SlugRelatedField), but expands to a nested serializer + for reads if `serializer_class` is provided. + + Example: + class OrderSerializer(serializers.ModelSerializer): + address = SerializedRelatedField( + serializer_class=AddressSerializer, + queryset=Address.objects.all(), + lookup_field="pk", + ) + """ + + def __init__(self, serializer_class=None, lookup_field="pk", **kwargs): + self.serializer_class = serializer_class + kwargs["slug_field"] = lookup_field + super().__init__(**kwargs) + + if self.serializer_class is not None and self.queryset is None: + raise AssertionError( + "SerializedRelatedField with serializer_class requires a queryset" + ) + + def to_representation(self, value): + # Ensure PKOnlyObject (used in select_related/prefetch) is resolved + if hasattr(value, "pk") and not isinstance(value, models.Model): + value = self.get_queryset().get(pk=value.pk) + + if self.serializer_class is not None: + serializer = self.serializer_class(value, context=self.context) + return serializer.data + + return super().to_representation(value) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8fe284bc84..d807c785b1 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -62,6 +62,7 @@ from rest_framework.relations import ( # NOQA # isort:skip HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField, PrimaryKeyRelatedField, RelatedField, SlugRelatedField, StringRelatedField, + SerializedRelatedField ) # Non-field imports, but public API diff --git a/tests/test_relations.py b/tests/test_relations.py index b9ab157896..4cf0138118 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -3,13 +3,16 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.test import override_settings +from django.db import models +from django.test import TestCase, override_settings from django.urls import re_path from django.utils.datastructures import MultiValueDict +from django.utils.translation import gettext_lazy as _ from rest_framework import relations, serializers from rest_framework.fields import empty from rest_framework.test import APISimpleTestCase +from tests.models import RESTFrameworkModel from .utils import ( BadType, MockObject, MockQueryset, fail_reverse, mock_reverse @@ -518,3 +521,102 @@ def test_can_be_pickled(self): upkled = pickle.loads(pickle.dumps(self.default_hyperlink)) assert upkled == self.default_hyperlink assert upkled.name == self.default_hyperlink.name + + +class Address(RESTFrameworkModel): + postal_code = models.CharField( + max_length=20, unique=True, verbose_name=_("Postal Code") + ) + province = models.CharField(max_length=100, verbose_name=_("Province")) + city = models.CharField(max_length=100, verbose_name=_("City")) + street = models.CharField( + max_length=255, blank=True, null=True, verbose_name=_("Street") + ) + additional_info = models.TextField( + verbose_name=_("Additional Info"), blank=True, null=True + ) + + +class AddressSerializer(serializers.ModelSerializer): + class Meta: + model = Address + fields = "__all__" + + +class SerializedRelatedFieldTests(TestCase): + + class OrderSerializerPostalCode(serializers.Serializer): + address = relations.SerializedRelatedField( + serializer_class=AddressSerializer, + queryset=Address.objects.all(), + lookup_field='postal_code', + ) + + class OrderSerializerCity(serializers.Serializer): + address = relations.SerializedRelatedField( + serializer_class=AddressSerializer, + queryset=Address.objects.all(), + lookup_field='city', + ) + + class OrderSerializerPK(serializers.Serializer): + address = relations.SerializedRelatedField( + serializer_class=AddressSerializer, + queryset=Address.objects.all(), + ) + + def setUp(self): + self.address = Address.objects.create( + postal_code="12345", + province="Tehran", + city="Tehran", + street="Valiasr", + additional_info="Test info" + ) + Address.objects.create( + postal_code="123456", + province="Tehran", + city="Tehran", + street="Valiasr", + additional_info="Test info" + ) + + def test_write_slug(self): + data = {"address": self.address.postal_code} + serializer = self.OrderSerializerPostalCode(data=data) + assert serializer.is_valid(), serializer.errors + assert serializer.validated_data["address"] == self.address + + def test_read_nested(self): + data = {"address": self.address.postal_code} + serializer = self.OrderSerializerPostalCode(data=data) + assert serializer.is_valid(), serializer.errors + expected = AddressSerializer(self.address).data + assert serializer.data["address"] == expected + + def test_write_default(self): + data = {"address": self.address.pk} + serializer = self.OrderSerializerPK(data=data) + assert serializer.is_valid(), serializer.errors + expected = AddressSerializer(self.address).data + assert serializer.data["address"] == expected + + def test_read_default(self): + data = {"address": self.address.pk} + serializer = self.OrderSerializerPK(data=data) + assert serializer.is_valid(), serializer.errors + expected = AddressSerializer(self.address).data + assert serializer.data["address"] == expected + + def test_duplicated(self): + data = {"address": "Tehran"} + serializer = self.OrderSerializerCity(data=data) + with pytest.raises(Address.MultipleObjectsReturned) as exc_info: + serializer.is_valid(raise_exception=True) + assert "returned more than one Address -- it returned 2!" in str(exc_info.value) + + def test_not_fount(self): + data = {"address": "Isfahan"} + serializer = self.OrderSerializerCity(data=data) + serializer.is_valid() + assert "Object with city=Isfahan does not exist." in serializer.errors["address"]