diff --git a/rest_framework_swagger/docgenerator.py b/rest_framework_swagger/docgenerator.py index f83174b0..106b4215 100644 --- a/rest_framework_swagger/docgenerator.py +++ b/rest_framework_swagger/docgenerator.py @@ -28,6 +28,9 @@ class DocumentationGenerator(object): # Response classes defined in docstrings explicit_response_types = dict() + # Serializers referenced with $ref + ref_serializers = set() + def __init__(self, for_user=None): # unauthenticated user is expected to be in the form 'module.submodule.Class' if a value is present @@ -90,6 +93,9 @@ def get_operations(self, api, apis=None): response_type = self._get_method_response_type( doc_parser, serializer, introspector, method_introspector) + if response_type != 'object': + self.get_ref(response_type) + operation = { 'method': method_introspector.get_http_method(), 'summary': method_introspector.get_summary(), @@ -107,6 +113,9 @@ def get_operations(self, api, apis=None): inspector=method_introspector) operation['parameters'] = parameters or [] + for param in operation['parameters']: + if param['type'] not in BaseMethodIntrospector.PRIMITIVES: + self.get_ref(param['type']) if response_messages: operation['responseMessages'] = response_messages @@ -123,7 +132,7 @@ def get_operations(self, api, apis=None): # array response if method_introspector.is_array_response: operation['items'] = { - '$ref': operation['type'] + '$ref': self.get_ref(operation['type']) } operation['type'] = 'array' @@ -186,10 +195,21 @@ def get_models(self, apis): # 'properties': data['fields'], # } - models.update(self.explicit_response_types) models.update(self.fields_serializers) + + # Remove unused serializers + for name in list(models): + if name not in self.ref_serializers: + del models[name] + + models.update(self.explicit_response_types) + return models + def get_ref(self, serializer): + self.ref_serializers.add(serializer) + return serializer + def _get_method_serializer(self, method_inspector): """ Returns serializer used in method. @@ -395,7 +415,9 @@ def _get_serializer_fields(self, serializer): if getattr(field, 'write_only', False): field_serializer = "Write{}".format(field_serializer) - f['type'] = field_serializer + if not has_many: + f['$ref'] = self.get_ref(field_serializer) + del f['type'] else: field_serializer = None data_type = 'string' @@ -403,7 +425,7 @@ def _get_serializer_fields(self, serializer): if has_many: f['type'] = 'array' if field_serializer: - f['items'] = {'$ref': field_serializer} + f['items'] = {'$ref': self.get_ref(field_serializer)} elif data_type in BaseMethodIntrospector.PRIMITIVES: f['items'] = {'type': data_type} diff --git a/rest_framework_swagger/tests.py b/rest_framework_swagger/tests.py index d58bbe76..5fd6b3d3 100644 --- a/rest_framework_swagger/tests.py +++ b/rest_framework_swagger/tests.py @@ -531,6 +531,7 @@ class SerializedAPI(ListCreateAPIView): apis = urlparser.get_apis(url_patterns) docgen = self.get_documentation_generator() + docgen.generate(apis) models = docgen.get_models(apis) self.assertIn('CommentSerializer', models) @@ -651,7 +652,9 @@ class OtherSerializer(serializers.Serializer): fields = docgen._get_serializer_fields(OtherSerializer) self.assertEqual(1, len(fields['fields'])) - self.assertEqual("SomeSerializer", fields['fields']['thing2']['type']) + self.assertIn("$ref", fields['fields']['thing2']) + self.assertNotIn("type", fields['fields']['thing2']) + self.assertEqual("SomeSerializer", fields['fields']['thing2']['$ref']) def test_get_serializer_fields_api_with_nested_many(self): class SomeSerializer(serializers.Serializer): @@ -1357,6 +1360,14 @@ class HiddenSerializer(serializers.Serializer): hidden = serializers.HiddenField(default=42) class SerializedAPI(ListCreateAPIView): + """ + --- + POST: + parameters: + - name: HiddenSerializer + type: WriteHiddenSerializer + paramType: body + """ serializer_class = HiddenSerializer class_introspector = self.make_introspector2(SerializedAPI) @@ -1369,6 +1380,7 @@ class SerializedAPI(ListCreateAPIView): urlparser = UrlParser() generator = self.get_documentation_generator() apis = urlparser.get_apis(url_patterns) + generator.generate(apis) models = generator.get_models(apis) self.assertIn("HiddenSerializer", models) properties = models["HiddenSerializer"]['properties'] @@ -1398,6 +1410,7 @@ class SerializedAPI(ListCreateAPIView): urlparser = UrlParser() generator = self.get_documentation_generator() apis = urlparser.get_apis(url_patterns) + generator.generate(apis) models = generator.get_models(apis) self.assertIn("KitchenSinkSerializer", models) properties = models["KitchenSinkSerializer"]['properties'] @@ -2042,9 +2055,10 @@ def post(self, request, *args, **kwargs): url_patterns = patterns('', url(r'my-api/', SerializedAPI.as_view())) urlparser = UrlParser() apis = urlparser.get_apis(url_patterns) + generator.generate(apis) models = generator.get_models(apis) self.assertIn('SerializedAPIPostResponse', models) - self.assertIn('WriteCommentSerializer', models) + self.assertNotIn('WriteCommentSerializer', models) self.assertIn('CommentSerializer', models) self.assertNotIn('QuerySerializer', models) self.assertNotIn('WriteQuerySerializer', models)