Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def get_components(self, path, method):
Return components with their properties from the serializer.
"""

components = {}

if method.lower() == 'delete':
return {}

Expand All @@ -197,10 +199,28 @@ def get_components(self, path, method):
if not isinstance(serializer, serializers.Serializer):
return {}

component_name = self.get_component_name(serializer)
item_component_name = self.get_component_name(serializer)
item_schema = self.map_serializer(serializer)
components[item_component_name] = item_schema

response_component_name = self._get_response_component_name(
self.get_operation_id(path, method)
)

if is_list_view(path, method, self.view):
response_component_schema = {
'type': 'array',
'items': self._get_serializer_reference(serializer),
}
paginator = self.get_paginator()
if paginator:
response_component_schema = paginator.get_paginated_response_schema(response_component_schema)
else:
response_component_schema = self._get_serializer_reference(serializer)

components[response_component_name] = response_component_schema

content = self.map_serializer(serializer)
return {component_name: content}
return components

def _to_camel_case(self, snake_str):
components = snake_str.split('_')
Expand Down Expand Up @@ -613,9 +633,17 @@ def get_serializer(self, path, method):
.format(view.__class__.__name__, method, path))
return None

def _get_reference(self, serializer):
def _get_serializer_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}

@staticmethod
def _get_response_component_name(operation_id):
operation_id = operation_id[0].upper() + operation_id[1:]
return operation_id + 'Response'

def _get_response_reference(self, operation_id):
return {'$ref': '#/components/schemas/{0}'.format(self._get_response_component_name(operation_id))}

def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
Expand All @@ -627,7 +655,7 @@ def get_request_body(self, path, method):
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
item_schema = self._get_serializer_reference(serializer)

return {
'content': {
Expand All @@ -649,20 +677,17 @@ def get_responses(self, path, method):
serializer = self.get_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)

if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': {}
}
else:
response_schema = {}
else:
response_schema = item_schema
response_schema = self._get_response_reference(
self.get_operation_id(path, method)
)
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
Expand Down
38 changes: 24 additions & 14 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ class View(generics.GenericAPIView):
inspector.view = view

responses = inspector.get_responses(path, method)
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
assert responses['201']['content']['application/json']['schema']['$ref'] == \
'#/components/schemas/CreateItemResponse'

components = inspector.get_components(path, method)
assert sorted(components['Item']['required']) == ['text', 'write_only']
Expand Down Expand Up @@ -338,7 +339,7 @@ class View(generics.GenericAPIView):
inspector.view = view

responses = inspector.get_responses(path, method)
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item'
assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/CreateItemResponse'
components = inspector.get_components(path, method)
assert components['Item']

Expand Down Expand Up @@ -375,17 +376,20 @@ class View(generics.GenericAPIView):
'content': {
'application/json': {
'schema': {
'type': 'array',
'items': {
'$ref': '#/components/schemas/Item'
},
'$ref': '#/components/schemas/ListItemsResponse'
},
},
},
},
}
components = inspector.get_components(path, method)
assert components == {
'ListItemsResponse': {
'type': 'array',
'items': {
'$ref': '#/components/schemas/Item',
},
},
'Item': {
'type': 'object',
'properties': {
Expand Down Expand Up @@ -431,20 +435,23 @@ class View(generics.GenericAPIView):
'content': {
'application/json': {
'schema': {
'type': 'object',
'item': {
'type': 'array',
'items': {
'$ref': '#/components/schemas/Item'
},
},
'$ref': '#/components/schemas/ListItemsResponse'
},
},
},
},
}
components = inspector.get_components(path, method)
assert components == {
'ListItemsResponse': {
'type': 'object',
'item': {
'type': 'array',
'items': {
'$ref': '#/components/schemas/Item',
},
},
},
'Item': {
'type': 'object',
'properties': {
Expand Down Expand Up @@ -601,7 +608,7 @@ class View(generics.GenericAPIView):
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Item'
'$ref': '#/components/schemas/RetrieveItemResponse'
},
},
},
Expand All @@ -610,6 +617,9 @@ class View(generics.GenericAPIView):

components = inspector.get_components(path, method)
assert components == {
'RetrieveItemResponse': {
'$ref': '#/components/schemas/Item'
},
'Item': {
'type': 'object',
'properties': {
Expand Down