2323from vulnerabilities .models import VulnerabilityReference
2424from vulnerabilities .models import VulnerabilitySeverity
2525from vulnerabilities .models import get_purl_query_lookups
26- from vulnerabilities .throttling import AliasesAPIThrottle
27- from vulnerabilities .throttling import BulkSearchCPEAPIThrottle
28- from vulnerabilities .throttling import BulkSearchPackagesAPIThrottle
29- from vulnerabilities .throttling import CPEAPIThrottle
30- from vulnerabilities .throttling import PackagesAPIThrottle
31- from vulnerabilities .throttling import VulnerabilitiesAPIThrottle
32- from vulnerabilities .throttling import VulnerablePackagesAPIThrottle
26+ from vulnerabilities .throttling import StaffUserRateThrottle
3327
3428
3529class VulnerabilitySeveritySerializer (serializers .ModelSerializer ):
@@ -227,18 +221,11 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
227221 serializer_class = PackageSerializer
228222 filter_backends = (filters .DjangoFilterBackend ,)
229223 filterset_class = PackageFilterSet
230-
231- def get_throttles (self ):
232- if self .action == "bulk_search" :
233- throttle_classes = [BulkSearchPackagesAPIThrottle ]
234- elif self .action == "all" :
235- throttle_classes = [VulnerablePackagesAPIThrottle ]
236- else :
237- throttle_classes = [PackagesAPIThrottle ]
238- return [throttle () for throttle in throttle_classes ]
224+ throttle_classes = [StaffUserRateThrottle ]
225+ throttle_scope = "packages"
239226
240227 # TODO: Fix the swagger documentation for this endpoint
241- @action (detail = False , methods = ["post" ])
228+ @action (detail = False , methods = ["post" ], throttle_scope = "bulk_search_packages" )
242229 def bulk_search (self , request ):
243230 """
244231 See https://github.com/nexB/vulnerablecode/pull/369#issuecomment-796877606 for docs
@@ -270,7 +257,7 @@ def bulk_search(self, request):
270257
271258 return Response (response )
272259
273- @action (detail = False , methods = ["get" ])
260+ @action (detail = False , methods = ["get" ], throttle_scope = "vulnerable_packages" )
274261 def all (self , request ):
275262 """
276263 Return all the vulnerable Package URLs.
@@ -318,7 +305,8 @@ def get_queryset(self):
318305 serializer_class = VulnerabilitySerializer
319306 filter_backends = (filters .DjangoFilterBackend ,)
320307 filterset_class = VulnerabilityFilterSet
321- throttle_classes = [VulnerabilitiesAPIThrottle ]
308+ throttle_classes = [StaffUserRateThrottle ]
309+ throttle_scope = "vulnerabilities"
322310
323311
324312class CPEFilterSet (filters .FilterSet ):
@@ -335,16 +323,11 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
335323 ).distinct ()
336324 serializer_class = VulnerabilitySerializer
337325 filter_backends = (filters .DjangoFilterBackend ,)
326+ throttle_classes = [StaffUserRateThrottle ]
338327 filterset_class = CPEFilterSet
328+ throttle_scope = "cpes"
339329
340- def get_throttles (self ):
341- if self .action == "bulk_search" :
342- throttle_classes = [BulkSearchCPEAPIThrottle ]
343- else :
344- throttle_classes = [CPEAPIThrottle ]
345- return [throttle () for throttle in throttle_classes ]
346-
347- @action (detail = False , methods = ["post" ])
330+ @action (detail = False , methods = ["post" ], throttle_scope = "bulk_search_cpes" )
348331 def bulk_search (self , request ):
349332 """
350333 This endpoint is used to search for vulnerabilities by more than one CPE.
@@ -381,4 +364,5 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
381364 serializer_class = VulnerabilitySerializer
382365 filter_backends = (filters .DjangoFilterBackend ,)
383366 filterset_class = AliasFilterSet
384- throttle_classes = [AliasesAPIThrottle ]
367+ throttle_classes = [StaffUserRateThrottle ]
368+ throttle_scope = "aliases"
0 commit comments