Skip to content

Commit 3aebde8

Browse files
authored
fix(librarian): support SSH remotes (#1898)
Fixes: #1791
1 parent cedfb49 commit 3aebde8

File tree

3 files changed

+91
-13
lines changed

3 files changed

+91
-13
lines changed

internal/github/github.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,41 @@ type PullRequestMetadata struct {
8686
Number int
8787
}
8888

89-
// ParseURL parses a GitHub URL (anything to do with a repository) to determine
89+
// ParseRemote parses a GitHub remote (anything to do with a repository) to determine
9090
// the GitHub repo details (owner and name).
91-
func ParseURL(remoteURL string) (*Repository, error) {
92-
if !strings.HasPrefix(remoteURL, "https://github.com/") {
93-
return nil, fmt.Errorf("remote '%s' is not a GitHub remote", remoteURL)
91+
func ParseRemote(remote string) (*Repository, error) {
92+
if strings.HasPrefix(remote, "https://github.com/") {
93+
return parseHTTPRemote(remote)
9494
}
95-
remotePath := remoteURL[len("https://github.com/"):]
95+
if strings.HasPrefix(remote, "git@") {
96+
return parseSSHRemote(remote)
97+
}
98+
return nil, fmt.Errorf("remote '%s' is not a GitHub remote", remote)
99+
}
100+
101+
func parseHTTPRemote(remote string) (*Repository, error) {
102+
remotePath := remote[len("https://github.com/"):]
96103
pathParts := strings.Split(remotePath, "/")
97104
organization := pathParts[0]
98105
repoName := pathParts[1]
99106
repoName = strings.TrimSuffix(repoName, ".git")
100107
return &Repository{Owner: organization, Name: repoName}, nil
101108
}
102109

110+
func parseSSHRemote(remote string) (*Repository, error) {
111+
pathParts := strings.Split(remote, ":")
112+
if len(pathParts) != 2 {
113+
return nil, fmt.Errorf("remote %q is not a GitHub remote", remote)
114+
}
115+
orgRepo := strings.Split(pathParts[1], "/")
116+
if len(orgRepo) != 2 {
117+
return nil, fmt.Errorf("remote %q is not a GitHub remote", remote)
118+
}
119+
organization := orgRepo[0]
120+
repoName := strings.TrimSuffix(orgRepo[1], ".git")
121+
return &Repository{Owner: organization, Name: repoName}, nil
122+
}
123+
103124
// GetRawContent fetches the raw content of a file within a repository repo,
104125
// identifying the file by path, at a specific commit/tag/branch of ref.
105126
func (c *Client) GetRawContent(ctx context.Context, path, ref string) ([]byte, error) {
@@ -181,11 +202,9 @@ func FetchGitHubRepoFromRemote(repo gitrepo.Repository) (*Repository, error) {
181202
for _, remote := range remotes {
182203
if remote.Config().Name == "origin" {
183204
urls := remote.Config().URLs
184-
if len(urls) > 0 && strings.HasPrefix(urls[0], "https://github.com/") {
185-
return ParseURL(urls[0])
205+
if len(urls) > 0 {
206+
return ParseRemote(urls[0])
186207
}
187-
// If 'origin' exists but is not a GitHub remote, we stop.
188-
break
189208
}
190209
}
191210

internal/github/github_test.go

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func TestFetchGitHubRepoFromRemote(t *testing.T) {
182182
"origin": {"https://gitlab.com/owner/repo.git"},
183183
},
184184
wantErr: true,
185-
wantErrSubstr: "could not find an 'origin' remote",
185+
wantErrSubstr: "is not a GitHub remote",
186186
},
187187
{
188188
name: "upstream is GitHub, but no origin",
@@ -208,7 +208,7 @@ func TestFetchGitHubRepoFromRemote(t *testing.T) {
208208
"upstream": {"https://github.com/gh-owner/gh-repo.git"},
209209
},
210210
wantErr: true,
211-
wantErrSubstr: "could not find an 'origin' remote",
211+
wantErrSubstr: "is not a GitHub remote",
212212
},
213213
{
214214
name: "origin has multiple URLs, first is GitHub",
@@ -278,7 +278,7 @@ func TestParseURL(t *testing.T) {
278278
} {
279279
t.Run(test.name, func(t *testing.T) {
280280
t.Parallel()
281-
repo, err := ParseURL(test.remoteURL)
281+
repo, err := ParseRemote(test.remoteURL)
282282

283283
if test.wantErr {
284284
if err == nil {
@@ -298,6 +298,65 @@ func TestParseURL(t *testing.T) {
298298
}
299299
}
300300

301+
func TestParseSSHRemote(t *testing.T) {
302+
t.Parallel()
303+
for _, test := range []struct {
304+
name string
305+
remote string
306+
wantRepo *Repository
307+
wantErr bool
308+
wantErrSubstr string
309+
}{
310+
{
311+
name: "Valid SSH URL with .git",
312+
remote: "[email protected]:owner/repo.git",
313+
wantRepo: &Repository{Owner: "owner", Name: "repo"},
314+
},
315+
{
316+
name: "Valid SSH URL without .git",
317+
remote: "[email protected]:owner/repo",
318+
wantRepo: &Repository{Owner: "owner", Name: "repo"},
319+
},
320+
{
321+
name: "Invalid remote, no git@ prefix",
322+
remote: "https://github.com/owner/repo.git",
323+
wantErr: true,
324+
wantErrSubstr: "not a GitHub remote",
325+
},
326+
{
327+
name: "Invalid remote, no colon",
328+
remote: "[email protected]/repo.git",
329+
wantErr: true,
330+
wantErrSubstr: "not a GitHub remote",
331+
},
332+
{
333+
name: "Invalid remote, no slash",
334+
remote: "[email protected]:owner-repo.git",
335+
wantErr: true,
336+
wantErrSubstr: "not a GitHub remote",
337+
},
338+
} {
339+
t.Run(test.name, func(t *testing.T) {
340+
t.Parallel()
341+
repo, err := parseSSHRemote(test.remote)
342+
if test.wantErr {
343+
if err == nil {
344+
t.Errorf("ParseSSHRemote() err = nil, want error containing %q", test.wantErrSubstr)
345+
} else if !strings.Contains(err.Error(), test.wantErrSubstr) {
346+
t.Errorf("ParseSSHRemote() err = %v, want error containing %q", err, test.wantErrSubstr)
347+
}
348+
} else {
349+
if err != nil {
350+
t.Errorf("ParseSSHRemote() err = %v, want nil", err)
351+
}
352+
if diff := cmp.Diff(test.wantRepo, repo); diff != "" {
353+
t.Errorf("ParseSSHRemote() repo mismatch (-want +got): %s", diff)
354+
}
355+
}
356+
})
357+
}
358+
}
359+
301360
func TestCreatePullRequest(t *testing.T) {
302361
t.Parallel()
303362
for _, test := range []struct {

internal/librarian/command.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func newCommandRunner(cfg *config.Config) (*commandRunner, error) {
100100

101101
var gitRepo *github.Repository
102102
if isURL(cfg.Repo) {
103-
gitRepo, err = github.ParseURL(cfg.Repo)
103+
gitRepo, err = github.ParseRemote(cfg.Repo)
104104
if err != nil {
105105
return nil, fmt.Errorf("failed to parse repo url: %w", err)
106106
}

0 commit comments

Comments
 (0)