Skip to content

Commit 4c7ab6e

Browse files
Add SerializedRelatedField for write-by-slug, read-as-nested representation
- Introduce SerializedRelatedField in rest_framework/relations.py - Supports writing by slug/PK and reading as nested serializer - Add corresponding tests in tests/test_relations.py
1 parent 2001878 commit 4c7ab6e

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

rest_framework/relations.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from rest_framework.reverse import reverse
1717
from rest_framework.settings import api_settings
1818
from rest_framework.utils import html
19-
19+
from django.db import models
2020

2121
def method_overridden(method_name, klass, instance):
2222
"""
@@ -583,3 +583,40 @@ def iter_options(self):
583583
cutoff=self.html_cutoff,
584584
cutoff_text=self.html_cutoff_text
585585
)
586+
587+
588+
class SerializedRelatedField(SlugRelatedField):
589+
"""
590+
A relational field that accepts a simple slug for writes
591+
(like SlugRelatedField), but expands to a nested serializer
592+
for reads if `serializer_class` is provided.
593+
594+
Example:
595+
class OrderSerializer(serializers.ModelSerializer):
596+
address = SerializedRelatedField(
597+
serializer_class=AddressSerializer,
598+
queryset=Address.objects.all(),
599+
lookup_field="pk",
600+
)
601+
"""
602+
603+
def __init__(self, serializer_class=None, lookup_field="pk", **kwargs):
604+
self.serializer_class = serializer_class
605+
kwargs["slug_field"] = lookup_field
606+
super().__init__(**kwargs)
607+
608+
if self.serializer_class is not None and self.queryset is None:
609+
raise AssertionError(
610+
"SerializedRelatedField with serializer_class requires a queryset"
611+
)
612+
613+
def to_representation(self, value):
614+
# Ensure PKOnlyObject (used in select_related/prefetch) is resolved
615+
if hasattr(value, "pk") and not isinstance(value, models.Model):
616+
value = self.get_queryset().get(pk=value.pk)
617+
618+
if self.serializer_class is not None:
619+
serializer = self.serializer_class(value, context=self.context)
620+
return serializer.data
621+
622+
return super().to_representation(value)

rest_framework/serializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from rest_framework.relations import ( # NOQA # isort:skip
6363
HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField,
6464
PrimaryKeyRelatedField, RelatedField, SlugRelatedField, StringRelatedField,
65+
SerializedRelatedField
6566
)
6667

6768
# Non-field imports, but public API

tests/test_relations.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import pytest
44
from _pytest.monkeypatch import MonkeyPatch
55
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
6-
from django.test import override_settings
6+
from django.test import TestCase, override_settings
77
from django.urls import re_path
88
from django.utils.datastructures import MultiValueDict
9-
9+
1010
from rest_framework import relations, serializers
1111
from rest_framework.fields import empty
1212
from rest_framework.test import APISimpleTestCase
13-
13+
from django.db import models
1414
from .utils import (
1515
BadType, MockObject, MockQueryset, fail_reverse, mock_reverse
1616
)
17+
from tests.models import RESTFrameworkModel
18+
from django.utils.translation import gettext_lazy as _
1719

1820

1921
class TestStringRelatedField(APISimpleTestCase):
@@ -518,3 +520,101 @@ def test_can_be_pickled(self):
518520
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
519521
assert upkled == self.default_hyperlink
520522
assert upkled.name == self.default_hyperlink.name
523+
524+
525+
class Address(RESTFrameworkModel):
526+
postal_code = models.CharField(
527+
max_length=20, unique=True, verbose_name=_("Postal Code")
528+
)
529+
province = models.CharField(max_length=100, verbose_name=_("Province"))
530+
city = models.CharField(max_length=100, verbose_name=_("City"))
531+
street = models.CharField(
532+
max_length=255, blank=True, null=True, verbose_name=_("Street")
533+
)
534+
additional_info = models.TextField(
535+
verbose_name=_("Additional Info"), blank=True, null=True
536+
)
537+
538+
539+
class AddressSerializer(serializers.ModelSerializer):
540+
class Meta:
541+
model = Address
542+
fields = "__all__"
543+
544+
545+
class SerializedRelatedFieldTests(TestCase):
546+
def setUp(self):
547+
self.address = Address.objects.create(
548+
postal_code="12345",
549+
province="Tehran",
550+
city="Tehran",
551+
street="Valiasr",
552+
additional_info="Test info"
553+
)
554+
Address.objects.create(
555+
postal_code="123456",
556+
province="Tehran",
557+
city="Tehran",
558+
street="Valiasr",
559+
additional_info="Test info"
560+
)
561+
class OrderSerializerPostalCode(serializers.Serializer):
562+
address = relations.SerializedRelatedField(
563+
serializer_class=AddressSerializer,
564+
queryset=Address.objects.all(),
565+
lookup_field='postal_code',
566+
)
567+
class OrderSerializerCity(serializers.Serializer):
568+
address = relations.SerializedRelatedField(
569+
serializer_class=AddressSerializer,
570+
queryset=Address.objects.all(),
571+
lookup_field='city',
572+
)
573+
class OrderSerializerPK(serializers.Serializer):
574+
address = relations.SerializedRelatedField(
575+
serializer_class=AddressSerializer,
576+
queryset=Address.objects.all(),
577+
)
578+
self.serializer_postal_code = OrderSerializerPostalCode
579+
self.serializer_pk = OrderSerializerPK
580+
self.serializer_city = OrderSerializerCity
581+
582+
def test_write_slug(self):
583+
data = {"address": self.address.postal_code}
584+
serializer = self.serializer_postal_code(data=data)
585+
assert serializer.is_valid(), serializer.errors
586+
assert serializer.validated_data["address"] == self.address
587+
588+
def test_read_nested(self):
589+
data = {"address": self.address.postal_code}
590+
serializer = self.serializer_postal_code(data=data)
591+
assert serializer.is_valid(), serializer.errors
592+
expected = AddressSerializer(self.address).data
593+
assert serializer.data["address"] == expected
594+
595+
def test_write_default(self):
596+
data = {"address": self.address.pk}
597+
serializer = self.serializer_pk(data=data)
598+
assert serializer.is_valid(), serializer.errors
599+
expected = AddressSerializer(self.address).data
600+
assert serializer.data["address"] == expected
601+
602+
def test_read_default(self):
603+
data = {"address": self.address.pk}
604+
serializer = self.serializer_pk(data=data)
605+
assert serializer.is_valid(), serializer.errors
606+
expected = AddressSerializer(self.address).data
607+
assert serializer.data["address"] == expected
608+
609+
def test_duplicated(self):
610+
data = {"address": "Tehran"}
611+
serializer = self.serializer_city(data=data)
612+
with pytest.raises(Address.MultipleObjectsReturned) as exc_info:
613+
serializer.is_valid(raise_exception=True)
614+
assert "returned more than one Address -- it returned 2!" in str(exc_info.value)
615+
616+
def test_not_fount(self):
617+
data = {"address": "Isfahan"}
618+
serializer = self.serializer_city(data=data)
619+
serializer.is_valid()
620+
assert "Object with city=Isfahan does not exist." in serializer.errors["address"]

0 commit comments

Comments
 (0)