Skip to content

Commit 7675e00

Browse files
authored
Merge pull request #12 from ghazi-git/fix-popymorphic-serializer-mapping
fix-popymorphic-serializer-mapping
2 parents fa57f56 + b27baeb commit 7675e00

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
66

77
## [UNRELEASED]
8+
### Fixed
9+
- generate the mapping for discriminator fields properly instead of showing a "null" value in the generated schema (#12).
810

911
## [0.12.0] - 2022-08-27
1012
### Added

drf_standardized_errors/openapi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .handler import exception_handler as standardized_errors_handler
2121
from .openapi_serializers import (
22+
ClientErrorEnum,
2223
ErrorResponse401Serializer,
2324
ErrorResponse403Serializer,
2425
ErrorResponse404Serializer,
@@ -28,6 +29,7 @@
2829
ErrorResponse429Serializer,
2930
ErrorResponse500Serializer,
3031
ParseErrorResponseSerializer,
32+
ValidationErrorEnum,
3133
)
3234
from .openapi_utils import (
3335
InputDataField,
@@ -253,12 +255,13 @@ def _get_http400_serializer(self):
253255
operation_id = self.get_operation_id()
254256
component_name = f"{camelize(operation_id)}ErrorResponse400"
255257

256-
http400_serializers = []
258+
http400_serializers = {}
257259
if self._should_add_validation_error_response():
258260
serializer = self._get_serializer_for_validation_error_response()
259-
http400_serializers.append(serializer)
261+
http400_serializers[ValidationErrorEnum.VALIDATION_ERROR.value] = serializer
260262
if self._should_add_parse_error_response():
261-
http400_serializers.append(ParseErrorResponseSerializer)
263+
serializer = ParseErrorResponseSerializer
264+
http400_serializers[ClientErrorEnum.CLIENT_ERROR.value] = serializer
262265

263266
return PolymorphicProxySerializer(
264267
component_name=component_name,

drf_standardized_errors/openapi_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,10 @@ def get_validation_error_serializer(
410410
):
411411
validation_error_component_name = f"{camelize(operation_id)}ValidationError"
412412
errors_component_name = f"{camelize(operation_id)}Error"
413-
sub_serializers = [
414-
get_error_serializer(operation_id, sfield.name, sfield.error_codes)
413+
sub_serializers = {
414+
sfield.name: get_error_serializer(operation_id, sfield.name, sfield.error_codes)
415415
for sfield in data_fields
416-
]
416+
}
417417

418418
class ValidationErrorSerializer(serializers.Serializer):
419419
type = serializers.ChoiceField(choices=ValidationErrorEnum.choices)

tests/test_openapi.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,32 @@ def test_validation_error_for_unsafe_method():
8484
assert "400" in responses
8585

8686

87+
def test_discriminator_mapping_for_validation_serializer():
88+
route = "validate/"
89+
view = ValidationView.as_view()
90+
schema = generate_view_schema(route, view)
91+
92+
discriminator = schema["components"]["schemas"]["ValidateCreateError"][
93+
"discriminator"
94+
]
95+
assert discriminator["propertyName"] == "attr"
96+
mapping_fields = set(discriminator["mapping"])
97+
assert mapping_fields == {"non_field_errors", "first_name"}
98+
99+
100+
def test_discriminator_mapping_for_http400_serializer():
101+
route = "validate/"
102+
view = ValidationView.as_view(parser_classes=[JSONParser])
103+
schema = generate_view_schema(route, view)
104+
105+
discriminator = schema["components"]["schemas"]["ValidateCreateErrorResponse400"][
106+
"discriminator"
107+
]
108+
assert discriminator["propertyName"] == "type"
109+
mapping_fields = set(discriminator["mapping"])
110+
assert mapping_fields == {"validation_error", "client_error"}
111+
112+
87113
def test_no_validation_error_for_unsafe_method():
88114
route = "validate/"
89115
view = ValidationView.as_view(serializer_class=None)

0 commit comments

Comments
 (0)