Skip to content

Commit 6a21168

Browse files
committed
perf: skip advisories that are not needed
1 parent 6d30712 commit 6a21168

File tree

9 files changed

+171
-33
lines changed

9 files changed

+171
-33
lines changed

main.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ func loadDatabases(
188188
listPackages bool,
189189
offline bool,
190190
batchSize int,
191+
pkgNames []string,
191192
) (OSVDatabases, bool) {
192193
dbs := make(OSVDatabases, 0, len(dbConfigs))
193194

@@ -204,7 +205,7 @@ func loadDatabases(
204205
for _, dbConfig := range dbConfigs {
205206
r.PrintTextf(" %s", dbConfig.Name)
206207

207-
db, err := database.Load(dbConfig, offline, batchSize)
208+
db, err := database.Load(dbConfig, offline, batchSize, pkgNames)
208209

209210
if err != nil {
210211
r.PrintDatabaseLoadErr(err)
@@ -591,12 +592,20 @@ This flag can be passed multiple times to ignore different vulnerabilities`)
591592

592593
files.adjustExtraDatabases(*noConfigDatabases, *useAPI, *useDatabases)
593594

595+
var allPackages []string
596+
for _, p := range files {
597+
for _, pkg := range p.lockf.Packages {
598+
allPackages = append(allPackages, pkg.Name)
599+
}
600+
}
601+
594602
dbs, errored := loadDatabases(
595603
r,
596604
uniqueDBConfigs(files.getConfigs()),
597605
*listPackages,
598606
*offline,
599607
*batchSize,
608+
allPackages,
600609
)
601610

602611
if errored {

pkg/database/config.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ func (dbc Config) Identifier() string {
2727
var ErrUnsupportedDatabaseType = errors.New("unsupported database source type")
2828

2929
// Load initializes a new OSV database based on the given Config
30-
func Load(config Config, offline bool, batchSize int) (DB, error) {
30+
func Load(config Config, offline bool, batchSize int, pkgNames []string) (DB, error) {
3131
switch config.Type {
3232
case "zip":
33-
return NewZippedDB(config, offline)
33+
return NewZippedDB(config, offline, pkgNames)
3434
case "api":
3535
return NewAPIDB(config, offline, batchSize)
3636
case "dir":
37-
return NewDirDB(config, offline)
37+
return NewDirDB(config, offline, pkgNames)
3838
}
3939

4040
return nil, fmt.Errorf("%w %s", ErrUnsupportedDatabaseType, config.Type)

pkg/database/dir.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ var ErrDirPathWrongProtocol = errors.New("directory path must start with \"file:
2828

2929
// load walks the filesystem starting with the working directory within the local path,
3030
// loading all OSVs found along the way.
31-
func (db *DirDB) load() error {
31+
func (db *DirDB) load(pkgNames []string) error {
3232
db.vulnerabilities = make(map[string][]OSV)
3333

3434
if !strings.HasPrefix(db.LocalPath, "file:") {
@@ -78,7 +78,7 @@ func (db *DirDB) load() error {
7878
return nil
7979
}
8080

81-
db.addVulnerability(pa)
81+
db.addVulnerability(pa, pkgNames)
8282

8383
return nil
8484
})
@@ -94,15 +94,15 @@ func (db *DirDB) load() error {
9494
return nil
9595
}
9696

97-
func NewDirDB(config Config, offline bool) (*DirDB, error) {
97+
func NewDirDB(config Config, offline bool, pkgNames []string) (*DirDB, error) {
9898
db := &DirDB{
9999
name: config.Name,
100100
identifier: config.Identifier(),
101101
LocalPath: config.URL,
102102
WorkingDirectory: config.WorkingDirectory,
103103
Offline: offline,
104104
}
105-
if err := db.load(); err != nil {
105+
if err := db.load(pkgNames); err != nil {
106106
return nil, fmt.Errorf("unable to load OSV database: %w", err)
107107
}
108108

pkg/database/dir_test.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ func TestNewDirDB(t *testing.T) {
1313
osvs := []database.OSV{
1414
withDefaultAffected("OSV-1"),
1515
withDefaultAffected("OSV-2"),
16+
{
17+
ID: "OSV-3",
18+
Affected: []database.Affected{
19+
{Package: database.Package{Ecosystem: "PyPi", Name: "mine2"}, Versions: database.Versions{}},
20+
},
21+
},
1622
{
1723
ID: "GHSA-1234",
1824
Affected: []database.Affected{
@@ -22,7 +28,7 @@ func TestNewDirDB(t *testing.T) {
2228
},
2329
}
2430

25-
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false)
31+
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false, nil)
2632

2733
if err != nil {
2834
t.Fatalf("unexpected error \"%v\"", err)
@@ -34,7 +40,7 @@ func TestNewDirDB(t *testing.T) {
3440
func TestNewDirDB_InvalidURI(t *testing.T) {
3541
t.Parallel()
3642

37-
db, err := database.NewDirDB(database.Config{URL: "file://\\"}, false)
43+
db, err := database.NewDirDB(database.Config{URL: "file://\\"}, false, nil)
3844

3945
if err == nil {
4046
t.Fatalf("NewDirDB() did not return expected error")
@@ -48,7 +54,7 @@ func TestNewDirDB_InvalidURI(t *testing.T) {
4854
func TestNewDirDB_NotFileProtocol(t *testing.T) {
4955
t.Parallel()
5056

51-
db, err := database.NewDirDB(database.Config{URL: "https://mysite.com/my.zip"}, false)
57+
db, err := database.NewDirDB(database.Config{URL: "https://mysite.com/my.zip"}, false, nil)
5258

5359
if err == nil {
5460
t.Fatalf("NewDirDB() did not return expected error")
@@ -66,7 +72,7 @@ func TestNewDirDB_NotFileProtocol(t *testing.T) {
6672
func TestNewDirDB_DoesNotExist(t *testing.T) {
6773
t.Parallel()
6874

69-
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/nowhere"}, false)
75+
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/nowhere"}, false, nil)
7076

7177
if err == nil {
7278
t.Fatalf("NewDirDB() did not return expected error")
@@ -82,11 +88,33 @@ func TestNewDirDB_WorkingDirectory(t *testing.T) {
8288

8389
osvs := []database.OSV{withDefaultAffected("OSV-1")}
8490

85-
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db", WorkingDirectory: "nested-1"}, false)
91+
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db", WorkingDirectory: "nested-1"}, false, nil)
8692

8793
if err != nil {
8894
t.Fatalf("unexpected error \"%v\"", err)
8995
}
9096

9197
expectDBToHaveOSVs(t, db, osvs)
9298
}
99+
100+
func TestNewDirDB_WithSpecificPackages(t *testing.T) {
101+
t.Parallel()
102+
103+
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false, []string{"mine", "request"})
104+
105+
if err != nil {
106+
t.Fatalf("unexpected error \"%v\"", err)
107+
}
108+
109+
expectDBToHaveOSVs(t, db, []database.OSV{
110+
withDefaultAffected("OSV-1"),
111+
withDefaultAffected("OSV-2"),
112+
{
113+
ID: "GHSA-1234",
114+
Affected: []database.Affected{
115+
{Package: database.Package{Ecosystem: "npm", Name: "request"}},
116+
{Package: database.Package{Ecosystem: "npm", Name: "@cypress/request"}},
117+
},
118+
},
119+
})
120+
}

pkg/database/load_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func TestLoad(t *testing.T) {
1717
}
1818

1919
for _, typ := range types {
20-
_, err := database.Load(database.Config{Type: typ}, false, 100)
20+
_, err := database.Load(database.Config{Type: typ}, false, 100, nil)
2121

2222
if err == nil {
2323
t.Fatalf("NewDirDB() did not return expected error")
@@ -28,7 +28,7 @@ func TestLoad(t *testing.T) {
2828
func TestLoad_BadType(t *testing.T) {
2929
t.Parallel()
3030

31-
db, err := database.Load(database.Config{Type: "file"}, false, 100)
31+
db, err := database.Load(database.Config{Type: "file"}, false, 100, nil)
3232

3333
if err == nil {
3434
t.Fatalf("NewDirDB() did not return expected error")

pkg/database/mem-check.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ type memDB struct {
1111
VulnerabilitiesCount int
1212
}
1313

14-
func (db *memDB) addVulnerability(osv OSV) {
14+
func (db *memDB) addVulnerability(osv OSV, pkgNames []string) {
1515
db.VulnerabilitiesCount++
1616

17+
// if we have been provided a list of package names, only load advisories
18+
// that might actually affect those packages, rather than all advisories
19+
if len(pkgNames) != 0 && !mightAffectPackages(osv, pkgNames) {
20+
return
21+
}
22+
1723
for _, affected := range osv.Affected {
1824
hash := string(affected.Package.Ecosystem) + "-" + affected.Package.NormalizedName()
1925
vulns := db.vulnerabilities[hash]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"id": "OSV-3",
3+
"affected": [
4+
{
5+
"package": {
6+
"name": "mine2",
7+
"ecosystem": "PyPi"
8+
},
9+
"versions": []
10+
}
11+
]
12+
}

pkg/database/zip.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,21 @@ func (db *ZipDB) fetchZip() ([]byte, error) {
126126
return body, nil
127127
}
128128

129+
func mightAffectPackages(v OSV, names []string) bool {
130+
for _, affected := range v.Affected {
131+
for _, name := range names {
132+
if affected.Package.Name == name {
133+
return true
134+
}
135+
}
136+
}
137+
138+
return false
139+
}
140+
129141
// Loads the given zip file into the database as an OSV.
130142
// It is assumed that the file is JSON and in the working directory of the db
131-
func (db *ZipDB) loadZipFile(zipFile *zip.File) {
143+
func (db *ZipDB) loadZipFile(zipFile *zip.File, pkgNames []string) {
132144
file, err := zipFile.Open()
133145
if err != nil {
134146
_, _ = fmt.Fprintf(os.Stderr, "Could not read %s: %v\n", zipFile.Name, err)
@@ -152,7 +164,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
152164
return
153165
}
154166

155-
db.addVulnerability(osv)
167+
db.addVulnerability(osv, pkgNames)
156168
}
157169

158170
// load fetches a zip archive of the OSV database and loads known vulnerabilities
@@ -161,7 +173,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
161173
// Internally, the archive is cached along with the date that it was fetched
162174
// so that a new version of the archive is only downloaded if it has been
163175
// modified, per HTTP caching standards.
164-
func (db *ZipDB) load() error {
176+
func (db *ZipDB) load(pkgNames []string) error {
165177
db.vulnerabilities = make(map[string][]OSV)
166178

167179
body, err := db.fetchZip()
@@ -185,21 +197,21 @@ func (db *ZipDB) load() error {
185197
continue
186198
}
187199

188-
db.loadZipFile(zipFile)
200+
db.loadZipFile(zipFile, pkgNames)
189201
}
190202

191203
return nil
192204
}
193205

194-
func NewZippedDB(config Config, offline bool) (*ZipDB, error) {
206+
func NewZippedDB(config Config, offline bool, pkgNames []string) (*ZipDB, error) {
195207
db := &ZipDB{
196208
name: config.Name,
197209
identifier: config.Identifier(),
198210
ArchiveURL: config.URL,
199211
WorkingDirectory: config.WorkingDirectory,
200212
Offline: offline,
201213
}
202-
if err := db.load(); err != nil {
214+
if err := db.load(pkgNames); err != nil {
203215
return nil, fmt.Errorf("unable to fetch OSV database: %w", err)
204216
}
205217

0 commit comments

Comments
 (0)