19
19
from rich .progress import Progress , track
20
20
21
21
from cve_bin_tool .async_utils import RateLimiter
22
- from cve_bin_tool .error_handler import ErrorMode , NVDServiceError
22
+ from cve_bin_tool .error_handler import ErrorMode , NVDKeyError , NVDServiceError
23
23
from cve_bin_tool .log import LOGGER
24
24
25
25
FEED = "https://services.nvd.nist.gov/rest/json/cves/1.0"
@@ -102,6 +102,9 @@ async def get_nvd_params(
102
102
self .logger .debug ("Fetching metadata from NVD..." )
103
103
cve_count = await self .nvd_count_metadata (self .session )
104
104
105
+ if "apiKey" in self .params :
106
+ await self .validate_nvd_api ()
107
+
105
108
if time_of_last_update :
106
109
# Fetch all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries
107
110
self .params ["modStartDate" ] = self .convert_date_to_nvd_date (
@@ -125,6 +128,28 @@ async def get_nvd_params(
125
128
self .total_results = cve_count ["Total" ] - cve_count ["Rejected" ]
126
129
self .logger .info (f"Adding { self .total_results } CVE entries" )
127
130
131
+ async def validate_nvd_api (self ):
132
+ """
133
+ Validate NVD API
134
+ """
135
+ param_dict = self .params .copy ()
136
+ param_dict ["startIndex" ] = 0
137
+ param_dict ["resultsPerPage" ] = 1
138
+ try :
139
+ self .logger .debug ("Validating NVD API..." )
140
+ async with await self .session .get (
141
+ self .feed , params = param_dict , raise_for_status = True
142
+ ) as response :
143
+ data = await response .json ()
144
+ if data .get ("error" , False ):
145
+ self .logger .error (f"NVD API error: { data ['error' ]} " )
146
+ raise NVDKeyError (self .params ["apiKey" ])
147
+ except NVDKeyError :
148
+ # If the API key provided is invalid, delete from params
149
+ # list and try the request again.
150
+ self .logger .error ("unset api key, retrying" )
151
+ del self .params ["apiKey" ]
152
+
128
153
async def load_nvd_request (self , start_index ):
129
154
"""Get single NVD request and update year_wise_data list which contains list of all CVEs"""
130
155
@@ -141,6 +166,7 @@ async def load_nvd_request(self, start_index):
141
166
) as response :
142
167
if response .status == 200 :
143
168
fetched_data = await response .json ()
169
+
144
170
if start_index == 0 :
145
171
# Update total results in case there is discrepancy between NVD dashboard and API
146
172
self .total_results = fetched_data ["totalResults" ]
0 commit comments