Skip to content

Commit b2497fc

Browse files
committed
Convert openapi.AutoSchema methods to public API.
1 parent d45e000 commit b2497fc

File tree

2 files changed

+148
-51
lines changed

2 files changed

+148
-51
lines changed

rest_framework/schemas/openapi.py

Lines changed: 133 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from django.db import models
1313
from django.utils.encoding import force_str
1414

15-
from rest_framework import exceptions, renderers, serializers
15+
from rest_framework import (
16+
RemovedInDRF314Warning, exceptions, renderers, serializers
17+
)
1618
from rest_framework.compat import uritemplate
1719
from rest_framework.fields import _UnvalidatedField, empty
1820
from rest_framework.settings import api_settings
@@ -146,15 +148,15 @@ def get_operation(self, path, method):
146148
operation['description'] = self.get_description(path, method)
147149

148150
parameters = []
149-
parameters += self._get_path_parameters(path, method)
150-
parameters += self._get_pagination_parameters(path, method)
151-
parameters += self._get_filter_parameters(path, method)
151+
parameters += self.get_path_parameters(path, method)
152+
parameters += self.get_pagination_parameters(path, method)
153+
parameters += self.get_filter_parameters(path, method)
152154
operation['parameters'] = parameters
153155

154-
request_body = self._get_request_body(path, method)
156+
request_body = self.get_request_body(path, method)
155157
if request_body:
156158
operation['requestBody'] = request_body
157-
operation['responses'] = self._get_responses(path, method)
159+
operation['responses'] = self.get_responses(path, method)
158160
operation['tags'] = self.get_tags(path, method)
159161

160162
return operation
@@ -190,14 +192,14 @@ def get_components(self, path, method):
190192
if method.lower() == 'delete':
191193
return {}
192194

193-
serializer = self._get_serializer(path, method)
195+
serializer = self.get_serializer(path, method)
194196

195197
if not isinstance(serializer, serializers.Serializer):
196198
return {}
197199

198200
component_name = self.get_component_name(serializer)
199201

200-
content = self._map_serializer(serializer)
202+
content = self.map_serializer(serializer)
201203
return {component_name: content}
202204

203205
def _to_camel_case(self, snake_str):
@@ -220,8 +222,8 @@ def get_operation_id_base(self, path, method, action):
220222
name = model.__name__
221223

222224
# Try with the serializer class name
223-
elif self._get_serializer(path, method) is not None:
224-
name = self._get_serializer(path, method).__class__.__name__
225+
elif self.get_serializer(path, method) is not None:
226+
name = self.get_serializer(path, method).__class__.__name__
225227
if name.endswith('Serializer'):
226228
name = name[:-10]
227229

@@ -259,7 +261,7 @@ def get_operation_id(self, path, method):
259261

260262
return action + name
261263

262-
def _get_path_parameters(self, path, method):
264+
def get_path_parameters(self, path, method):
263265
"""
264266
Return a list of parameters from templated path variables.
265267
"""
@@ -295,15 +297,15 @@ def _get_path_parameters(self, path, method):
295297

296298
return parameters
297299

298-
def _get_filter_parameters(self, path, method):
299-
if not self._allows_filters(path, method):
300+
def get_filter_parameters(self, path, method):
301+
if not self.allows_filters(path, method):
300302
return []
301303
parameters = []
302304
for filter_backend in self.view.filter_backends:
303305
parameters += filter_backend().get_schema_operation_parameters(self.view)
304306
return parameters
305307

306-
def _allows_filters(self, path, method):
308+
def allows_filters(self, path, method):
307309
"""
308310
Determine whether to include filter Fields in schema.
309311
@@ -316,19 +318,19 @@ def _allows_filters(self, path, method):
316318
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
317319
return method.lower() in ["get", "put", "patch", "delete"]
318320

319-
def _get_pagination_parameters(self, path, method):
321+
def get_pagination_parameters(self, path, method):
320322
view = self.view
321323

322324
if not is_list_view(path, method, view):
323325
return []
324326

325-
paginator = self._get_paginator()
327+
paginator = self.get_paginator()
326328
if not paginator:
327329
return []
328330

329331
return paginator.get_schema_operation_parameters(view)
330332

331-
def _map_choicefield(self, field):
333+
def map_choicefield(self, field):
332334
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
333335
if all(isinstance(choice, bool) for choice in choices):
334336
type = 'boolean'
@@ -356,24 +358,24 @@ def _map_choicefield(self, field):
356358
mapping['type'] = type
357359
return mapping
358360

359-
def _map_field(self, field):
361+
def map_field(self, field):
360362

361363
# Nested Serializers, `many` or not.
362364
if isinstance(field, serializers.ListSerializer):
363365
return {
364366
'type': 'array',
365-
'items': self._map_serializer(field.child)
367+
'items': self.map_serializer(field.child)
366368
}
367369
if isinstance(field, serializers.Serializer):
368-
data = self._map_serializer(field)
370+
data = self.map_serializer(field)
369371
data['type'] = 'object'
370372
return data
371373

372374
# Related fields.
373375
if isinstance(field, serializers.ManyRelatedField):
374376
return {
375377
'type': 'array',
376-
'items': self._map_field(field.child_relation)
378+
'items': self.map_field(field.child_relation)
377379
}
378380
if isinstance(field, serializers.PrimaryKeyRelatedField):
379381
model = getattr(field.queryset, 'model', None)
@@ -389,11 +391,11 @@ def _map_field(self, field):
389391
if isinstance(field, serializers.MultipleChoiceField):
390392
return {
391393
'type': 'array',
392-
'items': self._map_choicefield(field)
394+
'items': self.map_choicefield(field)
393395
}
394396

395397
if isinstance(field, serializers.ChoiceField):
396-
return self._map_choicefield(field)
398+
return self.map_choicefield(field)
397399

398400
# ListField.
399401
if isinstance(field, serializers.ListField):
@@ -402,7 +404,7 @@ def _map_field(self, field):
402404
'items': {},
403405
}
404406
if not isinstance(field.child, _UnvalidatedField):
405-
mapping['items'] = self._map_field(field.child)
407+
mapping['items'] = self.map_field(field.child)
406408
return mapping
407409

408410
# DateField and DateTimeField type is string
@@ -504,7 +506,7 @@ def _map_min_max(self, field, content):
504506
if field.min_value:
505507
content['minimum'] = field.min_value
506508

507-
def _map_serializer(self, serializer):
509+
def map_serializer(self, serializer):
508510
# Assuming we have a valid serializer instance.
509511
required = []
510512
properties = {}
@@ -516,7 +518,7 @@ def _map_serializer(self, serializer):
516518
if field.required:
517519
required.append(field.field_name)
518520

519-
schema = self._map_field(field)
521+
schema = self.map_field(field)
520522
if field.read_only:
521523
schema['readOnly'] = True
522524
if field.write_only:
@@ -527,7 +529,7 @@ def _map_serializer(self, serializer):
527529
schema['default'] = field.default
528530
if field.help_text:
529531
schema['description'] = str(field.help_text)
530-
self._map_field_validators(field, schema)
532+
self.map_field_validators(field, schema)
531533

532534
properties[field.field_name] = schema
533535

@@ -540,7 +542,7 @@ def _map_serializer(self, serializer):
540542

541543
return result
542544

543-
def _map_field_validators(self, field, schema):
545+
def map_field_validators(self, field, schema):
544546
"""
545547
map field validators
546548
"""
@@ -578,7 +580,7 @@ def _map_field_validators(self, field, schema):
578580
schema['maximum'] = int(digits * '9') + 1
579581
schema['minimum'] = -schema['maximum']
580582

581-
def _get_paginator(self):
583+
def get_paginator(self):
582584
pagination_class = getattr(self.view, 'pagination_class', None)
583585
if pagination_class:
584586
return pagination_class()
@@ -596,7 +598,7 @@ def map_renderers(self, path, method):
596598
media_types.append(renderer.media_type)
597599
return media_types
598600

599-
def _get_serializer(self, path, method):
601+
def get_serializer(self, path, method):
600602
view = self.view
601603

602604
if not hasattr(view, 'get_serializer'):
@@ -614,13 +616,13 @@ def _get_serializer(self, path, method):
614616
def _get_reference(self, serializer):
615617
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
616618

617-
def _get_request_body(self, path, method):
619+
def get_request_body(self, path, method):
618620
if method not in ('PUT', 'PATCH', 'POST'):
619621
return {}
620622

621623
self.request_media_types = self.map_parsers(path, method)
622624

623-
serializer = self._get_serializer(path, method)
625+
serializer = self.get_serializer(path, method)
624626

625627
if not isinstance(serializer, serializers.Serializer):
626628
item_schema = {}
@@ -634,8 +636,7 @@ def _get_request_body(self, path, method):
634636
}
635637
}
636638

637-
def _get_responses(self, path, method):
638-
# TODO: Handle multiple codes and pagination classes.
639+
def get_responses(self, path, method):
639640
if method == 'DELETE':
640641
return {
641642
'204': {
@@ -645,7 +646,7 @@ def _get_responses(self, path, method):
645646

646647
self.response_media_types = self.map_renderers(path, method)
647648

648-
serializer = self._get_serializer(path, method)
649+
serializer = self.get_serializer(path, method)
649650

650651
if not isinstance(serializer, serializers.Serializer):
651652
item_schema = {}
@@ -657,7 +658,7 @@ def _get_responses(self, path, method):
657658
'type': 'array',
658659
'items': item_schema,
659660
}
660-
paginator = self._get_paginator()
661+
paginator = self.get_paginator()
661662
if paginator:
662663
response_schema = paginator.get_paginated_response_schema(response_schema)
663664
else:
@@ -688,3 +689,99 @@ def get_tags(self, path, method):
688689
path = path[1:]
689690

690691
return [path.split('/')[0].replace('_', '-')]
692+
693+
def _get_path_parameters(self, path, method):
694+
warnings.warn(
695+
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
696+
"The old name will be removed in DRF v3.14.",
697+
RemovedInDRF314Warning, stacklevel=2
698+
)
699+
return self.get_path_parameters(path, method)
700+
701+
def _get_filter_parameters(self, path, method):
702+
warnings.warn(
703+
"Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. "
704+
"The old name will be removed in DRF v3.14.",
705+
RemovedInDRF314Warning, stacklevel=2
706+
)
707+
return self.get_filter_parameters(path, method)
708+
709+
def _get_responses(self, path, method):
710+
warnings.warn(
711+
"Method `_get_responses()` has been renamed to `get_responses()`. "
712+
"The old name will be removed in DRF v3.14.",
713+
RemovedInDRF314Warning, stacklevel=2
714+
)
715+
return self.get_responses(path, method)
716+
717+
def _get_request_body(self, path, method):
718+
warnings.warn(
719+
"Method `_get_request_body()` has been renamed to `get_request_body()`. "
720+
"The old name will be removed in DRF v3.14.",
721+
RemovedInDRF314Warning, stacklevel=2
722+
)
723+
return self.get_request_body(path, method)
724+
725+
def _get_serializer(self, path, method):
726+
warnings.warn(
727+
"Method `_get_serializer()` has been renamed to `get_serializer()`. "
728+
"The old name will be removed in DRF v3.14.",
729+
RemovedInDRF314Warning, stacklevel=2
730+
)
731+
return self.get_serializer(path, method)
732+
733+
def _get_paginator(self):
734+
warnings.warn(
735+
"Method `_get_paginator()` has been renamed to `get_paginator()`. "
736+
"The old name will be removed in DRF v3.14.",
737+
RemovedInDRF314Warning, stacklevel=2
738+
)
739+
return self.get_paginator()
740+
741+
def _map_field_validators(self, field, schema):
742+
warnings.warn(
743+
"Method `_map_field_validators()` has been renamed to `map_field_validators()`. "
744+
"The old name will be removed in DRF v3.14.",
745+
RemovedInDRF314Warning, stacklevel=2
746+
)
747+
return self.map_field_validators(field, schema)
748+
749+
def _map_serializer(self, serializer):
750+
warnings.warn(
751+
"Method `_map_serializer()` has been renamed to `map_serializer()`. "
752+
"The old name will be removed in DRF v3.14.",
753+
RemovedInDRF314Warning, stacklevel=2
754+
)
755+
return self.map_serializer(serializer)
756+
757+
def _map_field(self, field):
758+
warnings.warn(
759+
"Method `_map_field()` has been renamed to `map_field()`. "
760+
"The old name will be removed in DRF v3.14.",
761+
RemovedInDRF314Warning, stacklevel=2
762+
)
763+
return self.map_field(field)
764+
765+
def _map_choicefield(self, field):
766+
warnings.warn(
767+
"Method `_map_choicefield()` has been renamed to `map_choicefield()`. "
768+
"The old name will be removed in DRF v3.14.",
769+
RemovedInDRF314Warning, stacklevel=2
770+
)
771+
return self.map_choicefield(field)
772+
773+
def _get_pagination_parameters(self, path, method):
774+
warnings.warn(
775+
"Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. "
776+
"The old name will be removed in DRF v3.14.",
777+
RemovedInDRF314Warning, stacklevel=2
778+
)
779+
return self.get_pagination_parameters(path, method)
780+
781+
def _allows_filters(self, path, method):
782+
warnings.warn(
783+
"Method `_allows_filters()` has been renamed to `allows_filters()`. "
784+
"The old name will be removed in DRF v3.14.",
785+
RemovedInDRF314Warning, stacklevel=2
786+
)
787+
return self.allows_filters(path, method)

0 commit comments

Comments
 (0)