Skip to content

Commit 98df505

Browse files
authored
Merge branch 'main' into ngi-acknowledgements
2 parents de4372d + c42945c commit 98df505

File tree

5 files changed

+60
-183
lines changed

5 files changed

+60
-183
lines changed

vulnerabilities/api.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from rest_framework import viewsets
1717
from rest_framework.decorators import action
1818
from rest_framework.response import Response
19+
from rest_framework.throttling import AnonRateThrottle
20+
from rest_framework.throttling import UserRateThrottle
1921

2022
from vulnerabilities.models import Alias
2123
from vulnerabilities.models import Package
@@ -231,11 +233,10 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
231233
serializer_class = PackageSerializer
232234
filter_backends = (filters.DjangoFilterBackend,)
233235
filterset_class = PackageFilterSet
234-
throttle_classes = [StaffUserRateThrottle]
235-
throttle_scope = "packages"
236+
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
236237

237238
# TODO: Fix the swagger documentation for this endpoint
238-
@action(detail=False, methods=["post"], throttle_scope="bulk_search_packages")
239+
@action(detail=False, methods=["post"])
239240
def bulk_search(self, request):
240241
"""
241242
Lookup for vulnerable packages using many Package URLs at once.
@@ -289,7 +290,7 @@ def bulk_search(self, request):
289290
vulnerable_purls = [str(package.package_url) for package in vulnerable_purls]
290291
return Response(data=vulnerable_purls)
291292

292-
@action(detail=False, methods=["get"], throttle_scope="vulnerable_packages")
293+
@action(detail=False, methods=["get"])
293294
def all(self, request):
294295
"""
295296
Return the Package URLs of all packages known to be vulnerable.
@@ -341,8 +342,7 @@ def get_queryset(self):
341342
serializer_class = VulnerabilitySerializer
342343
filter_backends = (filters.DjangoFilterBackend,)
343344
filterset_class = VulnerabilityFilterSet
344-
throttle_classes = [StaffUserRateThrottle]
345-
throttle_scope = "vulnerabilities"
345+
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
346346

347347

348348
class CPEFilterSet(filters.FilterSet):
@@ -363,11 +363,10 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
363363
).distinct()
364364
serializer_class = VulnerabilitySerializer
365365
filter_backends = (filters.DjangoFilterBackend,)
366-
throttle_classes = [StaffUserRateThrottle]
366+
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
367367
filterset_class = CPEFilterSet
368-
throttle_scope = "cpes"
369368

370-
@action(detail=False, methods=["post"], throttle_scope="bulk_search_cpes")
369+
@action(detail=False, methods=["post"])
371370
def bulk_search(self, request):
372371
"""
373372
Lookup for vulnerabilities using many CPEs at once.
@@ -409,5 +408,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
409408
serializer_class = VulnerabilitySerializer
410409
filter_backends = (filters.DjangoFilterBackend,)
411410
filterset_class = AliasFilterSet
412-
throttle_classes = [StaffUserRateThrottle]
413-
throttle_scope = "aliases"
411+
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]

vulnerabilities/tests/test_auth.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

vulnerabilities/tests/test_throttling.py

Lines changed: 25 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ def setUp(self):
2929
self.staff_csrf_client = APIClient(enforce_csrf_checks=True)
3030
self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth)
3131

32-
def test_packages_endpoint_throttling(self):
32+
self.csrf_client_anon = APIClient(enforce_csrf_checks=True)
33+
self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True)
3334

34-
# A basic user can only access /packages endpoint 10 times a day
35-
for i in range(0, 10):
35+
def test_package_endpoint_throttling(self):
36+
for i in range(0, 20):
3637
response = self.csrf_client.get("/api/packages")
3738
self.assertEqual(response.status_code, 200)
3839
response = self.staff_csrf_client.get("/api/packages")
@@ -46,122 +47,36 @@ def test_packages_endpoint_throttling(self):
4647
# 200 - staff user can access API unlimited times
4748
self.assertEqual(response.status_code, 200)
4849

49-
def test_cpes_endpoint_throttling(self):
50-
51-
# A basic user can only access /cpes endpoint 4 times a day
52-
for i in range(0, 4):
53-
response = self.csrf_client.get("/api/cpes")
54-
self.assertEqual(response.status_code, 200)
55-
response = self.staff_csrf_client.get("/api/cpes")
56-
self.assertEqual(response.status_code, 200)
57-
58-
response = self.csrf_client.get("/api/cpes")
59-
# 429 - too many requests for basic user
60-
self.assertEqual(response.status_code, 429)
61-
62-
response = self.staff_csrf_client.get("/api/cpes", format="json")
63-
# 200 - staff user can access API unlimited times
64-
self.assertEqual(response.status_code, 200)
65-
66-
def test_all_vulnerable_packages_endpoint_throttling(self):
67-
68-
# A basic user can only access /packages/all 1 time a day
69-
for i in range(0, 1):
70-
response = self.csrf_client.get("/api/packages/all")
71-
self.assertEqual(response.status_code, 200)
72-
response = self.staff_csrf_client.get("/api/packages/all")
73-
self.assertEqual(response.status_code, 200)
74-
75-
response = self.csrf_client.get("/api/packages/all")
76-
# 429 - too many requests for basic user
77-
self.assertEqual(response.status_code, 429)
78-
79-
response = self.staff_csrf_client.get("/api/packages/all", format="json")
80-
# 200 - staff user can access API unlimited times
81-
self.assertEqual(response.status_code, 200)
82-
83-
def test_vulnerabilities_endpoint_throttling(self):
84-
85-
# A basic user can only access /vulnerabilities 8 times a day
86-
for i in range(0, 8):
87-
response = self.csrf_client.get("/api/vulnerabilities")
88-
self.assertEqual(response.status_code, 200)
89-
response = self.staff_csrf_client.get("/api/vulnerabilities")
50+
# A anonymous user can only access /packages endpoint 10 times a day
51+
for i in range(0, 10):
52+
print(i)
53+
response = self.csrf_client_anon.get("/api/packages")
9054
self.assertEqual(response.status_code, 200)
9155

92-
response = self.csrf_client.get("/api/vulnerabilities")
93-
# 429 - too many requests for basic user
56+
response = self.csrf_client_anon.get("/api/packages")
57+
# 429 - too many requests for anon user
9458
self.assertEqual(response.status_code, 429)
59+
self.assertEqual(
60+
response.data.get("message"),
61+
"Your request has been throttled. Please contact [email protected]",
62+
)
9563

96-
response = self.staff_csrf_client.get("/api/vulnerabilities", format="json")
97-
# 200 - staff user can access API unlimited times
98-
self.assertEqual(response.status_code, 200)
99-
100-
def test_aliases_endpoint_throttling(self):
101-
102-
# A basic user can only access /alias 2 times a day
103-
for i in range(0, 2):
104-
response = self.csrf_client.get("/api/aliases")
105-
self.assertEqual(response.status_code, 200)
106-
response = self.staff_csrf_client.get("/api/aliases")
107-
self.assertEqual(response.status_code, 200)
108-
109-
response = self.csrf_client.get("/api/aliases")
110-
# 429 - too many requests for basic user
64+
response = self.csrf_client_anon.get("/api/vulnerabilities")
65+
# 429 - too many requests for anon user
11166
self.assertEqual(response.status_code, 429)
67+
self.assertEqual(
68+
response.data.get("message"),
69+
"Your request has been throttled. Please contact [email protected]",
70+
)
11271

113-
response = self.staff_csrf_client.get("/api/aliases", format="json")
114-
# 200 - staff user can access API unlimited times
115-
self.assertEqual(response.status_code, 200)
116-
117-
def test_bulk_search_packages_endpoint_throttling(self):
11872
data = json.dumps({"purls": ["pkg:foo/bar"]})
11973

120-
# A basic user can only access /packages/bulk_search 6 times a day
121-
for i in range(0, 6):
122-
response = self.csrf_client.post(
123-
"/api/packages/bulk_search", data=data, content_type="application/json"
124-
)
125-
self.assertEqual(response.status_code, 200)
126-
response = self.staff_csrf_client.post(
127-
"/api/packages/bulk_search", data=data, content_type="application/json"
128-
)
129-
self.assertEqual(response.status_code, 200)
130-
131-
response = self.csrf_client.post(
74+
response = self.csrf_client_anon.post(
13275
"/api/packages/bulk_search", data=data, content_type="application/json"
13376
)
134-
# 429 - too many requests for basic user
77+
# 429 - too many requests for anon user
13578
self.assertEqual(response.status_code, 429)
136-
137-
response = self.staff_csrf_client.post(
138-
"/api/packages/bulk_search", data=data, content_type="application/json"
79+
self.assertEqual(
80+
response.data.get("message"),
81+
"Your request has been throttled. Please contact [email protected]",
13982
)
140-
# 200 - staff user can access API unlimited times
141-
self.assertEqual(response.status_code, 200)
142-
143-
def test_bulk_search_cpes_endpoint_throttling(self):
144-
data = json.dumps({"cpes": ["cpe:foo/bar"]})
145-
146-
# A basic user can only access /cpes/bulk_search 5 times a day
147-
for i in range(0, 5):
148-
response = self.csrf_client.post(
149-
"/api/cpes/bulk_search", data=data, content_type="application/json"
150-
)
151-
self.assertEqual(response.status_code, 200)
152-
response = self.staff_csrf_client.post(
153-
"/api/cpes/bulk_search", data=data, content_type="application/json"
154-
)
155-
self.assertEqual(response.status_code, 200)
156-
157-
response = self.csrf_client.post(
158-
"/api/cpes/bulk_search", data=data, content_type="application/json"
159-
)
160-
# 429 - too many requests for basic user
161-
self.assertEqual(response.status_code, 429)
162-
163-
response = self.staff_csrf_client.post(
164-
"/api/cpes/bulk_search", data=data, content_type="application/json"
165-
)
166-
# 200 - staff user can access API unlimited times
167-
self.assertEqual(response.status_code, 200)

vulnerabilities/throttling.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
# See https://github.com/nexB/vulnerablecode for support or download.
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
9-
from rest_framework.throttling import ScopedRateThrottle
9+
from rest_framework.exceptions import Throttled
10+
from rest_framework.throttling import UserRateThrottle
11+
from rest_framework.views import exception_handler
1012

1113

12-
class StaffUserRateThrottle(ScopedRateThrottle):
14+
class StaffUserRateThrottle(UserRateThrottle):
1315
def allow_request(self, request, view):
1416
"""
1517
Do not apply throttling for superusers and admins.
@@ -18,3 +20,19 @@ def allow_request(self, request, view):
1820
return True
1921

2022
return super().allow_request(request, view)
23+
24+
25+
def throttled_exception_handler(exception, context):
26+
"""
27+
Return this response whenever a request has been throttled
28+
"""
29+
30+
response = exception_handler(exception, context)
31+
32+
if isinstance(exception, Throttled):
33+
response_data = {
34+
"message": "Your request has been throttled. Please contact [email protected]"
35+
}
36+
response.data = response_data
37+
38+
return response

vulnerablecode/settings.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -171,35 +171,11 @@
171171
LOGIN_REDIRECT_URL = "/"
172172
LOGOUT_REDIRECT_URL = "/"
173173

174-
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {
175-
"vulnerable_packages": env.str(
176-
"VULNERABLECODE_ALL_VULNERABLE_PACKAGES_THROTTLING_RATE", default="1/hour"
177-
),
178-
"bulk_search_packages": env.str(
179-
"VULNERABLECODE_BULK_SEARCH_PACKAGE_THROTTLING_RATE", default="5/hour"
180-
),
181-
"packages": env.str("VULNERABLECODE_PACKAGES_SEARCH_THROTTLING_RATE", default="10/minute"),
182-
"vulnerabilities": env.str(
183-
"VULNERABLECODE_VULNERABILITIES_SEARCH_THROTTLING_RATE", default="10/minute"
184-
),
185-
"aliases": env.str("VULNERABLECODE_ALIASES_SEARCH_THROTTLING_RATE", default="5/minute"),
186-
"cpes": env.str("VULNERABLECODE_CPE_SEARCH_THROTTLING_RATE", default="5/minute"),
187-
"bulk_search_cpes": env.str(
188-
"VULNERABLECODE_BULK_SEARCH_CPE_THROTTLING_RATE", default="5/minute"
189-
),
190-
}
174+
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "3600/hour", "user": "10800/hour"}
191175

192176
if IS_TESTS:
193-
VULNERABLECODEIO_REQUIRE_AUTHENTICATION = True
194-
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {
195-
"vulnerable_packages": "1/day",
196-
"bulk_search_packages": "6/day",
197-
"packages": "10/day",
198-
"vulnerabilities": "8/day",
199-
"aliases": "2/day",
200-
"cpes": "4/day",
201-
"bulk_search_cpes": "5/day",
202-
}
177+
VULNERABLECODEIO_REQUIRE_AUTHENTICATION = False
178+
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "10/day", "user": "20/day"}
203179

204180

205181
USE_L10N = True
@@ -237,8 +213,11 @@
237213
),
238214
"DEFAULT_THROTTLE_CLASSES": [
239215
"vulnerabilities.throttling.StaffUserRateThrottle",
216+
"rest_framework.throttling.AnonRateThrottle",
217+
"rest_framework.throttling.UserRateThrottle",
240218
],
241219
"DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES,
220+
"EXCEPTION_HANDLER": "vulnerabilities.throttling.throttled_exception_handler",
242221
"DEFAULT_PAGINATION_CLASS": "vulnerabilities.pagination.SmallResultSetPagination",
243222
# Limit the load on the Database returning a small number of records by default. https://github.com/nexB/vulnerablecode/issues/819
244223
"PAGE_SIZE": 10,

0 commit comments

Comments
 (0)