1919from rich .progress import Progress , track
2020
2121from cve_bin_tool .async_utils import RateLimiter
22- from cve_bin_tool .error_handler import ErrorMode , NVDServiceError
22+ from cve_bin_tool .error_handler import ErrorMode , NVDKeyError , NVDServiceError
2323from cve_bin_tool .log import LOGGER
2424
2525FEED = "https://services.nvd.nist.gov/rest/json/cves/1.0"
@@ -102,6 +102,9 @@ async def get_nvd_params(
102102 self .logger .debug ("Fetching metadata from NVD..." )
103103 cve_count = await self .nvd_count_metadata (self .session )
104104
105+ if "apiKey" in self .params :
106+ await self .validate_nvd_api ()
107+
105108 if time_of_last_update :
106109 # Fetch all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries
107110 self .params ["modStartDate" ] = self .convert_date_to_nvd_date (
@@ -125,6 +128,28 @@ async def get_nvd_params(
125128 self .total_results = cve_count ["Total" ] - cve_count ["Rejected" ]
126129 self .logger .info (f"Adding { self .total_results } CVE entries" )
127130
131+ async def validate_nvd_api (self ):
132+ """
133+ Validate NVD API
134+ """
135+ param_dict = self .params .copy ()
136+ param_dict ["startIndex" ] = 0
137+ param_dict ["resultsPerPage" ] = 1
138+ try :
139+ self .logger .debug ("Validating NVD API..." )
140+ async with await self .session .get (
141+ self .feed , params = param_dict , raise_for_status = True
142+ ) as response :
143+ data = await response .json ()
144+ if data .get ("error" , False ):
145+ self .logger .error (f"NVD API error: { data ['error' ]} " )
146+ raise NVDKeyError (self .params ["apiKey" ])
147+ except NVDKeyError :
148+ # If the API key provided is invalid, delete from params
149+ # list and try the request again.
150+ self .logger .error ("unset api key, retrying" )
151+ del self .params ["apiKey" ]
152+
128153 async def load_nvd_request (self , start_index ):
129154 """Get single NVD request and update year_wise_data list which contains list of all CVEs"""
130155
@@ -141,6 +166,7 @@ async def load_nvd_request(self, start_index):
141166 ) as response :
142167 if response .status == 200 :
143168 fetched_data = await response .json ()
169+
144170 if start_index == 0 :
145171 # Update total results in case there is discrepancy between NVD dashboard and API
146172 self .total_results = fetched_data ["totalResults" ]
0 commit comments