diff --git a/go.mod b/go.mod index b74e3afc..8865617b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/osv-scalibr v0.3.5-0.20251002191929-de9496dc5aa2 github.com/tidwall/jsonc v0.3.2 golang.org/x/mod v0.30.0 + golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 55ee1743..de00b132 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= diff --git a/pkg/database/api-check.go b/pkg/database/api-check.go index 9307f368..3ff2cd92 100644 --- a/pkg/database/api-check.go +++ b/pkg/database/api-check.go @@ -12,6 +12,7 @@ import ( "path" "github.com/g-rath/osv-detector/internal" + "golang.org/x/sync/errgroup" ) func (db APIDB) buildAPIPayload(pkg internal.PackageDetails) apiQuery { @@ -171,15 +172,38 @@ func findOrDefault(vulns Vulnerabilities, def OSV) OSV { func (db APIDB) Check(pkgs []internal.PackageDetails) ([]Vulnerabilities, error) { batches := batchPkgs(pkgs, db.BatchSize) - vulnerabilities := make([]Vulnerabilities, 0, len(pkgs)) + var eg errgroup.Group - for _, batch := range batches { - results, err := db.checkBatch(batch) + // use a sensible upper limit so it's not possible to have inf. operations going + // even though it's very unlikely there will be more than a couple of batches + eg.SetLimit(100) - if err != nil { - return nil, err - } + batchResults := make([][][]ObjectWithID, len(batches)) + + for i, batch := range batches { + eg.Go(func() error { + results, err := db.checkBatch(batch) + + if err != nil { + return err + } + + batchResults[i] = results + + return nil + }) + } + + err := eg.Wait() + + if err != nil { + return nil, err + } + + vulnerabilities := make([]Vulnerabilities, 0, len(pkgs)) + // todo: pretty sure some of these loops and slices can be merged and simplified + for _, results := range batchResults { for _, withIDs := range results { vulns := make(Vulnerabilities, 0, len(withIDs)) diff --git a/pkg/database/api-check_test.go b/pkg/database/api-check_test.go index f874f20b..f861481a 100644 --- a/pkg/database/api-check_test.go +++ b/pkg/database/api-check_test.go @@ -553,11 +553,15 @@ func TestAPIDB_Check_Batches(t *testing.T) { mux.HandleFunc("/querybatch", func(w http.ResponseWriter, r *http.Request) { requestCount++ + if requestCount > 2 { + t.Errorf("unexpected number of requests (%d)", requestCount) + } + var expectedPayload []apiQuery var batchResponse []objectsWithIDs - switch requestCount { - case 1: + // strictly speaking not the best of checks, but it should be good enough + if r.ContentLength > 100 { expectedPayload = []apiQuery{ { Version: "1.0.0", @@ -569,7 +573,7 @@ func TestAPIDB_Check_Batches(t *testing.T) { }, } batchResponse = []objectsWithIDs{{}, {}} - case 2: + } else if r.ContentLength > 50 { expectedPayload = []apiQuery{ { Version: "2.3.1", @@ -577,8 +581,6 @@ func TestAPIDB_Check_Batches(t *testing.T) { }, } batchResponse = []objectsWithIDs{{}} - default: - t.Errorf("unexpected number of requests (%d)", requestCount) } expectRequestPayload(t, r, expectedPayload) diff --git a/pkg/database/api-fetch-all.go b/pkg/database/api-fetch-all.go index ebb6cc9a..98d3ec9d 100644 --- a/pkg/database/api-fetch-all.go +++ b/pkg/database/api-fetch-all.go @@ -2,65 +2,32 @@ package database import ( "sort" -) -// a struct to hold the result from each request including an index -// which will be used for sorting the results after they come in -type result struct { - index int - res OSV - err error -} + "golang.org/x/sync/errgroup" +) func (db APIDB) FetchAll(ids []string) Vulnerabilities { - conLimit := 200 + var eg errgroup.Group - var osvs Vulnerabilities - - if len(ids) == 0 { - return osvs - } + eg.SetLimit(200) - // buffered channel which controls the number of concurrent operations - semaphoreChan := make(chan struct{}, conLimit) - resultsChan := make(chan *result) - - defer func() { - close(semaphoreChan) - close(resultsChan) - }() + osvs := make(Vulnerabilities, len(ids)) for i, id := range ids { - go func(i int, id string) { - // read from the buffered semaphore channel, which will block if we're - // already got as many goroutines as our concurrency limit allows - // - // when one of those routines finish they'll read from this channel, - // freeing up a slot to unblock this send - semaphoreChan <- struct{}{} - + eg.Go(func() error { // if we error, still report the vulnerability as hopefully the ID should be // enough to manually look up the details - in future we should ideally warn // the user too, but for now we just silently eat the error osv, _ := db.Fetch(id) - result := &result{i, osv, nil} - resultsChan <- result + osvs[i] = osv - // read from the buffered semaphore to free up a slot to allow - // another goroutine to start, since this one is wrapping up - <-semaphoreChan - }(i, id) + return nil + }) } - for { - result := <-resultsChan - osvs = append(osvs, result.res) - - if len(osvs) == len(ids) { - break - } - } + // errors are handled within the go routines + _ = eg.Wait() sort.Slice(osvs, func(i, j int) bool { return osvs[i].ID < osvs[j].ID