Skip to content

Commit 9d02bed

Browse files
committed
Refine the vulnerability affectation
Signed-off-by: tdruez <[email protected]>
1 parent e863019 commit 9d02bed

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

product_portfolio/tests/test_models.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -601,21 +601,25 @@ def test_product_model_improve_packages_from_purldb(self, mock_update_from_purld
601601
self.assertEqual("apache-2.0", pp1.license_expression)
602602

603603
def test_product_model_affected_by_vulnerabilities(self):
604-
vulnerability1 = make_vulnerability(self.dataspace, risk_score=10.0)
605-
vulnerability2 = make_vulnerability(
606-
self.dataspace, affecting=[self.product1], risk_score=1.0
607-
)
604+
vulnerability1 = make_vulnerability(self.dataspace, risk_score=1.0)
605+
vulnerability2 = make_vulnerability(self.dataspace, risk_score=10.0)
606+
vulnerability3 = make_vulnerability(self.dataspace, risk_score=5.0)
608607

608+
vulnerability1.add_affected(self.product1)
609609
affected_by = self.product1.affected_by_vulnerabilities.all()
610-
self.assertQuerySetEqual([vulnerability2], affected_by)
610+
self.assertQuerySetEqual([vulnerability1], affected_by)
611611
self.product1.refresh_from_db()
612-
# self.assertEqual(1.0, self.product1.risk_score)
612+
self.assertEqual(1.0, self.product1.risk_score)
613613

614-
vulnerability1.add_affected(self.product1)
614+
vulnerability2.add_affected(self.product1)
615615
affected_by = self.product1.affected_by_vulnerabilities.order_by("id")
616616
self.assertQuerySetEqual([vulnerability1, vulnerability2], affected_by)
617617
self.product1.refresh_from_db()
618-
# self.assertEqual(10.0, self.product1.risk_score)
618+
self.assertEqual(10.0, self.product1.risk_score)
619+
620+
vulnerability3.add_affected(self.product1)
621+
self.product1.refresh_from_db()
622+
self.assertEqual(10.0, self.product1.risk_score)
619623

620624
def test_product_model_get_vulnerability_qs(self):
621625
package1 = make_package(self.dataspace)

vulnerabilities/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ def add_affected_products(self, products):
205205
"""Assign the ``products`` as affected by this vulnerability."""
206206
through_defaults = {"dataspace_id": self.dataspace_id}
207207
self.affected_products.add(*products, through_defaults=through_defaults)
208+
for product in products:
209+
product.update_risk_score()
208210

209211
@classmethod
210212
def create_from_data(cls, dataspace, data, validate=False, affecting=None):
@@ -437,6 +439,12 @@ def update_risk_score(self):
437439
self.save(update_fields=["risk_score"])
438440
return self.risk_score
439441

442+
def add_affected_by(self, vulnerability):
443+
"""Add ``vulnerability`` as affecting this instance."""
444+
through_defaults = {"dataspace_id": self.dataspace_id}
445+
self.affected_by_vulnerabilities.add(vulnerability, through_defaults=through_defaults)
446+
self.update_risk_score()
447+
440448
def get_entry_for_package(self, vulnerablecode):
441449
if not self.package_url:
442450
return

vulnerabilities/tests/test_models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,30 @@ def test_vulnerability_mixin_update_risk_score(self):
126126
package2.update_risk_score()
127127
self.assertIsNone(package2.risk_score)
128128

129+
def test_vulnerability_mixin_add_affected_by(self):
130+
package1 = make_package(self.dataspace)
131+
132+
vulnerability1 = make_vulnerability(self.dataspace, risk_score=1.0)
133+
vulnerability2 = make_vulnerability(self.dataspace, risk_score=10.0)
134+
vulnerability3 = make_vulnerability(self.dataspace, risk_score=5.0)
135+
136+
package1.add_affected_by(vulnerability1)
137+
package1.refresh_from_db()
138+
self.assertEqual("1.0", str(package1.risk_score))
139+
140+
package1.add_affected_by(vulnerability2)
141+
package1.refresh_from_db()
142+
self.assertEqual("10.0", str(package1.risk_score))
143+
144+
package1.add_affected_by(vulnerability3)
145+
package1.refresh_from_db()
146+
self.assertEqual("10.0", str(package1.risk_score))
147+
148+
self.assertEqual(package1, vulnerability1.affected_packages.get())
149+
self.assertEqual(package1, vulnerability2.affected_packages.get())
150+
self.assertEqual(package1, vulnerability3.affected_packages.get())
151+
self.assertEqual(3, package1.affected_by_vulnerabilities.count())
152+
129153
def test_vulnerability_model_affected_packages_m2m(self):
130154
package1 = make_package(self.dataspace)
131155
vulnerability1 = make_vulnerability(dataspace=self.dataspace, affecting=package1)

0 commit comments

Comments
 (0)