Skip to content

Commit e9c280d

Browse files
authored
Fix swagger API docs generation (#1366)
* Fix swagger API doc for `api/packages/bulk_search/` - Use `extend_schema` to override the request body, which was not being properly discovered. - drf-spectacular relies on `get_serializer_class()` and `get_serializer()`, and it works well for view functions that purely deal with ModelSerializer. For anything else, it gets a bit murky, and it is advised to provide proper overrides in the `extend_schema` decorator. - Override erroneous pagination and filter backend caused due to response containing multiple serializer object https://drf-spectacular.readthedocs.io/en/latest/faq.html#my-action-is-erroneously-paginated-or-has-filter-parameters-that-i-do-not-want Signed-off-by: Keshav Priyadarshi <[email protected]> * Add tests for `bulk_search` Signed-off-by: Keshav Priyadarshi <[email protected]> * Test `bulk_search` with empty request body Signed-off-by: Keshav Priyadarshi <[email protected]> * Fix API doc generation for `api/packages/lookup` Signed-off-by: Keshav Priyadarshi <[email protected]> * Fix API doc generation for `api/packages/bulk_lookup` Signed-off-by: Keshav Priyadarshi <[email protected]> --------- Signed-off-by: Keshav Priyadarshi <[email protected]>
1 parent 939055a commit e9c280d

File tree

2 files changed

+158
-20
lines changed

2 files changed

+158
-20
lines changed

vulnerabilities/api.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111

1212
from django.db.models import Prefetch
1313
from django_filters import rest_framework as filters
14+
from drf_spectacular.utils import extend_schema
15+
from drf_spectacular.utils import inline_serializer
1416
from packageurl import PackageURL
1517
from rest_framework import serializers
18+
from rest_framework import status
1619
from rest_framework import viewsets
1720
from rest_framework.decorators import action
1821
from rest_framework.response import Response
@@ -272,6 +275,26 @@ def filter_purl(self, queryset, name, value):
272275
return self.queryset.filter(**lookups)
273276

274277

278+
class PackageurlListSerializer(serializers.Serializer):
279+
purls = serializers.ListField(
280+
child=serializers.CharField(),
281+
allow_empty=False,
282+
help_text="List of PackageURL strings in canonical form.",
283+
)
284+
285+
286+
class PackageBulkSearchRequestSerializer(PackageurlListSerializer):
287+
purl_only = serializers.BooleanField(required=False, default=False)
288+
plain_purl = serializers.BooleanField(required=False, default=False)
289+
290+
291+
class LookupRequestSerializer(serializers.Serializer):
292+
purl = serializers.CharField(
293+
required=True,
294+
help_text="PackageURL strings in canonical form.",
295+
)
296+
297+
275298
class PackageViewSet(viewsets.ReadOnlyModelViewSet):
276299
"""
277300
Lookup for vulnerable packages by Package URL.
@@ -283,21 +306,34 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
283306
filterset_class = PackageFilterSet
284307
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
285308

286-
# TODO: Fix the swagger documentation for this endpoint
287-
@action(detail=False, methods=["post"])
309+
@extend_schema(
310+
request=PackageBulkSearchRequestSerializer,
311+
responses={200: PackageSerializer(many=True)},
312+
)
313+
@action(
314+
detail=False,
315+
methods=["post"],
316+
serializer_class=PackageBulkSearchRequestSerializer,
317+
filter_backends=[],
318+
pagination_class=None,
319+
)
288320
def bulk_search(self, request):
289321
"""
290322
Lookup for vulnerable packages using many Package URLs at once.
291323
"""
292-
293-
purls = request.data.get("purls", []) or []
294-
purl_only = request.data.get("purl_only", False)
295-
plain_purl = request.data.get("plain_purl", False)
296-
if not purls or not isinstance(purls, list):
324+
serializer = self.serializer_class(data=request.data)
325+
if not serializer.is_valid():
297326
return Response(
298-
status=400,
299-
data={"Error": "A non-empty 'purls' list of PURLs is required."},
327+
status=status.HTTP_400_BAD_REQUEST,
328+
data={
329+
"error": serializer.errors,
330+
"message": "A non-empty 'purls' list of PURLs is required.",
331+
},
300332
)
333+
validated_data = serializer.validated_data
334+
purls = validated_data.get("purls")
335+
purl_only = validated_data.get("purl_only", False)
336+
plain_purl = validated_data.get("plain_purl", False)
301337

302338
if plain_purl:
303339
purl_objects = [PackageURL.from_string(purl) for purl in purls]
@@ -347,34 +383,66 @@ def all(self, request):
347383
vulnerable_purls = [str(package.package_url) for package in vulnerable_packages]
348384
return Response(vulnerable_purls)
349385

350-
@action(detail=False, methods=["post"])
386+
@extend_schema(
387+
request=LookupRequestSerializer,
388+
responses={200: PackageSerializer(many=True)},
389+
)
390+
@action(
391+
detail=False,
392+
methods=["post"],
393+
serializer_class=LookupRequestSerializer,
394+
filter_backends=[],
395+
pagination_class=None,
396+
)
351397
def lookup(self, request):
352398
"""
353399
Return the response for exact PackageURL requested for.
354400
"""
355-
purl = request.data.get("purl")
356-
if not purl:
401+
serializer = self.serializer_class(data=request.data)
402+
if not serializer.is_valid():
357403
return Response(
358-
status=400,
359-
data={"Error": "A 'purl' is required."},
404+
status=status.HTTP_400_BAD_REQUEST,
405+
data={
406+
"error": serializer.errors,
407+
"message": "A 'purl' is required.",
408+
},
360409
)
410+
validated_data = serializer.validated_data
411+
purl = validated_data.get("purl")
412+
361413
return Response(
362414
PackageSerializer(
363415
Package.objects.for_purls([purl]), many=True, context={"request": request}
364416
).data
365417
)
366418

367-
@action(detail=False, methods=["post"])
419+
@extend_schema(
420+
request=PackageurlListSerializer,
421+
responses={200: PackageSerializer(many=True)},
422+
)
423+
@action(
424+
detail=False,
425+
methods=["post"],
426+
serializer_class=PackageurlListSerializer,
427+
filter_backends=[],
428+
pagination_class=None,
429+
)
368430
def bulk_lookup(self, request):
369431
"""
370432
Return the response for exact PackageURLs requested for.
371433
"""
372-
purls = request.data.get("purls") or []
373-
if not purls:
434+
serializer = self.serializer_class(data=request.data)
435+
if not serializer.is_valid():
374436
return Response(
375-
status=400,
376-
data={"Error": "A non-empty 'purls' list of PURLs is required."},
437+
status=status.HTTP_400_BAD_REQUEST,
438+
data={
439+
"error": serializer.errors,
440+
"message": "A non-empty 'purls' list of PURLs is required.",
441+
},
377442
)
443+
validated_data = serializer.validated_data
444+
purls = validated_data.get("purls")
445+
378446
return Response(
379447
PackageSerializer(
380448
Package.objects.for_purls(purls),

vulnerabilities/tests/test_api.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,55 @@ def test_bulk_api_with_purl_only_option(self):
647647
assert len(response) == 1
648648
assert response[0] == "pkg:nginx/[email protected]"
649649

650+
def test_bulk_api_without_purls_list(self):
651+
request_body = {
652+
"purls": None,
653+
}
654+
response = self.csrf_client.post(
655+
"/api/packages/bulk_search",
656+
data=json.dumps(request_body),
657+
content_type="application/json",
658+
).json()
659+
660+
expected = {
661+
"error": {"purls": ["This field may not be null."]},
662+
"message": "A non-empty 'purls' list of PURLs is required.",
663+
}
664+
665+
self.assertEqual(response, expected)
666+
667+
def test_bulk_api_without_purls_empty_list(self):
668+
request_body = {
669+
"purls": [],
670+
}
671+
response = self.csrf_client.post(
672+
"/api/packages/bulk_search",
673+
data=json.dumps(request_body),
674+
content_type="application/json",
675+
).json()
676+
677+
expected = {
678+
"error": {"purls": ["This list may not be empty."]},
679+
"message": "A non-empty 'purls' list of PURLs is required.",
680+
}
681+
682+
self.assertEqual(response, expected)
683+
684+
def test_bulk_api_with_empty_request_body(self):
685+
request_body = {}
686+
response = self.csrf_client.post(
687+
"/api/packages/bulk_search",
688+
data=json.dumps(request_body),
689+
content_type="application/json",
690+
).json()
691+
692+
expected = {
693+
"error": {"purls": ["This field is required."]},
694+
"message": "A non-empty 'purls' list of PURLs is required.",
695+
}
696+
697+
self.assertEqual(response, expected)
698+
650699

651700
class BulkSearchAPICPE(TestCase):
652701
def setUp(self):
@@ -768,7 +817,13 @@ def test_lookup_endpoint_failure(self):
768817
data=json.dumps(request_body),
769818
content_type="application/json",
770819
).json()
771-
assert response == {"Error": "A 'purl' is required."}
820+
821+
expected = {
822+
"error": {"purl": ["This field may not be null."]},
823+
"message": "A 'purl' is required.",
824+
}
825+
826+
self.assertEqual(response, expected)
772827

773828
def test_lookup_endpoint(self):
774829
request_body = {"purl": "pkg:pypi/microweber/[email protected]"}
@@ -844,3 +899,18 @@ def test_bulk_lookup_endpoint(self):
844899
content_type="application/json",
845900
).json()
846901
assert len(response) == 1
902+
903+
def test_bulk_lookup_endpoint_failure(self):
904+
request_body = {"purls": None}
905+
response = self.csrf_client.post(
906+
"/api/packages/bulk_lookup",
907+
data=json.dumps(request_body),
908+
content_type="application/json",
909+
).json()
910+
911+
expected = {
912+
"error": {"purls": ["This field may not be null."]},
913+
"message": "A non-empty 'purls' list of PURLs is required.",
914+
}
915+
916+
self.assertEqual(response, expected)

0 commit comments

Comments
 (0)