Skip to content

Commit e8a6916

Browse files
committed
combine fields from sub-serializers in a polymorphic serializer, and if the same field exists in more than one sub-serializer, combine their error codes so all error codes are visible in the resulting schema
1 parent 43b881d commit e8a6916

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

docs/changelog.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

77
## [UNRELEASED]
88
### Fixed
9-
- account for specifying the request serializer as a basic type using `@extend_schema(request=OpenApiTypes.STR)`
10-
when determining error codes for validation errors.
9+
- account for specifying the request serializer as a basic type (like `OpenApiTypes.STR`) or as a
10+
`PolymorphicProxySerializer` using `@extend_schema(request=...)` when determining error codes for validation errors.
1111

1212
## [0.12.3] - 2022-11-13
1313
### Added

drf_standardized_errors/openapi_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from dataclasses import dataclass
23
from dataclasses import field as dataclass_field
34
from typing import List, Optional, Set, Type, Union
@@ -54,6 +55,11 @@ def get_flat_serializer_fields(
5455
f = InputDataField(non_field_errors_name, field)
5556
prefix = get_prefix(prefix, package_settings.LIST_INDEX_IN_API_SCHEMA)
5657
return [f] + get_flat_serializer_fields(field.child, prefix)
58+
elif isinstance(field, PolymorphicProxySerializer):
59+
if isinstance(field.serializers, dict):
60+
return get_flat_serializer_fields(list(field.serializers.values()), prefix)
61+
else:
62+
return get_flat_serializer_fields(field.serializers, prefix)
5763
elif is_serializer(field):
5864
prefix = get_prefix(prefix, field.field_name)
5965
non_field_errors_name = get_prefix(prefix, drf_settings.NON_FIELD_ERRORS_KEY)
@@ -384,9 +390,16 @@ def get_validation_error_serializer(
384390
):
385391
validation_error_component_name = f"{camelize(operation_id)}ValidationError"
386392
errors_component_name = f"{camelize(operation_id)}Error"
393+
394+
# When there are multiple fields with the same name in the list of data_fields,
395+
# their error codes are combined. This can happen when using a PolymorphicProxySerializer
396+
error_codes_by_field = defaultdict(set)
397+
for field in data_fields:
398+
error_codes_by_field[field.name].update(field.error_codes)
399+
387400
sub_serializers = {
388-
sfield.name: get_error_serializer(operation_id, sfield.name, sfield.error_codes)
389-
for sfield in data_fields
401+
field_name: get_error_serializer(operation_id, field_name, error_codes)
402+
for field_name, error_codes in error_codes_by_field.items()
390403
}
391404

392405
class ValidationErrorSerializer(serializers.Serializer):

tests/test_openapi.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django_filters.rest_framework import DjangoFilterBackend, FilterSet
66
from drf_spectacular.generators import SchemaGenerator
77
from drf_spectacular.types import OpenApiTypes
8-
from drf_spectacular.utils import extend_schema
8+
from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema
99
from rest_framework import serializers
1010
from rest_framework.authentication import BasicAuthentication
1111
from rest_framework.generics import GenericAPIView, ListAPIView
@@ -141,6 +141,81 @@ def test_no_error_raised_when_request_serializer_is_set_as_openapi_type():
141141
)
142142

143143

144+
class Object1Serializer(serializers.Serializer):
145+
type = serializers.CharField()
146+
field1 = serializers.IntegerField()
147+
148+
def __init__(self, *args, **kwargs):
149+
super().__init__(*args, **kwargs)
150+
self.error_messages.update(object1_code="first error")
151+
152+
153+
class Object2Serializer(serializers.Serializer):
154+
type = serializers.CharField()
155+
field2 = serializers.DateField()
156+
157+
def __init__(self, *args, **kwargs):
158+
super().__init__(*args, **kwargs)
159+
self.error_messages.update(object2_code="second error")
160+
161+
162+
class PolymorphicView(GenericAPIView):
163+
# ensure that 400 is not added due to the parser classes by using a parser
164+
# that does not raise a ParseError which results in adding a 400 error response
165+
parser_classes = [CustomParser]
166+
167+
@extend_schema(
168+
request=PolymorphicProxySerializer(
169+
component_name="Object",
170+
serializers={"object1": Object1Serializer, "object2": Object2Serializer},
171+
resource_type_field_name="type",
172+
),
173+
responses={204: None},
174+
)
175+
def post(self, request, *args, **kwargs):
176+
return Response(status=204)
177+
178+
@extend_schema(
179+
request=PolymorphicProxySerializer(
180+
component_name="AnotherObject",
181+
serializers=[Object1Serializer, Object2Serializer],
182+
resource_type_field_name="type",
183+
),
184+
responses={204: None},
185+
)
186+
def patch(self, request, *args, **kwargs):
187+
return Response(status=204)
188+
189+
190+
def test_error_codes_for_polymorphic_serializer():
191+
"""
192+
For polymorphic serializers, the fields from the actual serializers are combined.
193+
Also, when the same field exists in multiple serializers in a polymorphic serializer,
194+
their error codes should be combined.
195+
196+
This test checks that fields from both serializers are present. It also checks that
197+
the error codes of non_field_errors from both serializers are combined.
198+
"""
199+
route = "validate/"
200+
view = PolymorphicView.as_view()
201+
schema = generate_view_schema(route, view)
202+
203+
mapping = schema["components"]["schemas"]["ValidateCreateError"]["discriminator"][
204+
"mapping"
205+
]
206+
assert set(mapping) == {"non_field_errors", "type", "field1", "field2"}
207+
208+
create_error_codes = schema["components"]["schemas"][
209+
"ValidateCreateNonFieldErrorsErrorComponent"
210+
]["properties"]["code"]["enum"]
211+
assert set(create_error_codes) == {"invalid", "object1_code", "object2_code"}
212+
213+
patch_error_codes = schema["components"]["schemas"][
214+
"ValidatePartialUpdateNonFieldErrorsErrorComponent"
215+
]["properties"]["code"]["enum"]
216+
assert set(patch_error_codes) == {"invalid", "object1_code", "object2_code"}
217+
218+
144219
class CustomFilterSet(FilterSet):
145220
first_name = CharFilter()
146221

0 commit comments

Comments
 (0)