26
26
from typing import Optional
27
27
from pathlib import Path
28
28
29
+ import json
30
+
29
31
30
32
class Vulnerability :
31
- def __init__ (self , id : str , url : str ):
33
+ def __init__ (self , id : str , url : str , dependency : str , version : str ):
32
34
self .id = id
33
35
self .url = url
36
+ self .dependency = dependency
37
+ self .version = version
38
+
39
+
40
+ class VulnerabilityEncoder (json .JSONEncoder ):
41
+ def default (self , obj ):
42
+ if isinstance (obj , Vulnerability ):
43
+ return {"id" : obj .id , "url" : obj .url , "dependency" : obj .dependency , "version" : obj .version }
44
+ # Let the base class default method raise the TypeError
45
+ return json .JSONEncoder .default (self , obj )
34
46
35
47
36
48
vulnerability_found_message = """For each dependency and vulnerability, check the following:
@@ -64,7 +76,7 @@ def __init__(self, id: str, url: str):
64
76
65
77
def query_ghad (
66
78
dependencies : dict [str , Dependency ], gh_token : str , repo_path : Path
67
- ) -> dict [ str , list [Vulnerability ] ]:
79
+ ) -> list [Vulnerability ]:
68
80
"""Queries the GitHub Advisory Database for vulnerabilities reported for Node's dependencies.
69
81
70
82
The database supports querying by package name in the NPM ecosystem, so we only send queries for the dependencies
@@ -86,7 +98,7 @@ def query_ghad(
86
98
parse_results = True ,
87
99
)
88
100
89
- found_vulnerabilities : dict [ str , list [Vulnerability ]] = defaultdict ( list )
101
+ found_vulnerabilities : list [Vulnerability ] = list ( )
90
102
for name , dep in deps_in_npm .items ():
91
103
variables_package = {
92
104
"package_name" : dep .npm_name ,
@@ -103,10 +115,10 @@ def query_ghad(
103
115
and v ["advisory" ]["ghsaId" ] not in ignore_list
104
116
]
105
117
if matching_vulns :
106
- found_vulnerabilities [ name ] .extend (
118
+ found_vulnerabilities .extend (
107
119
[
108
120
Vulnerability (
109
- id = vuln ["advisory" ]["ghsaId" ], url = vuln ["advisory" ]["permalink" ]
121
+ id = vuln ["advisory" ]["ghsaId" ], url = vuln ["advisory" ]["permalink" ], dependency = name , version = dep_version
110
122
)
111
123
for vuln in matching_vulns
112
124
]
@@ -117,7 +129,7 @@ def query_ghad(
117
129
118
130
def query_nvd (
119
131
dependencies : dict [str , Dependency ], api_key : Optional [str ], repo_path : Path
120
- ) -> dict [ str , list [Vulnerability ] ]:
132
+ ) -> list [Vulnerability ]:
121
133
"""Queries the National Vulnerability Database for vulnerabilities reported for Node's dependencies.
122
134
123
135
The database supports querying by CPE (Common Platform Enumeration) or by a keyword present in the CVE's
@@ -129,7 +141,7 @@ def query_nvd(
129
141
for name , dep in dependencies .items ()
130
142
if dep .cpe is not None or dep .keyword is not None
131
143
}
132
- found_vulnerabilities : dict [ str , list [Vulnerability ]] = defaultdict ( list )
144
+ found_vulnerabilities : list [Vulnerability ] = list ( )
133
145
for name , dep in deps_in_nvd .items ():
134
146
query_results = [
135
147
cve
@@ -139,8 +151,9 @@ def query_nvd(
139
151
if cve .id not in ignore_list
140
152
]
141
153
if query_results :
142
- found_vulnerabilities [name ].extend (
143
- [Vulnerability (id = cve .id , url = cve .url ) for cve in query_results ]
154
+ version = dep .version_parser (repo_path )
155
+ found_vulnerabilities .extend (
156
+ [Vulnerability (id = cve .id , url = cve .url , dependency = name , version = version ) for cve in query_results ]
144
157
)
145
158
146
159
return found_vulnerabilities
@@ -170,10 +183,16 @@ def main() -> int:
170
183
"--nvd-key" ,
171
184
help = "the NVD API key for querying the National Vulnerability Database" ,
172
185
)
186
+ parser .add_argument (
187
+ "--json-output" ,
188
+ action = 'store_true' ,
189
+ help = "the NVD API key for querying the National Vulnerability Database" ,
190
+ )
173
191
repo_path : Path = parser .parse_args ().node_repo_path
174
192
repo_branch : str = parser .parse_args ().node_repo_branch
175
193
gh_token = parser .parse_args ().gh_token
176
194
nvd_key = parser .parse_args ().nvd_key
195
+ json_output : bool = parser .parse_args ().json_output
177
196
if not repo_path .exists () or not (repo_path / ".git" ).exists ():
178
197
raise RuntimeError (
179
198
"Invalid argument: '{repo_path}' is not a valid Node git repository"
@@ -196,25 +215,27 @@ def main() -> int:
196
215
for name , dep in dependencies_info .items ()
197
216
if name in dependencies_per_branch [repo_branch ]
198
217
}
199
- ghad_vulnerabilities : dict [ str , list [Vulnerability ] ] = (
218
+ ghad_vulnerabilities : list [Vulnerability ] = (
200
219
{} if gh_token is None else query_ghad (dependencies , gh_token , repo_path )
201
220
)
202
- nvd_vulnerabilities : dict [ str , list [Vulnerability ] ] = query_nvd (
221
+ nvd_vulnerabilities : list [Vulnerability ] = query_nvd (
203
222
dependencies , nvd_key , repo_path
204
223
)
205
224
206
- if not ghad_vulnerabilities and not nvd_vulnerabilities :
225
+ all_vulnerabilities = {"vulnerabilities" : ghad_vulnerabilities + nvd_vulnerabilities }
226
+ no_vulnerabilities_found = not ghad_vulnerabilities and not nvd_vulnerabilities
227
+ if json_output :
228
+ print (json .dumps (all_vulnerabilities , cls = VulnerabilityEncoder ))
229
+ return 0 if no_vulnerabilities_found else 1
230
+ elif no_vulnerabilities_found :
207
231
print (f"No new vulnerabilities found ({ len (ignore_list )} ignored)" )
208
232
return 0
209
233
else :
210
234
print ("WARNING: New vulnerabilities found" )
211
- for source in (ghad_vulnerabilities , nvd_vulnerabilities ):
212
- for name , vulns in source .items ():
213
- print (
214
- f"- { name } (version { dependencies [name ].version_parser (repo_path )} ) :"
215
- )
216
- for v in vulns :
217
- print (f"\t - { v .id } : { v .url } " )
235
+ for vuln in all_vulnerabilities ["vulnerabilities" ]:
236
+ print (
237
+ f"- { vuln .dependency } (version { vuln .version } ) : { vuln .id } ({ vuln .url } )"
238
+ )
218
239
print (f"\n { vulnerability_found_message } " )
219
240
return 1
220
241
0 commit comments