Skip to content

Commit 2d0a10b

Browse files
committed
Add tests for v2 API
Signed-off-by: Tushar Goel <[email protected]>
1 parent e389ce5 commit 2d0a10b

File tree

3 files changed

+379
-53
lines changed

3 files changed

+379
-53
lines changed

vulnerabilities/api.py

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ class AliasViewSet(VulnerabilityViewSet):
692692

693693
filterset_class = AliasFilterSet
694694

695+
695696
class WeaknessV2Serializer(serializers.ModelSerializer):
696697
cwe_id = serializers.CharField()
697698
name = serializers.CharField()
@@ -701,16 +702,6 @@ class Meta:
701702
model = Weakness
702703
fields = ["cwe_id", "name", "description"]
703704

704-
class VulnerabilityFilter(filters.FilterSet):
705-
vulnerability_id = filters.CharFilter(field_name='vulnerability_id', lookup_expr='exact')
706-
vulnerability_id__in = filters.BaseInFilter(field_name='vulnerability_id', lookup_expr='in')
707-
alias = filters.CharFilter(field_name='aliases__alias', lookup_expr='exact')
708-
alias__in = filters.BaseInFilter(field_name='aliases__alias', lookup_expr='in')
709-
710-
class Meta:
711-
model = Vulnerability
712-
fields = ['vulnerability_id', 'vulnerability_id__in', 'alias', 'alias__in']
713-
714705

715706
class VulnerabilityReferenceV2Serializer(serializers.ModelSerializer):
716707
url = serializers.CharField()
@@ -721,10 +712,11 @@ class Meta:
721712
model = VulnerabilityReference
722713
fields = ["url", "reference_type", "reference_id"]
723714

715+
724716
class VulnerabilityV2Serializer(BaseResourceSerializer):
725717
aliases = serializers.SerializerMethodField()
726718
weaknesses = WeaknessV2Serializer(many=True)
727-
references = VulnerabilityReferenceV2Serializer(many=True, source='vulnerabilityreference_set')
719+
references = VulnerabilityReferenceV2Serializer(many=True, source="vulnerabilityreference_set")
728720
severities = VulnerabilitySeveritySerializer(many=True)
729721

730722
class Meta:
@@ -745,52 +737,74 @@ def get_severities(self, obj):
745737
return obj.severities
746738

747739

740+
class VulnerabilityListSerializer(serializers.ModelSerializer):
741+
url = serializers.SerializerMethodField()
742+
743+
class Meta:
744+
model = Vulnerability
745+
fields = ["vulnerability_id", "url"]
746+
747+
def get_url(self, obj):
748+
request = self.context.get("request")
749+
return reverse(
750+
"vulnerability-v2-detail",
751+
kwargs={"vulnerability_id": obj.vulnerability_id},
752+
request=request,
753+
)
754+
755+
748756
class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
749757
queryset = Vulnerability.objects.all()
750758
serializer_class = VulnerabilityV2Serializer
759+
lookup_field = "vulnerability_id"
760+
761+
def get_queryset(self):
762+
queryset = super().get_queryset()
763+
vulnerability_ids = self.request.query_params.getlist("vulnerability_id")
764+
aliases = self.request.query_params.getlist("alias")
765+
766+
if vulnerability_ids:
767+
queryset = queryset.filter(vulnerability_id__in=vulnerability_ids)
768+
769+
if aliases:
770+
queryset = queryset.filter(aliases__alias__in=aliases).distinct()
771+
772+
return queryset
773+
774+
def get_serializer_class(self):
775+
if self.action == "list":
776+
return VulnerabilityListSerializer
777+
return super().get_serializer_class()
751778

752779
def list(self, request, *args, **kwargs):
753780
queryset = self.get_queryset()
754-
# Apply pagination
781+
vulnerability_ids = request.query_params.getlist("vulnerability_id")
782+
783+
# If exactly one vulnerability_id is provided, return the serialized data
784+
if len(vulnerability_ids) == 1:
785+
try:
786+
vulnerability = queryset.get(vulnerability_id=vulnerability_ids[0])
787+
serializer = self.get_serializer(vulnerability)
788+
return Response(serializer.data)
789+
except Vulnerability.DoesNotExist:
790+
return Response({"detail": "Not found."}, status=404)
791+
792+
# Otherwise, return a dictionary of vulnerabilities keyed by vulnerability_id
755793
page = self.paginate_queryset(queryset)
756794
if page is not None:
757795
serializer = self.get_serializer(page, many=True)
758796
data = serializer.data
759-
vulnerabilities = {item['vulnerability_id']: item for item in data}
760-
# Use 'self.get_paginated_response' to include pagination data
761-
return self.get_paginated_response({'vulnerabilities': vulnerabilities})
797+
vulnerabilities = {item["vulnerability_id"]: item for item in data}
798+
return self.get_paginated_response({"vulnerabilities": vulnerabilities})
762799

763-
# If pagination is not applied
764800
serializer = self.get_serializer(queryset, many=True)
765801
data = serializer.data
766-
vulnerabilities = {item['vulnerability_id']: item for item in data}
767-
return Response({'vulnerabilities': vulnerabilities})
768-
769-
770-
class PackageFilter(filters.FilterSet):
771-
purl = filters.CharFilter(field_name='package_url', lookup_expr='exact')
772-
purl__in = filters.BaseInFilter(field_name='package_url', lookup_expr='in')
773-
affected_by_vulnerability = filters.CharFilter(
774-
field_name='affected_by_vulnerabilities__vulnerability_id',
775-
lookup_expr='exact'
776-
)
777-
fixing_vulnerability = filters.CharFilter(
778-
field_name='fixing_vulnerabilities__vulnerability_id',
779-
lookup_expr='exact'
780-
)
781-
782-
class Meta:
783-
model = Package
784-
fields = [
785-
'purl',
786-
'purl__in',
787-
'affected_by_vulnerability',
788-
'fixing_vulnerability',
789-
]
802+
vulnerabilities = {item["vulnerability_id"]: item for item in data}
803+
return Response({"vulnerabilities": vulnerabilities})
790804

791805

792806
class PackageV2Serializer(serializers.ModelSerializer):
793-
purl = serializers.CharField(source='package_url')
807+
purl = serializers.CharField(source="package_url")
794808
affected_by_vulnerabilities = serializers.SerializerMethodField()
795809
fixing_vulnerabilities = serializers.SerializerMethodField()
796810
next_non_vulnerable_version = serializers.CharField(read_only=True)
@@ -799,11 +813,11 @@ class PackageV2Serializer(serializers.ModelSerializer):
799813
class Meta:
800814
model = Package
801815
fields = [
802-
'purl',
803-
'affected_by_vulnerabilities',
804-
'fixing_vulnerabilities',
805-
'next_non_vulnerable_version',
806-
'latest_non_vulnerable_version',
816+
"purl",
817+
"affected_by_vulnerabilities",
818+
"fixing_vulnerabilities",
819+
"next_non_vulnerable_version",
820+
"latest_non_vulnerable_version",
807821
]
808822

809823
def get_affected_by_vulnerabilities(self, obj):
@@ -816,19 +830,36 @@ def get_fixing_vulnerabilities(self, obj):
816830
class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
817831
queryset = Package.objects.all()
818832
serializer_class = PackageV2Serializer
819-
filterset_class = PackageFilter
833+
834+
def get_queryset(self):
835+
queryset = super().get_queryset()
836+
package_purls = self.request.query_params.getlist("purl")
837+
affected_by_vulnerability = self.request.query_params.get("affected_by_vulnerability")
838+
fixing_vulnerability = self.request.query_params.get("fixing_vulnerability")
839+
840+
if package_purls:
841+
queryset = queryset.filter(package_url__in=package_purls)
842+
if affected_by_vulnerability:
843+
queryset = queryset.filter(
844+
affected_by_vulnerabilities__vulnerability_id=affected_by_vulnerability
845+
)
846+
if fixing_vulnerability:
847+
queryset = queryset.filter(
848+
fixing_vulnerabilities__vulnerability_id=fixing_vulnerability
849+
)
850+
return queryset.with_is_vulnerable()
820851

821852
def list(self, request, *args, **kwargs):
822-
queryset = self.get_queryset().with_is_vulnerable()
853+
queryset = self.get_queryset()
823854
# Apply pagination
824855
page = self.paginate_queryset(queryset)
825856
if page is not None:
826857
serializer = self.get_serializer(page, many=True)
827858
data = serializer.data
828859
# Use 'self.get_paginated_response' to include pagination data
829-
return self.get_paginated_response({'purls': data})
860+
return self.get_paginated_response({"purls": data})
830861

831862
# If pagination is not applied
832863
serializer = self.get_serializer(queryset, many=True)
833864
data = serializer.data
834-
return Response({'purls': data})
865+
return Response({"purls": data})

0 commit comments

Comments
 (0)