Skip to content

Commit eb6e935

Browse files
authored
Merge pull request #993 from TG1999/override_throttle
Override throttle rate for each endpoint
2 parents 6f72ecf + 2f1cfc5 commit eb6e935

File tree

5 files changed

+163
-16
lines changed

5 files changed

+163
-16
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ Version v30.2.2
99
- We enabled API throttling for a basic user and for a staff user
1010
they can have unlimited access on API.
1111

12+
- We added throttle rate for each API endpoint and it can be
13+
configured from the settings #991 https://github.com/nexB/vulnerablecode/issues/991.
14+
1215

1316
Version v30.2.1
1417
----------------

vulnerabilities/api.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vulnerabilities.models import VulnerabilityReference
2424
from vulnerabilities.models import VulnerabilitySeverity
2525
from vulnerabilities.models import get_purl_query_lookups
26+
from vulnerabilities.throttling import StaffUserRateThrottle
2627

2728

2829
class VulnerabilitySeveritySerializer(serializers.ModelSerializer):
@@ -220,9 +221,11 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
220221
serializer_class = PackageSerializer
221222
filter_backends = (filters.DjangoFilterBackend,)
222223
filterset_class = PackageFilterSet
224+
throttle_classes = [StaffUserRateThrottle]
225+
throttle_scope = "packages"
223226

224227
# TODO: Fix the swagger documentation for this endpoint
225-
@action(detail=False, methods=["post"])
228+
@action(detail=False, methods=["post"], throttle_scope="bulk_search_packages")
226229
def bulk_search(self, request):
227230
"""
228231
See https://github.com/nexB/vulnerablecode/pull/369#issuecomment-796877606 for docs
@@ -246,15 +249,15 @@ def bulk_search(self, request):
246249
if purl_data:
247250
purl_response = PackageSerializer(purl_data[0], context={"request": request}).data
248251
else:
249-
purl_response = purl
252+
purl_response = purl.to_dict()
250253
purl_response["unresolved_vulnerabilities"] = []
251254
purl_response["resolved_vulnerabilities"] = []
252255
purl_response["purl"] = purl_string
253256
response.append(purl_response)
254257

255258
return Response(response)
256259

257-
@action(detail=False, methods=["get"])
260+
@action(detail=False, methods=["get"], throttle_scope="vulnerable_packages")
258261
def all(self, request):
259262
"""
260263
Return all the vulnerable Package URLs.
@@ -302,6 +305,8 @@ def get_queryset(self):
302305
serializer_class = VulnerabilitySerializer
303306
filter_backends = (filters.DjangoFilterBackend,)
304307
filterset_class = VulnerabilityFilterSet
308+
throttle_classes = [StaffUserRateThrottle]
309+
throttle_scope = "vulnerabilities"
305310

306311

307312
class CPEFilterSet(filters.FilterSet):
@@ -318,9 +323,11 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
318323
).distinct()
319324
serializer_class = VulnerabilitySerializer
320325
filter_backends = (filters.DjangoFilterBackend,)
326+
throttle_classes = [StaffUserRateThrottle]
321327
filterset_class = CPEFilterSet
328+
throttle_scope = "cpes"
322329

323-
@action(detail=False, methods=["post"])
330+
@action(detail=False, methods=["post"], throttle_scope="bulk_search_cpes")
324331
def bulk_search(self, request):
325332
"""
326333
This endpoint is used to search for vulnerabilities by more than one CPE.
@@ -357,3 +364,5 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
357364
serializer_class = VulnerabilitySerializer
358365
filter_backends = (filters.DjangoFilterBackend,)
359366
filterset_class = AliasFilterSet
367+
throttle_classes = [StaffUserRateThrottle]
368+
throttle_scope = "aliases"

vulnerabilities/tests/test_throttling.py

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
99

10+
import json
11+
1012
from django.contrib.auth import get_user_model
1113
from rest_framework.test import APIClient
1214
from rest_framework.test import APITestCase
@@ -30,10 +32,10 @@ def setUp(self):
3032
self.staff_csrf_client = APIClient(enforce_csrf_checks=True)
3133
self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth)
3234

33-
def test_api_throttling(self):
35+
def test_packages_endpoint_throttling(self):
3436

35-
# A basic user can only access API 5 times a day
36-
for i in range(0, 5):
37+
# A basic user can only access /packages endpoint 10 times a day
38+
for i in range(0, 10):
3739
response = self.csrf_client.get("/api/packages")
3840
self.assertEqual(response.status_code, 200)
3941
response = self.staff_csrf_client.get("/api/packages")
@@ -46,3 +48,123 @@ def test_api_throttling(self):
4648
response = self.staff_csrf_client.get("/api/packages", format="json")
4749
# 200 - staff user can access API unlimited times
4850
self.assertEqual(response.status_code, 200)
51+
52+
def test_cpes_endpoint_throttling(self):
53+
54+
# A basic user can only access /cpes endpoint 4 times a day
55+
for i in range(0, 4):
56+
response = self.csrf_client.get("/api/cpes")
57+
self.assertEqual(response.status_code, 200)
58+
response = self.staff_csrf_client.get("/api/cpes")
59+
self.assertEqual(response.status_code, 200)
60+
61+
response = self.csrf_client.get("/api/cpes")
62+
# 429 - too many requests for basic user
63+
self.assertEqual(response.status_code, 429)
64+
65+
response = self.staff_csrf_client.get("/api/cpes", format="json")
66+
# 200 - staff user can access API unlimited times
67+
self.assertEqual(response.status_code, 200)
68+
69+
def test_all_vulnerable_packages_endpoint_throttling(self):
70+
71+
# A basic user can only access /packages/all 1 time a day
72+
for i in range(0, 1):
73+
response = self.csrf_client.get("/api/packages/all")
74+
self.assertEqual(response.status_code, 200)
75+
response = self.staff_csrf_client.get("/api/packages/all")
76+
self.assertEqual(response.status_code, 200)
77+
78+
response = self.csrf_client.get("/api/packages/all")
79+
# 429 - too many requests for basic user
80+
self.assertEqual(response.status_code, 429)
81+
82+
response = self.staff_csrf_client.get("/api/packages/all", format="json")
83+
# 200 - staff user can access API unlimited times
84+
self.assertEqual(response.status_code, 200)
85+
86+
def test_vulnerabilities_endpoint_throttling(self):
87+
88+
# A basic user can only access /vulnerabilities 8 times a day
89+
for i in range(0, 8):
90+
response = self.csrf_client.get("/api/vulnerabilities")
91+
self.assertEqual(response.status_code, 200)
92+
response = self.staff_csrf_client.get("/api/vulnerabilities")
93+
self.assertEqual(response.status_code, 200)
94+
95+
response = self.csrf_client.get("/api/vulnerabilities")
96+
# 429 - too many requests for basic user
97+
self.assertEqual(response.status_code, 429)
98+
99+
response = self.staff_csrf_client.get("/api/vulnerabilities", format="json")
100+
# 200 - staff user can access API unlimited times
101+
self.assertEqual(response.status_code, 200)
102+
103+
def test_aliases_endpoint_throttling(self):
104+
105+
# A basic user can only access /alias 2 times a day
106+
for i in range(0, 2):
107+
response = self.csrf_client.get("/api/alias")
108+
self.assertEqual(response.status_code, 200)
109+
response = self.staff_csrf_client.get("/api/alias")
110+
self.assertEqual(response.status_code, 200)
111+
112+
response = self.csrf_client.get("/api/alias")
113+
# 429 - too many requests for basic user
114+
self.assertEqual(response.status_code, 429)
115+
116+
response = self.staff_csrf_client.get("/api/alias", format="json")
117+
# 200 - staff user can access API unlimited times
118+
self.assertEqual(response.status_code, 200)
119+
120+
def test_bulk_search_packages_endpoint_throttling(self):
121+
data = json.dumps({"purls": ["pkg:foo/bar"]})
122+
123+
# A basic user can only access /packages/bulk_search 6 times a day
124+
for i in range(0, 6):
125+
response = self.csrf_client.post(
126+
"/api/packages/bulk_search", data=data, content_type="application/json"
127+
)
128+
self.assertEqual(response.status_code, 200)
129+
response = self.staff_csrf_client.post(
130+
"/api/packages/bulk_search", data=data, content_type="application/json"
131+
)
132+
self.assertEqual(response.status_code, 200)
133+
134+
response = self.csrf_client.post(
135+
"/api/packages/bulk_search", data=data, content_type="application/json"
136+
)
137+
# 429 - too many requests for basic user
138+
self.assertEqual(response.status_code, 429)
139+
140+
response = self.staff_csrf_client.post(
141+
"/api/packages/bulk_search", data=data, content_type="application/json"
142+
)
143+
# 200 - staff user can access API unlimited times
144+
self.assertEqual(response.status_code, 200)
145+
146+
def test_bulk_search_cpes_endpoint_throttling(self):
147+
data = json.dumps({"cpes": ["cpe:foo/bar"]})
148+
149+
# A basic user can only access /cpes/bulk_search 5 times a day
150+
for i in range(0, 5):
151+
response = self.csrf_client.post(
152+
"/api/cpes/bulk_search", data=data, content_type="application/json"
153+
)
154+
self.assertEqual(response.status_code, 200)
155+
response = self.staff_csrf_client.post(
156+
"/api/cpes/bulk_search", data=data, content_type="application/json"
157+
)
158+
self.assertEqual(response.status_code, 200)
159+
160+
response = self.csrf_client.post(
161+
"/api/cpes/bulk_search", data=data, content_type="application/json"
162+
)
163+
# 429 - too many requests for basic user
164+
self.assertEqual(response.status_code, 429)
165+
166+
response = self.staff_csrf_client.post(
167+
"/api/cpes/bulk_search", data=data, content_type="application/json"
168+
)
169+
# 200 - staff user can access API unlimited times
170+
self.assertEqual(response.status_code, 200)

vulnerabilities/throttling.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,10 @@
66
# See https://github.com/nexB/vulnerablecode for support or download.
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
9+
from rest_framework.throttling import ScopedRateThrottle
910

10-
from django.contrib.auth import get_user_model
11-
from rest_framework.throttling import UserRateThrottle
1211

13-
User = get_user_model()
14-
15-
16-
class StaffUserRateThrottle(UserRateThrottle):
12+
class StaffUserRateThrottle(ScopedRateThrottle):
1713
def allow_request(self, request, view):
1814
"""
1915
Do not apply throttling for superusers and admins.

vulnerablecode/settings.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,28 @@
150150

151151
LOGIN_REDIRECT_URL = "/"
152152
LOGOUT_REDIRECT_URL = "/"
153-
THROTTLING_RATE = env.str("THROTTLING_RATE", default="1000/day")
153+
154+
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {
155+
"vulnerable_packages": "1/hour",
156+
"bulk_search_packages": "5/hour",
157+
"packages": "10/minute",
158+
"vulnerabilities": "10/minute",
159+
"aliases": "5/minute",
160+
"cpes": "5/minute",
161+
"bulk_search_cpes": "5/hour",
162+
}
154163

155164
if IS_TESTS:
156165
VULNERABLECODEIO_REQUIRE_AUTHENTICATION = True
157-
THROTTLING_RATE = "5/day"
166+
REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {
167+
"vulnerable_packages": "1/day",
168+
"bulk_search_packages": "6/day",
169+
"packages": "10/day",
170+
"vulnerabilities": "8/day",
171+
"aliases": "2/day",
172+
"cpes": "4/day",
173+
"bulk_search_cpes": "5/day",
174+
}
158175

159176

160177
USE_L10N = True
@@ -190,7 +207,7 @@
190207
"DEFAULT_THROTTLE_CLASSES": [
191208
"vulnerabilities.throttling.StaffUserRateThrottle",
192209
],
193-
"DEFAULT_THROTTLE_RATES": {"user": THROTTLING_RATE},
210+
"DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES,
194211
"DEFAULT_PAGINATION_CLASS": "vulnerabilities.pagination.SmallResultSetPagination",
195212
# Limit the load on the Database returning a small number of records by default. https://github.com/nexB/vulnerablecode/issues/819
196213
"PAGE_SIZE": 10,

0 commit comments

Comments
 (0)