Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 97 additions & 3 deletions internal/repository/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/elliotchance/pie/v2"
"github.com/google/go-github/v68/github"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
)

Expand All @@ -25,7 +26,11 @@ type githubService struct {

// newGithubRepo creates a new GitHub repository service
func New(token string) githubService {
client := github.NewClient(nil)
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
tc := oauth2.NewClient(context.Background(), ts)
client := github.NewClient(tc)
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
Expand Down Expand Up @@ -70,12 +75,101 @@ func (s githubService) GetProjectList(paths []string) (projects []repository.Pro

// CloseVulnerabilityIssue closes the vulnerability issue for the given project
func (s githubService) CloseVulnerabilityIssue(project repository.Project) (err error) {
return errors.New("CloseVulnerabilityIssue not yet implemented") // TODO #9 Add github support
issue, err := s.getVulnerabilityIssue(project.GroupOrOwner, project.Name)
if err != nil {
return fmt.Errorf("failed to fetch current list of issues: %w", err)
}
if issue == nil {
log.Info().Str("project", project.Path).Msg("No issue to close, nothing to do")
return nil
}
if issue.GetState() == "closed" {
log.Info().Str("project", project.Path).Msg("Issue already closed")
return nil
}
state := "closed"
_, _, err = s.client.UpdateIssue(project.GroupOrOwner, project.Name, issue.GetNumber(), &github.IssueRequest{
State: &state,
})
if err != nil {
return fmt.Errorf("failed to update issue: %w", err)
}
log.Info().Str("project", project.Path).Msg("Issue closed")
return nil
}

// OpenVulnerabilityIssue opens or updates the vulnerability issue for the given project
func (s githubService) OpenVulnerabilityIssue(project repository.Project, report string) (issue *repository.Issue, err error) {
return nil, errors.New("OpenVulnerabilityIssue not yet implemented") // TODO #9 Add github support
vulnTitle := repository.VulnerabilityIssueTitle
ghIssue, err := s.getVulnerabilityIssue(project.GroupOrOwner, project.Name)
if err != nil {
return nil, fmt.Errorf("[%v] Failed to fetch current list of issues: %w", project.Path, err)
}
if ghIssue == nil {
log.Info().Str("project", project.Path).Msg("Creating new issue")
newIssue := &github.IssueRequest{
Title: &vulnTitle,
Body: &report,
}
created, _, err := s.client.CreateIssue(project.GroupOrOwner, project.Name, newIssue)
if err != nil {
return nil, fmt.Errorf("[%v] failed to create new issue: %w", project.Path, err)
}
return mapGithubIssuePtr(created), nil
}
log.Info().Str("project", project.Path).Int("issue", ghIssue.GetNumber()).Msg("Updating existing issue")
state := "open"
updatedIssue := &github.IssueRequest{
Body: &report,
State: &state,
}
edited, _, err := s.client.UpdateIssue(project.GroupOrOwner, project.Name, ghIssue.GetNumber(), updatedIssue)
if err != nil {
return nil, fmt.Errorf("[%v] Failed to update issue: %w", project.Path, err)
}
if edited.GetState() != "open" {
return nil, errors.New("failed to reopen issue")
}
return mapGithubIssuePtr(edited), nil
}

// getVulnerabilityIssue returns the vulnerability issue for the given repo (by title)
func (s githubService) getVulnerabilityIssue(owner, repo string) (*github.Issue, error) {
opts := &github.IssueListByRepoOptions{
State: "all",
ListOptions: github.ListOptions{PerPage: 100},
}
vulnTitle := repository.VulnerabilityIssueTitle
for {
issues, resp, err := s.client.ListRepositoryIssues(owner, repo, opts)
if err != nil {
return nil, err
}
for _, issue := range issues {
if issue != nil && issue.GetTitle() == vulnTitle {
return issue, nil
}
}
if resp.NextPage == 0 {
break
}
opts.Page = resp.NextPage
}
return nil, nil
}
func mapGithubIssue(i github.Issue) repository.Issue {
return repository.Issue{
Title: i.GetTitle(),
WebURL: i.GetHTMLURL(),
}
}

func mapGithubIssuePtr(i *github.Issue) *repository.Issue {
if i == nil {
return nil
}
issue := mapGithubIssue(*i)
return &issue
}

func (s githubService) Download(project repository.Project, dir string) (err error) {
Expand Down
40 changes: 35 additions & 5 deletions internal/repository/github/github_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,67 @@ package github
import (
"context"
"net/url"
"time"

"github.com/google/go-github/v68/github"
)

const defaultTimeout = 30 * time.Second

// This client is a thin wrapper around the go-github library. It provides an interface to the GitHub client
// The main purpose of this client is to provide an interface to the GitHub client which can be mocked in tests.
// As such this MUST be as thin as possible and MUST not contain any business logic, since it is not testable.

type iGithubClient interface {
GetRepository(owner string, repo string) (*github.Repository, *github.Response, error)
GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error)
GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error)
GetArchiveLink(owner string, repo string, archiveFormat github.ArchiveFormat, opts *github.RepositoryContentGetOptions) (*url.URL, *github.Response, error)
ListRepositoryIssues(owner string, repo string, opts *github.IssueListByRepoOptions) ([]*github.Issue, *github.Response, error)
CreateIssue(owner string, repo string, issue *github.IssueRequest) (*github.Issue, *github.Response, error)
UpdateIssue(owner string, repo string, number int, issue *github.IssueRequest) (*github.Issue, *github.Response, error)
}

type githubClient struct {
client *github.Client
}

func (c *githubClient) ListRepositoryIssues(owner, repo string, opts *github.IssueListByRepoOptions) ([]*github.Issue, *github.Response, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Issues.ListByRepo(ctx, owner, repo, opts)
}

func (c *githubClient) CreateIssue(owner, repo string, issue *github.IssueRequest) (*github.Issue, *github.Response, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Issues.Create(ctx, owner, repo, issue)
}

func (c *githubClient) UpdateIssue(owner, repo string, number int, issue *github.IssueRequest) (*github.Issue, *github.Response, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Issues.Edit(ctx, owner, repo, number, issue)
}
func (c *githubClient) GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) {
return c.client.Repositories.Get(context.Background(), owner, repo)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Repositories.Get(ctx, owner, repo)
}

func (c *githubClient) GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) {
return c.client.Repositories.ListByOrg(context.Background(), org, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Repositories.ListByOrg(ctx, org, opts)
}

func (c *githubClient) GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) {
return c.client.Repositories.ListByUser(context.Background(), user, opts)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Repositories.ListByUser(ctx, user, opts)
}

func (c *githubClient) GetArchiveLink(owner string, repo string, archiveFormat github.ArchiveFormat, opts *github.RepositoryContentGetOptions) (*url.URL, *github.Response, error) {
return c.client.Repositories.GetArchiveLink(context.Background(), owner, repo, archiveFormat, opts, 3)
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
return c.client.Repositories.GetArchiveLink(ctx, owner, repo, archiveFormat, opts, 3)
}
150 changes: 110 additions & 40 deletions internal/repository/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,46 +109,6 @@ func TestGetProjectListWithNextPage(t *testing.T) {
mockService.AssertExpectations(t)
}

type mockService struct {
mock.Mock
}

func (c *mockService) GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) {
args := c.Called(owner, repo)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*github.Repository), r, args.Error(2)
}

func (c *mockService) GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) {
args := c.Called(org, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).([]*github.Repository), r, args.Error(2)
}

func (c *mockService) GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) {
args := c.Called(user, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).([]*github.Repository), r, args.Error(2)
}

func (c *mockService) GetArchiveLink(owner string, repo string, archiveFormat github.ArchiveFormat, opts *github.RepositoryContentGetOptions) (*url.URL, *github.Response, error) {
args := c.Called(owner, repo, archiveFormat, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*url.URL), r, args.Error(2)
}

func TestDownload(t *testing.T) {
// Create temporary directory for testing
tempDir, err := os.MkdirTemp("", "sheriff-clone-test-")
Expand Down Expand Up @@ -211,3 +171,113 @@ func TestDownload(t *testing.T) {
_, err = os.Stat(filepath.Join(tempDir, "src"))
assert.NoError(t, err, "src directory should exist")
}

func TestOpenVulnerabilityIssue(t *testing.T) {
title := repository.VulnerabilityIssueTitle
mockClient := mockService{}
mockClient.On("ListRepositoryIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*github.Issue{}, &github.Response{}, nil)
mockClient.On("CreateIssue", mock.Anything, mock.Anything, mock.Anything).Return(&github.Issue{Title: &title}, &github.Response{}, nil)

svc := githubService{client: &mockClient}

i, err := svc.OpenVulnerabilityIssue(repository.Project{GroupOrOwner: "group", Name: "repo"}, "report")
assert.Nil(t, err)
assert.NotNil(t, i)
assert.Equal(t, repository.VulnerabilityIssueTitle, i.Title)
mockClient.AssertExpectations(t)
}

func TestCloseVulnerabilityIssue(t *testing.T) {
title := repository.VulnerabilityIssueTitle
state := "open"
mockClient := mockService{}
mockClient.On("ListRepositoryIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*github.Issue{{Title: &title, State: &state}}, &github.Response{}, nil)
mockClient.On("UpdateIssue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&github.Issue{}, &github.Response{}, nil)

svc := githubService{client: &mockClient}

err := svc.CloseVulnerabilityIssue(repository.Project{GroupOrOwner: "group", Name: "repo"})
assert.Nil(t, err)
mockClient.AssertExpectations(t)
}

func TestCloseVulnerabilityIssueNoIssue(t *testing.T) {
mockClient := mockService{}
mockClient.On("ListRepositoryIssues", mock.Anything, mock.Anything, mock.Anything).Return(nil, &github.Response{}, nil)

svc := githubService{client: &mockClient}

err := svc.CloseVulnerabilityIssue(repository.Project{GroupOrOwner: "group", Name: "repo"})
assert.Nil(t, err)
mockClient.AssertExpectations(t)
}

type mockService struct {
mock.Mock
}

func (c *mockService) GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) {
args := c.Called(owner, repo)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*github.Repository), r, args.Error(2)
}

func (c *mockService) GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) {
args := c.Called(org, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).([]*github.Repository), r, args.Error(2)
}

func (c *mockService) GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) {
args := c.Called(user, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).([]*github.Repository), r, args.Error(2)
}

func (c *mockService) GetArchiveLink(owner string, repo string, archiveFormat github.ArchiveFormat, opts *github.RepositoryContentGetOptions) (*url.URL, *github.Response, error) {
args := c.Called(owner, repo, archiveFormat, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*url.URL), r, args.Error(2)
}

func (c *mockService) ListRepositoryIssues(owner string, repo string, opts *github.IssueListByRepoOptions) ([]*github.Issue, *github.Response, error) {
args := c.Called(owner, repo, opts)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
if args.Get(0) == nil {
return nil, r, args.Error(2)
}
return args.Get(0).([]*github.Issue), r, args.Error(2)
}

func (c *mockService) CreateIssue(owner string, repo string, issue *github.IssueRequest) (*github.Issue, *github.Response, error) {
args := c.Called(owner, repo, issue)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*github.Issue), r, args.Error(2)
}

func (c *mockService) UpdateIssue(owner string, repo string, number int, issue *github.IssueRequest) (*github.Issue, *github.Response, error) {
args := c.Called(owner, repo, number, issue)
var r *github.Response
if resp := args.Get(1); resp != nil {
r = args.Get(1).(*github.Response)
}
return args.Get(0).(*github.Issue), r, args.Error(2)
}
Loading