Skip to content

Commit 366b4ab

Browse files
authored
Merge pull request #21 from ghazi-git/openapitypes_and_polymorphic_serializer_as_request_serializer
openapitypes_and_polymorphic_serializer_as_request_serializer
2 parents 26d45f5 + d51e9a7 commit 366b4ab

File tree

5 files changed

+126
-11
lines changed

5 files changed

+126
-11
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ DRF_STANDARDIZED_ERRORS = {"ENABLE_IN_DEBUG_FOR_UNHANDLED_EXCEPTIONS": True}
100100

101101
## Integration with DRF spectacular
102102
If you plan to use [drf-spectacular](https://github.com/tfranzel/drf-spectacular) to generate an OpenAPI 3 schema,
103-
install with `pip install drf-standardized-errors[openapi]`. After that, check the doc page for configuring the
104-
integration.
103+
install with `pip install drf-standardized-errors[openapi]`. After that, check the [doc page](https://drf-standardized-errors.readthedocs.io/en/latest/openapi.html)
104+
for configuring the integration.
105105

106106
## Links
107107

docs/changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

77
## [UNRELEASED]
8+
### Fixed
9+
- account for specifying the request serializer as a basic type (like `OpenApiTypes.STR`) or as a
10+
`PolymorphicProxySerializer` using `@extend_schema(request=...)` when determining error codes for validation errors.
811

912
## [0.12.3] - 2022-11-13
1013
### Added

drf_standardized_errors/openapi.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from typing import List, Optional, Type
23

34
from drf_spectacular.drainage import warn
@@ -125,14 +126,14 @@ def _should_add_validation_error_response(self) -> bool:
125126
"""
126127
add a validation error response when unsafe methods have a request body
127128
or when a list view implements filtering with django-filters.
128-
todo add a way to disable adding the 400 validation response on an
129-
operation-basis. This is because "serializer.is_valid" has the option
130-
of not raising an exception. At the very least, add docs to demo what
131-
to override to accomplish that (maybe sth like
132-
isinstance(self.view, SomeViewSet) and checking self.method)
133129
"""
134-
has_request_body = self.method in ("PUT", "PATCH", "POST") and bool(
135-
self.get_request_serializer()
130+
request_serializer = self.get_request_serializer()
131+
has_request_body = self.method in ("PUT", "PATCH", "POST") and (
132+
isinstance(request_serializer, serializers.Field)
133+
or (
134+
inspect.isclass(request_serializer)
135+
and issubclass(request_serializer, serializers.Field)
136+
)
136137
)
137138

138139
filter_backends = get_django_filter_backends(self.get_filter_backends())

drf_standardized_errors/openapi_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from dataclasses import dataclass
23
from dataclasses import field as dataclass_field
34
from typing import List, Optional, Set, Type, Union
@@ -54,6 +55,11 @@ def get_flat_serializer_fields(
5455
f = InputDataField(non_field_errors_name, field)
5556
prefix = get_prefix(prefix, package_settings.LIST_INDEX_IN_API_SCHEMA)
5657
return [f] + get_flat_serializer_fields(field.child, prefix)
58+
elif isinstance(field, PolymorphicProxySerializer):
59+
if isinstance(field.serializers, dict):
60+
return get_flat_serializer_fields(list(field.serializers.values()), prefix)
61+
else:
62+
return get_flat_serializer_fields(field.serializers, prefix)
5763
elif is_serializer(field):
5864
prefix = get_prefix(prefix, field.field_name)
5965
non_field_errors_name = get_prefix(prefix, drf_settings.NON_FIELD_ERRORS_KEY)
@@ -384,9 +390,16 @@ def get_validation_error_serializer(
384390
):
385391
validation_error_component_name = f"{camelize(operation_id)}ValidationError"
386392
errors_component_name = f"{camelize(operation_id)}Error"
393+
394+
# When there are multiple fields with the same name in the list of data_fields,
395+
# their error codes are combined. This can happen when using a PolymorphicProxySerializer
396+
error_codes_by_field = defaultdict(set)
397+
for field in data_fields:
398+
error_codes_by_field[field.name].update(field.error_codes)
399+
387400
sub_serializers = {
388-
sfield.name: get_error_serializer(operation_id, sfield.name, sfield.error_codes)
389-
for sfield in data_fields
401+
field_name: get_error_serializer(operation_id, field_name, error_codes)
402+
for field_name, error_codes in error_codes_by_field.items()
390403
}
391404

392405
class ValidationErrorSerializer(serializers.Serializer):

tests/test_openapi.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from django_filters import CharFilter
55
from django_filters.rest_framework import DjangoFilterBackend, FilterSet
66
from drf_spectacular.generators import SchemaGenerator
7+
from drf_spectacular.types import OpenApiTypes
8+
from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema
79
from rest_framework import serializers
810
from rest_framework.authentication import BasicAuthentication
911
from rest_framework.generics import GenericAPIView, ListAPIView
@@ -118,6 +120,102 @@ def test_no_validation_error_for_unsafe_method():
118120
assert "400" not in responses
119121

120122

123+
class OpenAPITypesView(GenericAPIView):
124+
# ensure that 400 is not added due to the parser classes by using a parser
125+
# that does not raise a ParseError which results in adding a 400 error response
126+
parser_classes = [CustomParser]
127+
128+
@extend_schema(request=OpenApiTypes.OBJECT, responses={204: None})
129+
def post(self, request, *args, **kwargs):
130+
return Response(status=204)
131+
132+
133+
def test_no_error_raised_when_request_serializer_is_set_as_openapi_type():
134+
route = "validate/"
135+
view = OpenAPITypesView.as_view()
136+
try:
137+
generate_view_schema(route, view)
138+
except Exception:
139+
pytest.fail(
140+
"Schema generation failed when using `@extend_schema(request.OpenApiTypes.OBJECT)`"
141+
)
142+
143+
144+
class Object1Serializer(serializers.Serializer):
145+
type = serializers.CharField()
146+
field1 = serializers.IntegerField()
147+
148+
def __init__(self, *args, **kwargs):
149+
super().__init__(*args, **kwargs)
150+
self.error_messages.update(object1_code="first error")
151+
152+
153+
class Object2Serializer(serializers.Serializer):
154+
type = serializers.CharField()
155+
field2 = serializers.DateField()
156+
157+
def __init__(self, *args, **kwargs):
158+
super().__init__(*args, **kwargs)
159+
self.error_messages.update(object2_code="second error")
160+
161+
162+
class PolymorphicView(GenericAPIView):
163+
# ensure that 400 is not added due to the parser classes by using a parser
164+
# that does not raise a ParseError which results in adding a 400 error response
165+
parser_classes = [CustomParser]
166+
167+
@extend_schema(
168+
request=PolymorphicProxySerializer(
169+
component_name="Object",
170+
serializers={"object1": Object1Serializer, "object2": Object2Serializer},
171+
resource_type_field_name="type",
172+
),
173+
responses={204: None},
174+
)
175+
def post(self, request, *args, **kwargs):
176+
return Response(status=204)
177+
178+
@extend_schema(
179+
request=PolymorphicProxySerializer(
180+
component_name="AnotherObject",
181+
serializers=[Object1Serializer, Object2Serializer],
182+
resource_type_field_name="type",
183+
),
184+
responses={204: None},
185+
)
186+
def patch(self, request, *args, **kwargs):
187+
return Response(status=204)
188+
189+
190+
def test_error_codes_for_polymorphic_serializer():
191+
"""
192+
For polymorphic serializers, the fields from the actual serializers are combined.
193+
Also, when the same field exists in multiple serializers in a polymorphic serializer,
194+
their error codes should be combined.
195+
196+
This test checks that fields from both serializers are present. It also checks that
197+
the error codes of non_field_errors from both serializers are combined.
198+
"""
199+
route = "validate/"
200+
view = PolymorphicView.as_view()
201+
schema = generate_view_schema(route, view)
202+
203+
mapping = schema["components"]["schemas"]["ValidateCreateError"]["discriminator"][
204+
"mapping"
205+
]
206+
assert set(mapping) == {"non_field_errors", "type", "field1", "field2"}
207+
208+
create_error_codes = schema["components"]["schemas"][
209+
"ValidateCreateNonFieldErrorsErrorComponent"
210+
]["properties"]["code"]["enum"]
211+
assert set(create_error_codes) == {"invalid", "object1_code", "object2_code"}
212+
213+
patch_error_codes = schema["components"]["schemas"][
214+
"ValidatePartialUpdateNonFieldErrorsErrorComponent"
215+
]["properties"]["code"]["enum"]
216+
assert set(patch_error_codes) == {"invalid", "object1_code", "object2_code"}
217+
218+
121219
class CustomFilterSet(FilterSet):
122220
first_name = CharFilter()
123221

0 commit comments

Comments
 (0)