Skip to content

Commit 1920a81

Browse files
committed
Subclass SchemaGenerator instead of defining a new one
While the SchemaGenerator provided by DRF-YASG is not API-compatible with the one used by DRF, it does use a lot of the methods and even generates a DRF one internally. Given that this required quite a bit of extra code to be written to do the wrapping, it makes sense to instead just subclass it and get those additional benefits. This should make it eaiser to migrate from the CoreAPISchemaGenerator, which is what is currently being used, to the DRF-provided OpenAPISchemaGenerator for the internal serialization. That was not done here because there is still a lot of work that needs to be done in order to enable that.
1 parent 6a1166d commit 1920a81

File tree

3 files changed

+10
-67
lines changed

3 files changed

+10
-67
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@
189189
('py:class', 'rest_framework.serializers.Serializer'),
190190
('py:class', 'rest_framework.renderers.BaseRenderer'),
191191
('py:class', 'rest_framework.parsers.BaseParser'),
192+
('py:class', 'rest_framework.schemas.coreapi.SchemaGenerator'),
192193
('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'),
193194
('py:class', 'rest_framework.views.APIView'),
194195

docs/openapi.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ This section describes where information is sourced from when using the default
146146
The Django `FORCE_SCRIPT_NAME`_ setting can be used to override the `SCRIPT_NAME`_ or set it when it's
147147
missing from the environment.
148148

149-
#. the longest common path prefix of all the urls in your API - see :meth:`.determine_path_prefix`
149+
#. the longest common path prefix of all the urls in your API
150150

151151
* When using API versioning with ``NamespaceVersioning`` or ``URLPathVersioning``, versioned endpoints that do not
152152
match the version used to access the ``SchemaView`` will be excluded from the endpoint list - for example,

src/drf_yasg/generators.py

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def unescape_path(self, path):
163163
return clean_path
164164

165165

166-
class OpenAPISchemaGenerator(object):
166+
class OpenAPISchemaGenerator(SchemaGenerator):
167167
"""
168168
This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
169169
Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
@@ -188,7 +188,8 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
188188
:param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
189189
if not given, the default urlconf is used
190190
"""
191-
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
191+
super(OpenAPISchemaGenerator, self).__init__(info.title, url, info.get('description', ''), patterns, urlconf)
192+
192193
self.info = info
193194
self.version = version
194195
self.consumes = []
@@ -204,10 +205,6 @@ def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
204205
if parsed_url.path:
205206
logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)
206207

207-
@property
208-
def url(self):
209-
return self._gen.url
210-
211208
def get_security_definitions(self):
212209
"""Get the security schemes for this API. This determines what is usable in security requirements,
213210
and helps clients configure their authorization credentials.
@@ -278,7 +275,7 @@ def create_view(self, callback, method, request=None):
278275
:type request: rest_framework.request.Request or None
279276
:return: the view instance
280277
"""
281-
view = self._gen.create_view(callback, method, request)
278+
view = super(OpenAPISchemaGenerator, self).create_view(callback, method, request)
282279
overrides = getattr(callback, '_swagger_auto_schema', None)
283280
if overrides is not None:
284281
# decorated function based view must have its decorator information passed on to the re-instantiated view
@@ -290,24 +287,6 @@ def create_view(self, callback, method, request=None):
290287
setattr(view, 'swagger_fake_view', True)
291288
return view
292289

293-
def coerce_path(self, path, view):
294-
"""Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
295-
external representation (i.e. "this is an identifier", not "this is a database primary key").
296-
297-
:param str path: the path
298-
:param rest_framework.views.APIView view: associated view
299-
:rtype: str
300-
"""
301-
if '{pk}' not in path:
302-
return path
303-
304-
model = getattr(get_queryset_from_view(view), 'model', None)
305-
if model:
306-
field_name = get_pk_name(model)
307-
else:
308-
field_name = 'id'
309-
return path.replace('{pk}', '{%s}' % field_name)
310-
311290
def get_endpoints(self, request):
312291
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
313292
@@ -316,55 +295,18 @@ def get_endpoints(self, request):
316295
:return: {path: (view_class, list[(http_method, view_instance)])
317296
:rtype: dict[str,(type,list[(str,rest_framework.views.APIView)])]
318297
"""
319-
enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
298+
enumerator = self.endpoint_enumerator_class(self.patterns, self.urlconf, request=request)
320299
endpoints = enumerator.get_api_endpoints()
321300

322301
view_paths = defaultdict(list)
323302
view_cls = {}
324303
for path, method, callback in endpoints:
325304
view = self.create_view(callback, method, request)
326-
path = self.coerce_path(path, view)
305+
path = self.coerce_path(path, method, view)
327306
view_paths[path].append((method, view))
328307
view_cls[path] = callback.cls
329308
return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
330309

331-
def get_operation_keys(self, subpath, method, view):
332-
"""Return a list of keys that should be used to group an operation within the specification. ::
333-
334-
/users/ ("users", "list"), ("users", "create")
335-
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
336-
/users/enabled/ ("users", "enabled") # custom viewset list action
337-
/users/{pk}/star/ ("users", "star") # custom viewset detail action
338-
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
339-
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update")
340-
341-
:param str subpath: path to the operation with any common prefix/base path removed
342-
:param str method: HTTP method
343-
:param view: the view associated with the operation
344-
:rtype: list[str]
345-
"""
346-
return self._gen.get_keys(subpath, method, view)
347-
348-
def determine_path_prefix(self, paths):
349-
"""
350-
Given a list of all paths, return the common prefix which should be
351-
discounted when generating a schema structure.
352-
353-
This will be the longest common string that does not include that last
354-
component of the URL, or the last component before a path parameter.
355-
356-
For example: ::
357-
358-
/api/v1/users/
359-
/api/v1/users/{pk}/
360-
361-
The path prefix is ``/api/v1/``.
362-
363-
:param list[str] paths: list of paths
364-
:rtype: str
365-
"""
366-
return self._gen.determine_path_prefix(paths)
367-
368310
def should_include_endpoint(self, path, method, view, public):
369311
"""Check if a given endpoint should be included in the resulting schema.
370312
@@ -375,7 +317,7 @@ def should_include_endpoint(self, path, method, view, public):
375317
:returns: true if the view should be excluded
376318
:rtype: bool
377319
"""
378-
return public or self._gen.has_view_permissions(path, method, view)
320+
return public or self.has_view_permissions(path, method, view)
379321

380322
def get_paths_object(self, paths):
381323
"""Construct the Swagger Paths object.
@@ -436,7 +378,7 @@ def get_operation(self, view, path, prefix, method, components, request):
436378
:param Request request: the request made against the schema view; can be None
437379
:rtype: openapi.Operation
438380
"""
439-
operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
381+
operation_keys = self.get_keys(path[len(prefix):], method, view)
440382
overrides = self.get_overrides(view, method)
441383

442384
# the inspector class can be specified, in decreasing order of priorty,

0 commit comments

Comments
 (0)