Skip to content

Commit fb3d067

Browse files
committed
internal/ghsa: refactor client
Instead of initializing a new client in every function call, move the functions in internal/ghsa to methods on a client. This will make it easier to add unit tests in a follow up CL. Change-Id: Ifdd7ee04e884822a94d489d4f6fde3035441f152 Reviewed-on: https://go-review.googlesource.com/c/vulndb/+/458202 TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Tatiana Bradley <[email protected]> Run-TryBot: Julie Qiu <[email protected]> Reviewed-by: Julie Qiu <[email protected]>
1 parent 9700ce4 commit fb3d067

File tree

6 files changed

+84
-93
lines changed

6 files changed

+84
-93
lines changed

cmd/issue/main.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ func main() {
5151
log.Fatal(err)
5252
}
5353
c := issues.NewClient(&issues.Config{Owner: owner, Repo: repoName, Token: *githubToken})
54+
ghsaClient := ghsa.NewClient(ctx, *githubToken)
5455
switch cmd {
5556
case "triage":
56-
err = createIssueToTriage(ctx, c, *githubToken, filename)
57+
err = createIssueToTriage(ctx, c, ghsaClient, filename)
5758
case "excluded":
58-
err = createExcluded(ctx, c, *githubToken, filename)
59+
err = createExcluded(ctx, c, ghsaClient, filename)
5960
default:
6061
err = fmt.Errorf("unsupported command: %q", cmd)
6162
}
@@ -64,42 +65,42 @@ func main() {
6465
}
6566
}
6667

67-
func createIssueToTriage(ctx context.Context, c *issues.Client, ghToken, filename string) (err error) {
68+
func createIssueToTriage(ctx context.Context, c *issues.Client, ghsaClient *ghsa.Client, filename string) (err error) {
6869
aliases, err := parseAliases(filename)
6970
if err != nil {
7071
return err
7172
}
7273
for _, alias := range aliases {
73-
if err := constructIssue(ctx, c, alias, ghToken, []string{"NeedsTriage"}); err != nil {
74+
if err := constructIssue(ctx, c, ghsaClient, alias, []string{"NeedsTriage"}); err != nil {
7475
return err
7576
}
7677
}
7778
return nil
7879
}
7980

80-
func createExcluded(ctx context.Context, c *issues.Client, ghToken, filename string) (err error) {
81+
func createExcluded(ctx context.Context, c *issues.Client, ghsaClient *ghsa.Client, filename string) (err error) {
8182
records, err := parseExcluded(filename)
8283
if err != nil {
8384
return err
8485
}
8586
for _, r := range records {
86-
if err := constructIssue(ctx, c, r.identifier, ghToken, []string{fmt.Sprintf("excluded: %s", r.category)}); err != nil {
87+
if err := constructIssue(ctx, c, ghsaClient, r.identifier, []string{fmt.Sprintf("excluded: %s", r.category)}); err != nil {
8788
return err
8889
}
8990
}
9091
return nil
9192
}
9293

93-
func constructIssue(ctx context.Context, c *issues.Client, alias, ghToken string, labels []string) (err error) {
94+
func constructIssue(ctx context.Context, c *issues.Client, ghsaClient *ghsa.Client, alias string, labels []string) (err error) {
9495
var ghsas []*ghsa.SecurityAdvisory
9596
if strings.HasPrefix(alias, "GHSA") {
96-
sa, err := ghsa.FetchGHSA(ctx, ghToken, alias)
97+
sa, err := ghsaClient.FetchGHSA(ctx, alias)
9798
if err != nil {
9899
return err
99100
}
100101
ghsas = append(ghsas, sa)
101102
} else if strings.HasPrefix(alias, "CVE") {
102-
ghsas, err = ghsa.ListForCVE(ctx, ghToken, alias)
103+
ghsas, err = ghsaClient.ListForCVE(ctx, alias)
103104
if err != nil {
104105
return err
105106
}

cmd/vulnreport/main.go

Lines changed: 38 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -80,55 +80,56 @@ func main() {
8080
*githubToken = os.Getenv("VULN_GITHUB_ACCESS_TOKEN")
8181
}
8282

83-
cmd := flag.Arg(0)
84-
85-
// Create-excluded has no args, so it is separated form the other commands.
86-
if cmd == "create-excluded" {
87-
_, cfg, err := setupCreate(ctx, nil)
88-
if err != nil {
89-
log.Fatal(err)
83+
var (
84+
args []string
85+
cmd = flag.Arg(0)
86+
)
87+
if cmd != "create-excluded" {
88+
if flag.NArg() < 2 {
89+
flag.Usage()
90+
log.Fatal("not enough arguments")
9091
}
91-
if err = createExcluded(ctx, cfg); err != nil {
92-
log.Fatal(err)
93-
}
94-
return
92+
args = flag.Args()[1:]
9593
}
9694

97-
if flag.NArg() < 2 {
98-
flag.Usage()
99-
log.Fatal("not enough arguments")
100-
}
101-
102-
args := flag.Args()[1:]
103-
104-
// Create operates on github issue IDs instead of filenames, so it is
105-
// separated from the other commands.
106-
if cmd == "create" {
95+
// setupCreate clones the CVEList repo and can be very slow,
96+
// so commands that require this functionality are separated from other
97+
// commands.
98+
if cmd == "create-excluded" || cmd == "create" {
10799
githubIDs, cfg, err := setupCreate(ctx, args)
108100
if err != nil {
109101
log.Fatal(err)
110102
}
111-
for _, githubID := range githubIDs {
112-
if err := create(ctx, githubID, cfg); err != nil {
113-
fmt.Printf("skipped: %s\n", err)
103+
switch cmd {
104+
case "create-excluded":
105+
if err = createExcluded(ctx, cfg); err != nil {
106+
log.Fatal(err)
107+
}
108+
case "create":
109+
// Unlike commands below, create operates on github issue IDs
110+
// instead of filenames.
111+
for _, githubID := range githubIDs {
112+
if err := create(ctx, githubID, cfg); err != nil {
113+
fmt.Printf("skipped: %s\n", err)
114+
}
114115
}
115116
}
116-
return
117117
}
118118

119+
ghsaClient := ghsa.NewClient(ctx, *githubToken)
119120
var cmdFunc func(string) error
120121
switch cmd {
121122
case "lint":
122123
cmdFunc = lint
123124
case "commit":
124-
cmdFunc = func(name string) error { return commit(ctx, name, *githubToken) }
125+
cmdFunc = func(name string) error { return commit(ctx, name, ghsaClient) }
125126
case "cve":
126127
cmdFunc = func(name string) error { return cveCmd(ctx, name) }
127128
//TODO: (https://github.com/golang/go/issues/56356): Deprecate this command once CVE JSON 5.0 publishing is available
128129
case "cve4":
129130
cmdFunc = func(name string) error { return cve4Cmd(ctx, name, *indent) }
130131
case "fix":
131-
cmdFunc = func(name string) error { return fix(ctx, name, *githubToken) }
132+
cmdFunc = func(name string) error { return fix(ctx, name, ghsaClient) }
132133
case "osv":
133134
cmdFunc = osvCmd
134135
case "set-dates":
@@ -233,8 +234,8 @@ func parseArgsToGithubIDs(args []string, existingByIssue map[int]*report.Report)
233234
}
234235

235236
type createCfg struct {
236-
ghToken string
237237
repo *git.Repository
238+
ghsaClient *ghsa.Client
238239
issuesClient *issues.Client
239240
existingByFile map[string]*report.Report
240241
existingByIssue map[int]*report.Report
@@ -270,9 +271,9 @@ func setupCreate(ctx context.Context, args []string) ([]int, *createCfg, error)
270271
return nil, nil, err
271272
}
272273
return githubIDs, &createCfg{
273-
ghToken: *githubToken,
274274
repo: repo,
275275
issuesClient: issues.NewClient(&issues.Config{Owner: owner, Repo: repoName, Token: *githubToken}),
276+
ghsaClient: ghsa.NewClient(ctx, *githubToken),
276277
existingByFile: existingByFile,
277278
existingByIssue: existingByIssue,
278279
allowClosed: *closedOk,
@@ -287,7 +288,7 @@ func createReport(ctx context.Context, cfg *createCfg, iss *issues.Issue) (r *re
287288
}
288289
if len(parsed.ghsas) == 0 && len(parsed.cves) > 0 {
289290
for _, cve := range parsed.cves {
290-
sas, err := ghsa.ListForCVE(ctx, cfg.ghToken, cve)
291+
sas, err := cfg.ghsaClient.ListForCVE(ctx, cve)
291292
if err != nil {
292293
return nil, err
293294
}
@@ -427,7 +428,7 @@ func newReport(ctx context.Context, cfg *createCfg, parsed *parsedIssue) (*repor
427428
var r *report.Report
428429
switch {
429430
case len(parsed.ghsas) > 0:
430-
ghsa, err := ghsa.FetchGHSA(ctx, cfg.ghToken, parsed.ghsas[0])
431+
ghsa, err := cfg.ghsaClient.FetchGHSA(ctx, parsed.ghsas[0])
431432
if err != nil {
432433
return nil, err
433434
}
@@ -597,7 +598,7 @@ func lint(filename string) (err error) {
597598
return nil
598599
}
599600

600-
func fix(ctx context.Context, filename string, accessToken string) (err error) {
601+
func fix(ctx context.Context, filename string, ghsaClient *ghsa.Client) (err error) {
601602
defer derrors.Wrap(&err, "fix(%q)", filename)
602603
r, err := report.Read(filename)
603604
if err != nil {
@@ -611,7 +612,7 @@ func fix(ctx context.Context, filename string, accessToken string) (err error) {
611612
return err
612613
}
613614
}
614-
if err := fixGHSAs(ctx, r, accessToken); err != nil {
615+
if err := fixGHSAs(ctx, r, ghsaClient); err != nil {
615616
return err
616617
}
617618
// Write unconditionally in order to format.
@@ -842,12 +843,12 @@ func irun(name string, arg ...string) error {
842843
return cmd.Run()
843844
}
844845

845-
func commit(ctx context.Context, filename, accessToken string) (err error) {
846+
func commit(ctx context.Context, filename string, ghsaClient *ghsa.Client) (err error) {
846847
defer derrors.Wrap(&err, "commit(%q)", filename)
847848

848849
// Ignore errors. If anything is really wrong with the report, we'll
849850
// detect it on re-linting below.
850-
_ = fix(ctx, filename, accessToken)
851+
_ = fix(ctx, filename, ghsaClient)
851852

852853
r, err := report.ReadAndLint(filename)
853854
if err != nil {
@@ -1027,39 +1028,15 @@ func setDates(filename string, dates map[string]gitrepo.Dates) (err error) {
10271028
return r.Write(filename)
10281029
}
10291030

1030-
// loadGHSAsByCVE returns a map from CVE ID to GHSA IDs.
1031-
// It does this by using the GitHub API to list all Go security
1032-
// advisories.
1033-
func loadGHSAsByCVE(ctx context.Context, accessToken string) (_ map[string][]string, err error) {
1034-
defer derrors.Wrap(&err, "loadGHSAsByCVE")
1035-
1036-
sas, err := ghsa.List(ctx, accessToken, time.Time{})
1037-
if err != nil {
1038-
return nil, err
1039-
}
1040-
m := map[string][]string{}
1041-
for _, sa := range sas {
1042-
for _, id := range sa.Identifiers {
1043-
if id.Type == "CVE" {
1044-
m[id.Value] = append(m[id.Value], sa.ID)
1045-
}
1046-
}
1047-
}
1048-
return m, nil
1049-
}
1050-
10511031
// fixGHSAs replaces r.GHSAs with a sorted list of GitHub Security
10521032
// Advisory IDs that correspond to the CVEs.
1053-
func fixGHSAs(ctx context.Context, r *report.Report, accessToken string) error {
1054-
if accessToken == "" {
1055-
return nil
1056-
}
1033+
func fixGHSAs(ctx context.Context, r *report.Report, ghsaClient *ghsa.Client) error {
10571034
if len(r.GHSAs) > 0 && !*alwaysFixGHSA {
10581035
return nil
10591036
}
10601037
m := map[string]struct{}{}
10611038
for _, cid := range r.CVEs {
1062-
sas, err := ghsa.ListForCVE(ctx, accessToken, cid)
1039+
sas, err := ghsaClient.ListForCVE(ctx, cid)
10631040
if err != nil {
10641041
return err
10651042
}

cmd/worker/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,9 @@ func updateCommand(ctx context.Context, commitHash string) error {
213213
fmt.Printf("Missing GitHub access token; not updating GH security advisories.\n")
214214
return nil
215215
}
216+
ghsaClient := ghsa.NewClient(ctx, cfg.GitHubAccessToken)
216217
listSAs := func(ctx context.Context, since time.Time) ([]*ghsa.SecurityAdvisory, error) {
217-
return ghsa.List(ctx, cfg.GitHubAccessToken, since)
218+
return ghsaClient.List(ctx, since)
218219
}
219220
_, err = worker.UpdateGHSAs(ctx, listSAs, cfg.Store)
220221
return err

internal/ghsa/ghsa.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,25 @@ func (sa *gqlSecurityAdvisory) securityAdvisory() (*SecurityAdvisory, error) {
119119
return s, nil
120120
}
121121

122-
func newGitHubClient(ctx context.Context, accessToken string) *githubv4.Client {
122+
// Client is a client that can fetch data about GitHub security advisories.
123+
type Client struct {
124+
client *githubv4.Client
125+
token string
126+
}
127+
128+
// NewClient creates a new client for making requests to the GHSA API.
129+
func NewClient(ctx context.Context, accessToken string) *Client {
123130
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken})
124131
tc := oauth2.NewClient(ctx, ts)
125-
return githubv4.NewClient(tc)
132+
return &Client{
133+
client: githubv4.NewClient(tc),
134+
token: accessToken,
135+
}
126136
}
127137

128138
// List returns all SecurityAdvisories that affect Go,
129139
// published or updated since the given time.
130-
func List(ctx context.Context, accessToken string, since time.Time) ([]*SecurityAdvisory, error) {
131-
client := newGitHubClient(ctx, accessToken)
132-
140+
func (c *Client) List(ctx context.Context, since time.Time) ([]*SecurityAdvisory, error) {
133141
var query struct { // the GraphQL query
134142
SAs struct {
135143
Nodes []gqlSecurityAdvisory
@@ -149,7 +157,7 @@ func List(ctx context.Context, accessToken string, since time.Time) ([]*Security
149157
// We need a loop to page through the list. The GitHub API limits us to 100
150158
// values per call.
151159
for {
152-
if err := client.Query(ctx, &query, vars); err != nil {
160+
if err := c.client.Query(ctx, &query, vars); err != nil {
153161
return nil, err
154162
}
155163
for _, sa := range query.SAs.Nodes {
@@ -170,9 +178,7 @@ func List(ctx context.Context, accessToken string, since time.Time) ([]*Security
170178
return sas, nil
171179
}
172180

173-
func ListForCVE(ctx context.Context, accessToken string, cve string) ([]*SecurityAdvisory, error) {
174-
client := newGitHubClient(ctx, accessToken)
175-
181+
func (c *Client) ListForCVE(ctx context.Context, cve string) ([]*SecurityAdvisory, error) {
176182
var query struct { // The GraphQL query
177183
SAs struct {
178184
Nodes []gqlSecurityAdvisory
@@ -190,7 +196,7 @@ func ListForCVE(ctx context.Context, accessToken string, cve string) ([]*Securit
190196
"go": githubv4.SecurityAdvisoryEcosystemGo,
191197
}
192198

193-
if err := client.Query(ctx, &query, vars); err != nil {
199+
if err := c.client.Query(ctx, &query, vars); err != nil {
194200
return nil, err
195201
}
196202
if query.SAs.PageInfo.HasNextPage {
@@ -223,9 +229,7 @@ func ListForCVE(ctx context.Context, accessToken string, cve string) ([]*Securit
223229

224230
// FetchGHSA returns the SecurityAdvisory for the given Github Security
225231
// Advisory ID.
226-
func FetchGHSA(ctx context.Context, accessToken, ghsaID string) (_ *SecurityAdvisory, err error) {
227-
client := newGitHubClient(ctx, accessToken)
228-
232+
func (c *Client) FetchGHSA(ctx context.Context, ghsaID string) (_ *SecurityAdvisory, err error) {
229233
var query struct {
230234
SA gqlSecurityAdvisory `graphql:"securityAdvisory(ghsaId: $id)"`
231235
}
@@ -234,9 +238,8 @@ func FetchGHSA(ctx context.Context, accessToken, ghsaID string) (_ *SecurityAdvi
234238
"go": githubv4.SecurityAdvisoryEcosystemGo,
235239
}
236240

237-
if err := client.Query(ctx, &query, vars); err != nil {
241+
if err := c.client.Query(ctx, &query, vars); err != nil {
238242
return nil, err
239243
}
240-
241244
return query.SA.securityAdvisory()
242245
}

0 commit comments

Comments
 (0)