|
4 | 4 | from django_filters import CharFilter |
5 | 5 | from django_filters.rest_framework import DjangoFilterBackend, FilterSet |
6 | 6 | from drf_spectacular.generators import SchemaGenerator |
| 7 | +from drf_spectacular.types import OpenApiTypes |
| 8 | +from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema |
7 | 9 | from rest_framework import serializers |
8 | 10 | from rest_framework.authentication import BasicAuthentication |
9 | 11 | from rest_framework.generics import GenericAPIView, ListAPIView |
@@ -118,6 +120,102 @@ def test_no_validation_error_for_unsafe_method(): |
118 | 120 | assert "400" not in responses |
119 | 121 |
|
120 | 122 |
|
| 123 | +class OpenAPITypesView(GenericAPIView): |
| 124 | + # ensure that 400 is not added due to the parser classes by using a parser |
| 125 | + # that does not raise a ParseError which results in adding a 400 error response |
| 126 | + parser_classes = [CustomParser] |
| 127 | + |
| 128 | + @extend_schema(request=OpenApiTypes.OBJECT, responses={204: None}) |
| 129 | + def post(self, request, *args, **kwargs): |
| 130 | + return Response(status=204) |
| 131 | + |
| 132 | + |
| 133 | +def test_no_error_raised_when_request_serializer_is_set_as_openapi_type(): |
| 134 | + route = "validate/" |
| 135 | + view = OpenAPITypesView.as_view() |
| 136 | + try: |
| 137 | + generate_view_schema(route, view) |
| 138 | + except Exception: |
| 139 | + pytest.fail( |
| 140 | + "Schema generation failed when using `@extend_schema(request.OpenApiTypes.OBJECT)`" |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +class Object1Serializer(serializers.Serializer): |
| 145 | + type = serializers.CharField() |
| 146 | + field1 = serializers.IntegerField() |
| 147 | + |
| 148 | + def __init__(self, *args, **kwargs): |
| 149 | + super().__init__(*args, **kwargs) |
| 150 | + self.error_messages.update(object1_code="first error") |
| 151 | + |
| 152 | + |
| 153 | +class Object2Serializer(serializers.Serializer): |
| 154 | + type = serializers.CharField() |
| 155 | + field2 = serializers.DateField() |
| 156 | + |
| 157 | + def __init__(self, *args, **kwargs): |
| 158 | + super().__init__(*args, **kwargs) |
| 159 | + self.error_messages.update(object2_code="second error") |
| 160 | + |
| 161 | + |
| 162 | +class PolymorphicView(GenericAPIView): |
| 163 | + # ensure that 400 is not added due to the parser classes by using a parser |
| 164 | + # that does not raise a ParseError which results in adding a 400 error response |
| 165 | + parser_classes = [CustomParser] |
| 166 | + |
| 167 | + @extend_schema( |
| 168 | + request=PolymorphicProxySerializer( |
| 169 | + component_name="Object", |
| 170 | + serializers={"object1": Object1Serializer, "object2": Object2Serializer}, |
| 171 | + resource_type_field_name="type", |
| 172 | + ), |
| 173 | + responses={204: None}, |
| 174 | + ) |
| 175 | + def post(self, request, *args, **kwargs): |
| 176 | + return Response(status=204) |
| 177 | + |
| 178 | + @extend_schema( |
| 179 | + request=PolymorphicProxySerializer( |
| 180 | + component_name="AnotherObject", |
| 181 | + serializers=[Object1Serializer, Object2Serializer], |
| 182 | + resource_type_field_name="type", |
| 183 | + ), |
| 184 | + responses={204: None}, |
| 185 | + ) |
| 186 | + def patch(self, request, *args, **kwargs): |
| 187 | + return Response(status=204) |
| 188 | + |
| 189 | + |
| 190 | +def test_error_codes_for_polymorphic_serializer(): |
| 191 | + """ |
| 192 | + For polymorphic serializers, the fields from the actual serializers are combined. |
| 193 | + Also, when the same field exists in multiple serializers in a polymorphic serializer, |
| 194 | + their error codes should be combined. |
| 195 | +
|
| 196 | + This test checks that fields from both serializers are present. It also checks that |
| 197 | + the error codes of non_field_errors from both serializers are combined. |
| 198 | + """ |
| 199 | + route = "validate/" |
| 200 | + view = PolymorphicView.as_view() |
| 201 | + schema = generate_view_schema(route, view) |
| 202 | + |
| 203 | + mapping = schema["components"]["schemas"]["ValidateCreateError"]["discriminator"][ |
| 204 | + "mapping" |
| 205 | + ] |
| 206 | + assert set(mapping) == {"non_field_errors", "type", "field1", "field2"} |
| 207 | + |
| 208 | + create_error_codes = schema["components"]["schemas"][ |
| 209 | + "ValidateCreateNonFieldErrorsErrorComponent" |
| 210 | + ]["properties"]["code"]["enum"] |
| 211 | + assert set(create_error_codes) == {"invalid", "object1_code", "object2_code"} |
| 212 | + |
| 213 | + patch_error_codes = schema["components"]["schemas"][ |
| 214 | + "ValidatePartialUpdateNonFieldErrorsErrorComponent" |
| 215 | + ]["properties"]["code"]["enum"] |
| 216 | + assert set(patch_error_codes) == {"invalid", "object1_code", "object2_code"} |
| 217 | + |
| 218 | + |
121 | 219 | class CustomFilterSet(FilterSet): |
122 | 220 | first_name = CharFilter() |
123 | 221 |
|
|
0 commit comments