Skip to content

Commit 4920e1f

Browse files
committed
Refactor the risk score calculation for vulnerabilities and packages.
Update the tests for exploits and the simple_risk_pipeline. Signed-off-by: ziad hany <[email protected]>
1 parent f29ef16 commit 4920e1f

File tree

4 files changed

+51
-52
lines changed

4 files changed

+51
-52
lines changed

vulnerabilities/pipelines/compute_package_risk.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
9-
109
from aboutcode.pipeline import LoopProgress
1110
from django.db.models import Prefetch
1211

@@ -35,12 +34,14 @@ def steps(cls):
3534
)
3635

3736
def compute_and_store_vulnerability_risk_score(self):
38-
affected_vulnerabilities = Vulnerability.objects.filter(
39-
affectedbypackagerelatedvulnerability__isnull=False
40-
).prefetch_related(
41-
"references",
42-
"severities",
43-
"exploits",
37+
affected_vulnerabilities = (
38+
Vulnerability.objects.filter(affecting_packages__isnull=False)
39+
.prefetch_related(
40+
"references",
41+
"severities",
42+
"exploits",
43+
)
44+
.distinct()
4445
)
4546

4647
self.log(
@@ -53,35 +54,43 @@ def compute_and_store_vulnerability_risk_score(self):
5354
updated_vulnerability_count = 0
5455
batch_size = 5000
5556

56-
for vulnerability in progress.iter(affected_vulnerabilities.paginated()):
57+
for vulnerability in progress.iter(affected_vulnerabilities.paginated(per_page=batch_size)):
5758
severities = vulnerability.severities.all()
5859
references = vulnerability.references.all()
60+
exploits = vulnerability.exploits.all()
5961

60-
(
61-
vulnerability.weighted_severity,
62-
vulnerability.exploitability,
63-
) = compute_vulnerability_risk_factors(references, severities, vulnerability.exploits)
62+
weighted_severity, exploitability = compute_vulnerability_risk_factors(
63+
references=references,
64+
severities=severities,
65+
exploits=exploits,
66+
)
67+
vulnerability.weighted_severity = weighted_severity
68+
vulnerability.exploitability = exploitability
6469

6570
updatables.append(vulnerability)
6671

6772
if len(updatables) >= batch_size:
68-
updated_vulnerability_count += bulk_update_vulnerability_risk_score(
69-
vulnerabilities=updatables,
73+
updated_vulnerability_count += bulk_update(
74+
model=Vulnerability,
75+
items=updatables,
76+
fields=["weighted_severity", "exploitability"],
7077
logger=self.log,
7178
)
72-
updated_vulnerability_count += bulk_update_vulnerability_risk_score(
73-
vulnerabilities=updatables,
79+
80+
updated_vulnerability_count += bulk_update(
81+
model=Vulnerability,
82+
items=updatables,
83+
fields=["weighted_severity", "exploitability"],
7484
logger=self.log,
7585
)
86+
7687
self.log(
7788
f"Successfully added risk score for {updated_vulnerability_count:,d} vulnerability"
7889
)
7990

8091
def compute_and_store_package_risk_score(self):
8192
affected_packages = (
82-
Package.objects.filter(affected_by_vulnerabilities__isnull=False)
83-
.only("id")
84-
.prefetch_related(
93+
Package.objects.filter(affected_by_vulnerabilities__isnull=False).prefetch_related(
8594
Prefetch(
8695
"affectedbypackagerelatedvulnerability_set__vulnerability",
8796
queryset=Vulnerability.objects.only("weighted_severity", "exploitability"),
@@ -111,38 +120,28 @@ def compute_and_store_package_risk_score(self):
111120
updatables.append(package)
112121

113122
if len(updatables) >= batch_size:
114-
updated_package_count += bulk_update_package_risk_score(
115-
packages=updatables,
123+
updated_package_count += bulk_update(
124+
model=Package,
125+
items=updatables,
126+
fields=["risk_score"],
116127
logger=self.log,
117128
)
118-
updated_package_count += bulk_update_package_risk_score(
119-
packages=updatables,
129+
updated_package_count += bulk_update(
130+
model=Package,
131+
items=updatables,
132+
fields=["risk_score"],
120133
logger=self.log,
121134
)
122135
self.log(f"Successfully added risk score for {updated_package_count:,d} package")
123136

124137

125-
def bulk_update_package_risk_score(packages, logger):
126-
package_count = 0
127-
if packages:
128-
try:
129-
Package.objects.bulk_update(objs=packages, fields=["risk_score"])
130-
package_count += len(packages)
131-
except Exception as e:
132-
logger(f"Error updating packages: {e}")
133-
packages.clear()
134-
return package_count
135-
136-
137-
def bulk_update_vulnerability_risk_score(vulnerabilities, logger):
138-
vulnerabilities_count = 0
139-
if vulnerabilities:
138+
def bulk_update(model, items, fields, logger):
139+
item_count = 0
140+
if items:
140141
try:
141-
Vulnerability.objects.bulk_update(
142-
objs=vulnerabilities, fields=["weighted_severity", "exploitability"]
143-
)
144-
vulnerabilities_count += len(vulnerabilities)
142+
model.objects.bulk_update(objs=items, fields=fields)
143+
item_count += len(items)
145144
except Exception as e:
146-
logger(f"Error updating vulnerability: {e}")
147-
vulnerabilities.clear()
148-
return vulnerabilities_count
145+
logger(f"Error updating {model.__name__}: {e}")
146+
items.clear()
147+
return item_count

vulnerabilities/risk.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,8 @@ def compute_package_risk(package):
104104
and determining the associated risk.
105105
"""
106106
result = []
107-
vulnerabilities = package.vulnerabilities.all()
108-
for vulnerability in vulnerabilities:
109-
if risk := vulnerability.risk_score:
107+
for vulnerability in package.affectedbypackagerelatedvulnerability_set.all():
108+
if risk := vulnerability.vulnerability.risk_score:
110109
result.append(float(risk))
111110

112111
if not result:

vulnerabilities/tests/pipelines/test_compute_package_risk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def test_simple_risk_pipeline(vulnerability):
3131
improver.execute()
3232

3333
pkg = Package.objects.get(type="pypi", name="foo", version="2.3.0")
34-
assert pkg.risk_score == Decimal("10")
34+
assert pkg.risk_score == Decimal("3.1") # max( 6.9 * 9/10 , 6.5 * 9/10 ) * .5 = 3.105

vulnerabilities/tests/test_risk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,18 @@ def test_get_weighted_severity(vulnerability):
145145

146146

147147
@pytest.mark.django_db
148-
def test_compute_vulnerability_risk_factors(vulnerability):
148+
def test_compute_vulnerability_risk_factors(vulnerability, exploit):
149149
severities = vulnerability.severities.all()
150150
references = vulnerability.references.all()
151151

152-
assert compute_vulnerability_risk_factors(references, severities, vulnerability.exploits) == (
152+
assert compute_vulnerability_risk_factors(references, severities, exploit) == (
153153
6.2,
154154
2,
155155
)
156156

157157
assert compute_vulnerability_risk_factors(references, severities, None) == (6.2, 0.5)
158-
assert compute_vulnerability_risk_factors(references, None, vulnerability.exploits) == (0, 2)
158+
159+
assert compute_vulnerability_risk_factors(references, None, exploit) == (0, 2)
159160

160161
assert compute_vulnerability_risk_factors(None, None, None) == (0, 0.5)
161162

0 commit comments

Comments
 (0)