Skip to content

Commit c923364

Browse files
authored
Add pipeline to sort packages (#1686)
* Add pipeline to sort packages Signed-off-by: Tushar Goel <[email protected]> * Add tests Signed-off-by: Tushar Goel <[email protected]> * Add calculate_version_rank on Package Signed-off-by: Tushar Goel <[email protected]> * Start enumerating from 1 Signed-off-by: Tushar Goel <[email protected]> * Fix tests Signed-off-by: Tushar Goel <[email protected]> * Return version rank anyhow Signed-off-by: Tushar Goel <[email protected]> * Fix API tests Signed-off-by: Tushar Goel <[email protected]> * Address review comments Signed-off-by: Tushar Goel <[email protected]> --------- Signed-off-by: Tushar Goel <[email protected]>
1 parent cec5d9e commit c923364

File tree

7 files changed

+241
-21
lines changed

7 files changed

+241
-21
lines changed

vulnerabilities/improvers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vulnerabilities.improvers import vulnerability_status
1212
from vulnerabilities.pipelines import VulnerableCodePipeline
1313
from vulnerabilities.pipelines import compute_package_risk
14+
from vulnerabilities.pipelines import compute_package_version_rank
1415
from vulnerabilities.pipelines import enhance_with_exploitdb
1516
from vulnerabilities.pipelines import enhance_with_kev
1617
from vulnerabilities.pipelines import enhance_with_metasploit
@@ -39,6 +40,7 @@
3940
enhance_with_metasploit.MetasploitImproverPipeline,
4041
enhance_with_exploitdb.ExploitDBImproverPipeline,
4142
compute_package_risk.ComputePackageRiskPipeline,
43+
compute_package_version_rank.ComputeVersionRankPipeline,
4244
]
4345

4446
IMPROVERS_REGISTRY = {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Generated by Django 4.2.16 on 2024-12-04 11:50
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
("vulnerabilities", "0083_alter_packagechangelog_software_version_and_more"),
10+
]
11+
12+
operations = [
13+
migrations.AlterModelOptions(
14+
name="package",
15+
options={
16+
"ordering": [
17+
"type",
18+
"namespace",
19+
"name",
20+
"version_rank",
21+
"version",
22+
"qualifiers",
23+
"subpath",
24+
]
25+
},
26+
),
27+
migrations.AddField(
28+
model_name="package",
29+
name="version_rank",
30+
field=models.IntegerField(
31+
default=0,
32+
help_text="Rank of the version to support ordering by version. Rank zero means the rank has not been defined yet",
33+
),
34+
),
35+
]

vulnerabilities/models.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ class Package(PackageURLMixin):
705705
"indicate greater vulnerability risk for the package.",
706706
)
707707

708+
version_rank = models.IntegerField(
709+
help_text="Rank of the version to support ordering by version. Rank "
710+
"zero means the rank has not been defined yet",
711+
default=0,
712+
)
713+
708714
objects = PackageQuerySet.as_manager()
709715

710716
def save(self, *args, **kwargs):
@@ -738,11 +744,34 @@ def purl(self):
738744

739745
class Meta:
740746
unique_together = ["type", "namespace", "name", "version", "qualifiers", "subpath"]
741-
ordering = ["type", "namespace", "name", "version", "qualifiers", "subpath"]
747+
ordering = ["type", "namespace", "name", "version_rank", "version", "qualifiers", "subpath"]
742748

743749
def __str__(self):
744750
return self.package_url
745751

752+
@property
753+
def calculate_version_rank(self):
754+
"""
755+
Calculate and return the `version_rank` for a package that does not have one.
756+
If this package already has a `version_rank`, return it.
757+
758+
The calculated rank will be interpolated between two packages that have
759+
`version_rank` values and are closest to this package in terms of version order.
760+
"""
761+
762+
group_packages = Package.objects.filter(
763+
type=self.type,
764+
namespace=self.namespace,
765+
name=self.name,
766+
)
767+
768+
if any(p.version_rank == 0 for p in group_packages):
769+
sorted_packages = sorted(group_packages, key=lambda p: self.version_class(p.version))
770+
for rank, package in enumerate(sorted_packages, start=1):
771+
package.version_rank = rank
772+
Package.objects.bulk_update(sorted_packages, fields=["version_rank"])
773+
return self.version_rank
774+
746775
@property
747776
def affected_by(self):
748777
"""
@@ -789,14 +818,6 @@ def get_details_url(self, request):
789818

790819
return reverse("package_details", kwargs={"purl": self.purl}, request=request)
791820

792-
def sort_by_version(self, packages):
793-
"""
794-
Return a sequence of `packages` sorted by version.
795-
"""
796-
if not packages:
797-
return []
798-
return sorted(packages, key=lambda x: self.version_class(x.version))
799-
800821
@cached_property
801822
def version_class(self):
802823
range_class = RANGE_CLASS_BY_SCHEMES.get(self.type)
@@ -831,19 +852,20 @@ def get_non_vulnerable_versions(self):
831852
Return a tuple of the next and latest non-vulnerable versions as Package instance.
832853
Return a tuple of (None, None) if there is no non-vulnerable version.
833854
"""
855+
if self.version_rank == 0:
856+
self.calculate_version_rank
834857
non_vulnerable_versions = Package.objects.get_fixed_by_package_versions(
835858
self, fix=False
836859
).only_non_vulnerable()
837-
sorted_versions = self.sort_by_version(non_vulnerable_versions)
838860

839-
later_non_vulnerable_versions = [
840-
non_vuln_ver
841-
for non_vuln_ver in sorted_versions
842-
if self.version_class(non_vuln_ver.version) > self.current_version
843-
]
861+
later_non_vulnerable_versions = non_vulnerable_versions.filter(
862+
version_rank__gt=self.version_rank
863+
)
864+
865+
later_non_vulnerable_versions = list(later_non_vulnerable_versions)
844866

845867
if later_non_vulnerable_versions:
846-
sorted_versions = self.sort_by_version(later_non_vulnerable_versions)
868+
sorted_versions = later_non_vulnerable_versions
847869
next_non_vulnerable = sorted_versions[0]
848870
latest_non_vulnerable = sorted_versions[-1]
849871
return next_non_vulnerable, latest_non_vulnerable
@@ -872,6 +894,8 @@ def get_affecting_vulnerabilities(self):
872894
Return a list of vulnerabilities that affect this package together with information regarding
873895
the versions that fix the vulnerabilities.
874896
"""
897+
if self.version_rank == 0:
898+
self.calculate_version_rank
875899
package_details_vulns = []
876900

877901
fixed_by_packages = Package.objects.get_fixed_by_package_versions(self, fix=True)
@@ -895,12 +919,13 @@ def get_affecting_vulnerabilities(self):
895919
if fixed_version > self.current_version:
896920
later_fixed_packages.append(fixed_pkg)
897921

898-
next_fixed_package = None
899922
next_fixed_package_vulns = []
900923

901924
sort_fixed_by_packages_by_version = []
902925
if later_fixed_packages:
903-
sort_fixed_by_packages_by_version = self.sort_by_version(later_fixed_packages)
926+
sort_fixed_by_packages_by_version = sorted(
927+
later_fixed_packages, key=lambda p: p.version_rank
928+
)
904929

905930
fixed_by_pkgs = []
906931

@@ -930,6 +955,7 @@ def fixing_vulnerabilities(self):
930955
"""
931956
Return a queryset of Vulnerabilities that are fixed by this package.
932957
"""
958+
print("A")
933959
return self.fixed_by_vulnerabilities.all()
934960

935961
@property
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#
2+
# Copyright (c) nexB Inc. and others. All rights reserved.
3+
# VulnerableCode is a trademark of nexB Inc.
4+
# SPDX-License-Identifier: Apache-2.0
5+
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
6+
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
7+
# See https://aboutcode.org for more information about nexB OSS projects.
8+
#
9+
10+
from itertools import groupby
11+
12+
from aboutcode.pipeline import LoopProgress
13+
from django.db import transaction
14+
from univers.version_range import RANGE_CLASS_BY_SCHEMES
15+
from univers.versions import Version
16+
17+
from vulnerabilities.models import Package
18+
from vulnerabilities.pipelines import VulnerableCodePipeline
19+
20+
21+
class ComputeVersionRankPipeline(VulnerableCodePipeline):
22+
"""
23+
A pipeline to compute and assign version ranks for all packages.
24+
"""
25+
26+
pipeline_id = "compute_version_rank"
27+
license_expression = None
28+
29+
@classmethod
30+
def steps(cls):
31+
return (cls.compute_and_store_version_rank,)
32+
33+
def compute_and_store_version_rank(self):
34+
"""
35+
Compute and assign version ranks to all packages.
36+
"""
37+
groups = Package.objects.only("type", "namespace", "name").order_by(
38+
"type", "namespace", "name"
39+
)
40+
41+
def key(package):
42+
return package.type, package.namespace, package.name
43+
44+
groups = groupby(groups, key=key)
45+
46+
groups = [(list(x), list(y)) for x, y in groups]
47+
48+
total_groups = len(groups)
49+
self.log(f"Calculating `version_rank` for {total_groups:,d} groups of packages.")
50+
51+
progress = LoopProgress(
52+
total_iterations=total_groups,
53+
logger=self.log,
54+
progress_step=5,
55+
)
56+
57+
for group, packages in progress.iter(groups):
58+
type, namespace, name = group
59+
if type not in RANGE_CLASS_BY_SCHEMES:
60+
continue
61+
self.update_version_rank_for_group(packages)
62+
63+
self.log("Successfully populated `version_rank` for all packages.")
64+
65+
@transaction.atomic
66+
def update_version_rank_for_group(self, packages):
67+
"""
68+
Update the `version_rank` for all packages in a specific group.
69+
"""
70+
71+
# Sort the packages by version
72+
sorted_packages = self.sort_packages_by_version(packages)
73+
74+
# Assign version ranks
75+
updates = []
76+
for rank, package in enumerate(sorted_packages, start=1):
77+
package.version_rank = rank
78+
updates.append(package)
79+
80+
# Bulk update to save the ranks
81+
Package.objects.bulk_update(updates, fields=["version_rank"])
82+
83+
def sort_packages_by_version(self, packages):
84+
"""
85+
Sort packages by version using `version_class`.
86+
"""
87+
88+
if not packages:
89+
return []
90+
version_class = RANGE_CLASS_BY_SCHEMES.get(packages[0].type).version_class
91+
if not version_class:
92+
version_class = Version
93+
return sorted(packages, key=lambda p: version_class(p.version))

vulnerabilities/tests/test_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def setUp(self):
489489
self.pkg_2_14_0_rc1 = from_purl(
490490
"pkg:maven/com.fasterxml.jackson.core/[email protected]"
491491
)
492+
self.pkg_2_12_6.calculate_version_rank
492493

493494
set_as_fixing(package=self.pkg_2_12_6, vulnerability=self.vul3)
494495

@@ -608,6 +609,7 @@ def setUp(self):
608609
self.pkg_2_14_0_rc1 = from_purl(
609610
"pkg:maven/com.fasterxml.jackson.core/[email protected]"
610611
)
612+
self.pkg_2_12_6.calculate_version_rank
611613

612614
self.ref = VulnerabilityReference.objects.create(
613615
reference_type="advisory", reference_id="CVE-xxx-xxx", url="https://example.com"
@@ -806,7 +808,7 @@ def test_api_with_ghost_package_no_fixing_vulnerabilities(self):
806808
"qualifiers": {},
807809
"subpath": "",
808810
"is_vulnerable": True,
809-
"next_non_vulnerable_version": "2.14.0-rc1",
811+
"next_non_vulnerable_version": "2.12.6",
810812
"latest_non_vulnerable_version": "2.14.0-rc1",
811813
"affected_by_vulnerabilities": [
812814
{
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
from univers.versions import Version
5+
6+
from vulnerabilities.models import Package
7+
from vulnerabilities.pipelines.compute_package_version_rank import ComputeVersionRankPipeline
8+
9+
10+
@pytest.mark.django_db
11+
class TestComputeVersionRankPipeline:
12+
@pytest.fixture
13+
def pipeline(self):
14+
return ComputeVersionRankPipeline()
15+
16+
@pytest.fixture
17+
def packages(self, db):
18+
package_type = "pypi"
19+
namespace = "test_namespace"
20+
name = "test_package"
21+
Package.objects.create(type=package_type, namespace=namespace, name=name, version="1.0.0")
22+
Package.objects.create(type=package_type, namespace=namespace, name=name, version="1.1.0")
23+
Package.objects.create(type=package_type, namespace=namespace, name=name, version="0.9.0")
24+
return Package.objects.filter(type=package_type, namespace=namespace, name=name)
25+
26+
def test_compute_and_store_version_rank(self, pipeline, packages):
27+
with patch.object(pipeline, "log") as mock_log:
28+
pipeline.compute_and_store_version_rank()
29+
assert mock_log.call_count > 0
30+
for package in packages:
31+
assert package.version_rank is not None
32+
33+
def test_update_version_rank_for_group(self, pipeline, packages):
34+
with patch.object(Package.objects, "bulk_update") as mock_bulk_update:
35+
pipeline.update_version_rank_for_group(packages)
36+
mock_bulk_update.assert_called_once()
37+
updated_packages = mock_bulk_update.call_args[0][0]
38+
assert len(updated_packages) == len(packages)
39+
for idx, package in enumerate(sorted(packages, key=lambda p: Version(p.version))):
40+
assert updated_packages[idx].version_rank == idx
41+
42+
def test_sort_packages_by_version(self, pipeline, packages):
43+
sorted_packages = pipeline.sort_packages_by_version(packages)
44+
versions = [p.version for p in sorted_packages]
45+
assert versions == sorted(versions, key=Version)
46+
47+
def test_sort_packages_by_version_empty(self, pipeline):
48+
assert pipeline.sort_packages_by_version([]) == []
49+
50+
def test_sort_packages_by_version_invalid_scheme(self, pipeline, packages):
51+
for package in packages:
52+
package.type = "invalid"
53+
assert pipeline.sort_packages_by_version(packages) == []
54+
55+
def test_compute_and_store_version_rank_invalid_scheme(self, pipeline):
56+
Package.objects.create(type="invalid", namespace="test", name="package", version="1.0.0")
57+
with patch.object(pipeline, "log") as mock_log:
58+
pipeline.compute_and_store_version_rank()
59+
mock_log.assert_any_call("Successfully populated `version_rank` for all packages.")

vulnerabilities/tests/test_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,11 @@ def test_sort_by_version(self):
423423
version="3.0.0",
424424
)
425425

426-
sorted_pkgs = requesting_package.sort_by_version(vuln_pkg_list)
427-
first_sorted_item = sorted_pkgs[0]
426+
requesting_package.calculate_version_rank
427+
428+
sorted_pkgs = Package.objects.filter(package_url__in=list_to_sort)
429+
430+
sorted_pkgs = list(sorted_pkgs)
428431

429432
assert sorted_pkgs[0].purl == "pkg:npm/[email protected]"
430433
assert sorted_pkgs[-1].purl == "pkg:npm/[email protected]"

0 commit comments

Comments
 (0)