Skip to content

Commit facf2a7

Browse files
Hritik14TG1999
authored andcommitted
Add tests for improve_runner
Signed-off-by: Hritik Vijay <[email protected]>
1 parent c58cf36 commit facf2a7

File tree

3 files changed

+183
-8
lines changed

3 files changed

+183
-8
lines changed

vulnerabilities/improve_runner.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def run(self) -> None:
5757
@transaction.atomic
5858
def process_inferences(inferences: List[Inference], advisory: Advisory, improver_name: str):
5959
"""
60+
Return number of inferences processed.
6061
An atomic transaction that updates both the Advisory (e.g. date_improved)
6162
and processes the given inferences to create or update corresponding
6263
database fields.
@@ -65,10 +66,11 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
6566
erroneous. Also, the atomic transaction for every advisory and its
6667
inferences makes sure that date_improved of advisory is consistent.
6768
"""
69+
inferences_processed_count = 0
6870

6971
if not inferences:
70-
logger.warn(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
71-
return
72+
logger.warning(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
73+
return inferences_processed_count
7274

7375
logger.info(f"Improving advisory id: {advisory.id}")
7476

@@ -80,7 +82,7 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
8082
)
8183

8284
if not vulnerability:
83-
logger.warn(f"Unable to get vulnerability for inference: {inference!r}")
85+
logger.warning(f"Unable to get vulnerability for inference: {inference!r}")
8486
continue
8587

8688
for ref in inference.references:
@@ -143,8 +145,12 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
143145
cwe_obj, created = Weakness.objects.get_or_create(cwe_id=cwe_id)
144146
cwe_obj.vulnerabilities.add(vulnerability)
145147
cwe_obj.save()
148+
149+
inferences_processed_count += 1
150+
146151
advisory.date_improved = datetime.now(timezone.utc)
147152
advisory.save()
153+
return inferences_processed_count
148154

149155

150156
def create_valid_vulnerability_reference(url, reference_id=None):
@@ -168,7 +174,7 @@ def create_valid_vulnerability_reference(url, reference_id=None):
168174
return reference
169175

170176

171-
def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summary):
177+
def get_or_create_vulnerability_and_aliases(alias_names, vulnerability_id=None, summary=None):
172178
"""
173179
Get or create vulnerabilitiy and aliases such that all existing and new
174180
aliases point to the same vulnerability
@@ -188,7 +194,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
188194
# TODO: It is possible that all those vulnerabilities are actually
189195
# the same at data level, figure out a way to merge them
190196
if len(existing_vulns) > 1:
191-
logger.warn(
197+
logger.warning(
192198
f"Given aliases {alias_names} already exist and do not point "
193199
f"to a single vulnerability. Cannot improve. Skipped."
194200
)
@@ -201,7 +207,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
201207
and vulnerability_id
202208
and existing_alias_vuln.vulnerability_id != vulnerability_id
203209
):
204-
logger.warn(
210+
logger.warning(
205211
f"Given aliases {alias_names!r} already exist and point to existing"
206212
f"vulnerability {existing_alias_vuln}. Unable to create Vulnerability "
207213
f"with vulnerability_id {vulnerability_id}. Skipped"
@@ -214,7 +220,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
214220
try:
215221
vulnerability = Vulnerability.objects.get(vulnerability_id=vulnerability_id)
216222
except Vulnerability.DoesNotExist:
217-
logger.warn(
223+
logger.warning(
218224
f"Given vulnerability_id: {vulnerability_id} does not exist in the database"
219225
)
220226
return
@@ -223,7 +229,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
223229
vulnerability.save()
224230

225231
if summary and summary != vulnerability.summary:
226-
logger.warn(
232+
logger.warning(
227233
f"Inconsistent summary for {vulnerability!r}. "
228234
f"Existing: {vulnerability.summary}, provided: {summary}"
229235
)

vulnerabilities/tests/test_improve_runner.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,27 @@
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
99

10+
from collections import Counter
11+
1012
import pytest
13+
from django.utils import timezone
14+
from packageurl import PackageURL
15+
from pytest_django.asserts import assertQuerysetEqual
1116

17+
from vulnerabilities.importer import Reference
1218
from vulnerabilities.improve_runner import create_valid_vulnerability_reference
19+
from vulnerabilities.improve_runner import get_or_create_vulnerability_and_aliases
20+
from vulnerabilities.improve_runner import process_inferences
21+
from vulnerabilities.improver import Improver
22+
from vulnerabilities.improver import Inference
23+
from vulnerabilities.models import Advisory
24+
from vulnerabilities.models import Alias
25+
from vulnerabilities.models import Package
26+
from vulnerabilities.models import PackageRelatedVulnerability
27+
from vulnerabilities.models import Vulnerability
28+
from vulnerabilities.models import VulnerabilityReference
29+
from vulnerabilities.models import VulnerabilityRelatedReference
30+
from vulnerabilities.models import VulnerabilitySeverity
1331

1432

1533
@pytest.mark.django_db
@@ -37,3 +55,152 @@ def test_create_valid_vulnerability_reference_accepts_long_references():
3755
url="https://foo.bar",
3856
)
3957
assert result
58+
59+
60+
@pytest.mark.django_db
61+
def test_get_or_create_vulnerability_and_aliases_with_new_vulnerability_and_new_aliases():
62+
alias_names = ["TAYLOR-1337", "SWIFT-1337"]
63+
summary = "Melodious vulnerability"
64+
vulnerability = get_or_create_vulnerability_and_aliases(
65+
alias_names=alias_names, summary=summary
66+
)
67+
assert vulnerability
68+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
69+
assert Counter(alias_names_in_db) == Counter(alias_names)
70+
71+
72+
@pytest.mark.django_db
73+
def test_get_or_create_vulnerability_and_aliases_with_different_vulnerability_and_existing_aliases():
74+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
75+
existing_vulnerability.save()
76+
existing_aliases = []
77+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
78+
for alias in existing_alias_names:
79+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
80+
Alias.objects.bulk_create(existing_aliases)
81+
82+
different_vulnerability = Vulnerability(vulnerability_id="VCID-New")
83+
different_vulnerability.save()
84+
assert not get_or_create_vulnerability_and_aliases(
85+
alias_names=existing_alias_names, vulnerability_id=different_vulnerability.vulnerability_id
86+
)
87+
88+
89+
@pytest.mark.django_db
90+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_new_aliases():
91+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
92+
existing_vulnerability.save()
93+
94+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
95+
vulnerability = get_or_create_vulnerability_and_aliases(
96+
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
97+
)
98+
assert existing_vulnerability == vulnerability
99+
100+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
101+
assert Counter(alias_names_in_db) == Counter(existing_alias_names)
102+
103+
104+
@pytest.mark.django_db
105+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_aliases():
106+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
107+
existing_vulnerability.save()
108+
109+
existing_aliases = []
110+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
111+
for alias in existing_alias_names:
112+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
113+
Alias.objects.bulk_create(existing_aliases)
114+
115+
vulnerability = get_or_create_vulnerability_and_aliases(
116+
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
117+
)
118+
assert existing_vulnerability == vulnerability
119+
120+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
121+
assert Counter(alias_names_in_db) == Counter(existing_alias_names)
122+
123+
124+
@pytest.mark.django_db
125+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_and_new_aliases():
126+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
127+
existing_vulnerability.save()
128+
129+
existing_aliases = []
130+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
131+
for alias in existing_alias_names:
132+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
133+
Alias.objects.bulk_create(existing_aliases)
134+
135+
new_alias_names = ["ALIAS-3", "ALIAS-4"]
136+
alias_names = existing_alias_names + new_alias_names
137+
vulnerability = get_or_create_vulnerability_and_aliases(
138+
vulnerability_id="VCID-Existing", alias_names=alias_names
139+
)
140+
assert existing_vulnerability == vulnerability
141+
142+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
143+
assert Counter(alias_names_in_db) == Counter(alias_names)
144+
145+
146+
DUMMY_ADVISORY = Advisory(summary="dummy", created_by="tests", date_collected=timezone.now())
147+
148+
149+
@pytest.mark.django_db
150+
def test_process_inferences_with_no_inference():
151+
assert not process_inferences(
152+
inferences=[], advisory=DUMMY_ADVISORY, improver_name="test_improver"
153+
)
154+
155+
156+
@pytest.mark.django_db
157+
def test_process_inferences_with_unknown_but_specified_vulnerability():
158+
inference = Inference(vulnerability_id="VCID-Does-Not-Exist-In-DB", aliases=["MATRIX-Neo"])
159+
assert not process_inferences(
160+
inferences=[inference], advisory=DUMMY_ADVISORY, improver_name="test_improver"
161+
)
162+
163+
164+
INFERENCES = [
165+
Inference(
166+
aliases=["CVE-1", "CVE-2"],
167+
summary="One upon a time, in a package far far away",
168+
affected_purls=[
169+
PackageURL(type="character", namespace="star-wars", name="anakin", version="1")
170+
],
171+
fixed_purl=PackageURL(
172+
type="character", namespace="star-wars", name="darth-vader", version="1"
173+
),
174+
references=[Reference(reference_id="imperial-vessel-1", url="https://m47r1x.github.io")],
175+
)
176+
]
177+
178+
179+
def get_objects_in_all_tables_used_by_process_inferences():
180+
return {
181+
"vulnerabilities": list(Vulnerability.objects.all()),
182+
"aliases": list(Alias.objects.all()),
183+
"references": list(VulnerabilityReference.objects.all()),
184+
"advisories": list(Advisory.objects.all()),
185+
"packages": list(Package.objects.all()),
186+
"references": list(VulnerabilityReference.objects.all()),
187+
"severity": list(VulnerabilitySeverity.objects.all()),
188+
}
189+
190+
191+
@pytest.mark.django_db
192+
def test_process_inferences_idempotency():
193+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
194+
all_objects = get_objects_in_all_tables_used_by_process_inferences()
195+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
196+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
197+
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()
198+
199+
200+
@pytest.mark.django_db
201+
def test_process_inference_idempotency_with_different_improver_names():
202+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_one")
203+
all_objects = get_objects_in_all_tables_used_by_process_inferences()
204+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_two")
205+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_three")
206+
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()

vulnerabilities/tests/test_improver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_inference_to_dict_method_with_vulnerability_id():
3131
"affected_purls": [],
3232
"fixed_purl": None,
3333
"references": [],
34+
"weaknesses": [],
3435
}
3536
assert expected == inference.to_dict()
3637

@@ -46,6 +47,7 @@ def test_inference_to_dict_method_with_purls():
4647
"affected_purls": [purl.to_dict()],
4748
"fixed_purl": purl.to_dict(),
4849
"references": [],
50+
"weaknesses": [],
4951
}
5052
assert expected == inference.to_dict()
5153

0 commit comments

Comments
 (0)