-
Notifications
You must be signed in to change notification settings - Fork 9
perf: do api queries in parallel using errgroup
#347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
89cdf7b
c4f2ee0
48a6084
e61cb2e
a2d2c56
3a0f396
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ import ( | |
| "path" | ||
|
|
||
| "github.com/g-rath/osv-detector/internal" | ||
| "golang.org/x/sync/errgroup" | ||
| ) | ||
|
|
||
| func (db APIDB) buildAPIPayload(pkg internal.PackageDetails) apiQuery { | ||
|
|
@@ -171,15 +172,38 @@ func findOrDefault(vulns Vulnerabilities, def OSV) OSV { | |
| func (db APIDB) Check(pkgs []internal.PackageDetails) ([]Vulnerabilities, error) { | ||
| batches := batchPkgs(pkgs, db.BatchSize) | ||
|
|
||
| vulnerabilities := make([]Vulnerabilities, 0, len(pkgs)) | ||
| var eg errgroup.Group | ||
|
|
||
| for _, batch := range batches { | ||
| results, err := db.checkBatch(batch) | ||
| // use a sensible upper limit so it's not possible to have inf. operations going | ||
| // even though it's very unlikely there will be more than a couple of batches | ||
| eg.SetLimit(100) | ||
|
|
||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| batchResults := make([][][]ObjectWithID, len(batches)) | ||
|
|
||
| for i, batch := range batches { | ||
| eg.Go(func() error { | ||
| results, err := db.checkBatch(batch) | ||
|
|
||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| batchResults[i] = results | ||
|
Comment on lines
+184
to
+191
|
||
|
|
||
| return nil | ||
| }) | ||
| } | ||
|
|
||
| err := eg.Wait() | ||
|
|
||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| vulnerabilities := make([]Vulnerabilities, 0, len(pkgs)) | ||
|
|
||
| // todo: pretty sure some of these loops and slices can be merged and simplified | ||
| for _, results := range batchResults { | ||
| for _, withIDs := range results { | ||
| vulns := make(Vulnerabilities, 0, len(withIDs)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -553,11 +553,15 @@ func TestAPIDB_Check_Batches(t *testing.T) { | |
| mux.HandleFunc("/querybatch", func(w http.ResponseWriter, r *http.Request) { | ||
| requestCount++ | ||
|
||
|
|
||
| if requestCount > 2 { | ||
| t.Errorf("unexpected number of requests (%d)", requestCount) | ||
| } | ||
|
|
||
| var expectedPayload []apiQuery | ||
| var batchResponse []objectsWithIDs | ||
|
|
||
| switch requestCount { | ||
| case 1: | ||
| // strictly speaking not the best of checks, but it should be good enough | ||
| if r.ContentLength > 100 { | ||
| expectedPayload = []apiQuery{ | ||
| { | ||
| Version: "1.0.0", | ||
|
|
@@ -569,16 +573,14 @@ func TestAPIDB_Check_Batches(t *testing.T) { | |
| }, | ||
| } | ||
| batchResponse = []objectsWithIDs{{}, {}} | ||
| case 2: | ||
| } else if r.ContentLength > 50 { | ||
| expectedPayload = []apiQuery{ | ||
| { | ||
| Version: "2.3.1", | ||
| Package: apiPackage{Name: "their-package", Ecosystem: lockfile.NpmEcosystem}, | ||
| }, | ||
| } | ||
| batchResponse = []objectsWithIDs{{}} | ||
| default: | ||
| t.Errorf("unexpected number of requests (%d)", requestCount) | ||
| } | ||
|
|
||
| expectRequestPayload(t, r, expectedPayload) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,65 +2,32 @@ package database | |
|
|
||
| import ( | ||
| "sort" | ||
| ) | ||
|
|
||
| // a struct to hold the result from each request including an index | ||
| // which will be used for sorting the results after they come in | ||
| type result struct { | ||
| index int | ||
| res OSV | ||
| err error | ||
| } | ||
| "golang.org/x/sync/errgroup" | ||
| ) | ||
|
|
||
| func (db APIDB) FetchAll(ids []string) Vulnerabilities { | ||
| conLimit := 200 | ||
| var eg errgroup.Group | ||
|
|
||
| var osvs Vulnerabilities | ||
|
|
||
| if len(ids) == 0 { | ||
| return osvs | ||
| } | ||
| eg.SetLimit(200) | ||
|
|
||
| // buffered channel which controls the number of concurrent operations | ||
| semaphoreChan := make(chan struct{}, conLimit) | ||
| resultsChan := make(chan *result) | ||
|
|
||
| defer func() { | ||
| close(semaphoreChan) | ||
| close(resultsChan) | ||
| }() | ||
| osvs := make(Vulnerabilities, len(ids)) | ||
|
|
||
| for i, id := range ids { | ||
| go func(i int, id string) { | ||
| // read from the buffered semaphore channel, which will block if we're | ||
| // already got as many goroutines as our concurrency limit allows | ||
| // | ||
| // when one of those routines finish they'll read from this channel, | ||
| // freeing up a slot to unblock this send | ||
| semaphoreChan <- struct{}{} | ||
|
|
||
| eg.Go(func() error { | ||
| // if we error, still report the vulnerability as hopefully the ID should be | ||
| // enough to manually look up the details - in future we should ideally warn | ||
| // the user too, but for now we just silently eat the error | ||
| osv, _ := db.Fetch(id) | ||
| result := &result{i, osv, nil} | ||
|
|
||
| resultsChan <- result | ||
| osvs[i] = osv | ||
|
Comment on lines
+17
to
+23
|
||
|
|
||
| // read from the buffered semaphore to free up a slot to allow | ||
| // another goroutine to start, since this one is wrapping up | ||
| <-semaphoreChan | ||
| }(i, id) | ||
| return nil | ||
| }) | ||
| } | ||
|
|
||
| for { | ||
| result := <-resultsChan | ||
| osvs = append(osvs, result.res) | ||
|
|
||
| if len(osvs) == len(ids) { | ||
| break | ||
| } | ||
| } | ||
| // errors are handled within the go routines | ||
| _ = eg.Wait() | ||
|
|
||
| sort.Slice(osvs, func(i, j int) bool { | ||
| return osvs[i].ID < osvs[j].ID | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing concurrency limit for errgroup
Unlike
FetchAll()which sets a limit of 200 concurrent operations witheg.SetLimit(200)(line 12 in api-fetch-all.go), this errgroup has no limit. This could potentially spawn a large number of goroutines if there are many batches, which may overwhelm system resources or the API server.Suggestion: Consider adding a concurrency limit consistent with the existing pattern: