Skip to content

Commit be2f36f

Browse files
committed
resolveGitReference helper
1 parent 3269af4 commit be2f36f

File tree

2 files changed

+153
-34
lines changed

2 files changed

+153
-34
lines changed

pkg/github/repositories.go

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"net/http"
1010
"net/url"
11-
"strconv"
1211
"strings"
1312

1413
ghErrors "github.com/github/github-mcp-server/pkg/errors"
@@ -495,33 +494,18 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t
495494
return mcp.NewToolResultError(err.Error()), nil
496495
}
497496

498-
rawOpts := &raw.ContentOpts{}
499-
500-
if strings.HasPrefix(ref, "refs/pull/") {
501-
prNumber := strings.TrimSuffix(strings.TrimPrefix(ref, "refs/pull/"), "/head")
502-
if len(prNumber) > 0 {
503-
// fetch the PR from the API to get the latest commit and use SHA
504-
githubClient, err := getClient(ctx)
505-
if err != nil {
506-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
507-
}
508-
prNum, err := strconv.Atoi(prNumber)
509-
if err != nil {
510-
return nil, fmt.Errorf("invalid pull request number: %w", err)
511-
}
512-
pr, _, err := githubClient.PullRequests.Get(ctx, owner, repo, prNum)
513-
if err != nil {
514-
return nil, fmt.Errorf("failed to get pull request: %w", err)
515-
}
516-
sha = pr.GetHead().GetSHA()
517-
ref = ""
518-
}
497+
client, err := getClient(ctx)
498+
if err != nil {
499+
return mcp.NewToolResultError("failed to get GitHub client"), nil
519500
}
520501

521-
rawOpts.SHA = sha
522-
rawOpts.Ref = ref
502+
rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha)
503+
if err != nil {
504+
return mcp.NewToolResultError(err.Error()), nil
505+
}
523506

524-
// If the path is (most likely) not to be a directory, we will first try to get the raw content from the GitHub raw content API.
507+
// If the path is (most likely) not to be a directory, we will
508+
// first try to get the raw content from the GitHub raw content API.
525509
if path != "" && !strings.HasSuffix(path, "/") {
526510

527511
rawClient, err := getRawClient(ctx)
@@ -580,13 +564,8 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t
580564
}
581565
}
582566

583-
client, err := getClient(ctx)
584-
if err != nil {
585-
return mcp.NewToolResultError("failed to get GitHub client"), nil
586-
}
587-
588-
if sha != "" {
589-
ref = sha
567+
if rawOpts.SHA != "" {
568+
ref = rawOpts.SHA
590569
}
591570
if strings.HasSuffix(path, "/") {
592571
opts := &github.RepositoryContentGetOptions{Ref: ref}
@@ -632,7 +611,7 @@ func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t t
632611
if err != nil {
633612
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal matching files: %s", err)), nil
634613
}
635-
return mcp.NewToolResultText(fmt.Sprintf("Provided path did not match a file or directory, but possible matches are: %s", matchingFilesJSON)), nil
614+
return mcp.NewToolResultText(fmt.Sprintf("Provided path did not match a file or directory, but resolved ref to %s with possible path matches: %s", ref, matchingFilesJSON)), nil
636615
}
637616

638617
return mcp.NewToolResultError("Failed to get file contents. The path does not point to a file or directory, or the file does not exist in the repository."), nil
@@ -1337,3 +1316,35 @@ func filterPaths(entries []*github.TreeEntry, path string, maxResults int) []str
13371316
}
13381317
return matchedPaths
13391318
}
1319+
1320+
// resolveGitReference resolves git references with the following logic:
1321+
// 1. If SHA is provided, it takes precedence
1322+
// 2. If neither is provided, use the default branch as ref
1323+
// 3. Get SHA from the ref
1324+
// Refs can look like `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`
1325+
// The function returns the resolved ref, SHA and any error.
1326+
func resolveGitReference(ctx context.Context, githubClient *github.Client, owner, repo, ref, sha string) (*raw.ContentOpts, error) {
1327+
// 1. If SHA is provided, use it directly
1328+
if sha != "" {
1329+
return &raw.ContentOpts{Ref: "", SHA: sha}, nil
1330+
}
1331+
1332+
// 2. If neither provided, use the default branch as ref
1333+
if ref == "" {
1334+
repoInfo, _, err := githubClient.Repositories.Get(ctx, owner, repo)
1335+
if err != nil {
1336+
return nil, fmt.Errorf("failed to get repository info: %w", err)
1337+
}
1338+
ref = fmt.Sprintf("refs/heads/%s", repoInfo.GetDefaultBranch())
1339+
}
1340+
1341+
// 3. Get the SHA from the ref
1342+
reference, _, err := githubClient.Git.GetRef(ctx, owner, repo, ref)
1343+
if err != nil {
1344+
return nil, fmt.Errorf("failed to get reference for default branch: %w", err)
1345+
}
1346+
sha = reference.GetObject().GetSHA()
1347+
1348+
// Use provided ref, or it will be empty which defaults to the default branch
1349+
return &raw.ContentOpts{Ref: ref, SHA: sha}, nil
1350+
}

pkg/github/repositories_test.go

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ func Test_GetFileContents(t *testing.T) {
6969
{
7070
name: "successful text content fetch",
7171
mockedClient: mock.NewMockedHTTPClient(
72+
mock.WithRequestMatchHandler(
73+
mock.GetReposGitRefByOwnerByRepoByRef,
74+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
75+
w.WriteHeader(http.StatusOK)
76+
_, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`))
77+
}),
78+
),
7279
mock.WithRequestMatchHandler(
7380
raw.GetRawReposContentsByOwnerByRepoByBranchByPath,
7481
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -93,6 +100,13 @@ func Test_GetFileContents(t *testing.T) {
93100
{
94101
name: "successful file blob content fetch",
95102
mockedClient: mock.NewMockedHTTPClient(
103+
mock.WithRequestMatchHandler(
104+
mock.GetReposGitRefByOwnerByRepoByRef,
105+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
106+
w.WriteHeader(http.StatusOK)
107+
_, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`))
108+
}),
109+
),
96110
mock.WithRequestMatchHandler(
97111
raw.GetRawReposContentsByOwnerByRepoByBranchByPath,
98112
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -117,6 +131,20 @@ func Test_GetFileContents(t *testing.T) {
117131
{
118132
name: "successful directory content fetch",
119133
mockedClient: mock.NewMockedHTTPClient(
134+
mock.WithRequestMatchHandler(
135+
mock.GetReposByOwnerByRepo,
136+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
137+
w.WriteHeader(http.StatusOK)
138+
_, _ = w.Write([]byte(`{"name": "repo", "default_branch": "main"}`))
139+
}),
140+
),
141+
mock.WithRequestMatchHandler(
142+
mock.GetReposGitRefByOwnerByRepoByRef,
143+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
144+
w.WriteHeader(http.StatusOK)
145+
_, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`))
146+
}),
147+
),
120148
mock.WithRequestMatchHandler(
121149
mock.GetReposContentsByOwnerByRepoByPath,
122150
expectQueryParams(t, map[string]string{}).andThen(
@@ -143,6 +171,13 @@ func Test_GetFileContents(t *testing.T) {
143171
{
144172
name: "content fetch fails",
145173
mockedClient: mock.NewMockedHTTPClient(
174+
mock.WithRequestMatchHandler(
175+
mock.GetReposGitRefByOwnerByRepoByRef,
176+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
177+
w.WriteHeader(http.StatusOK)
178+
_, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`))
179+
}),
180+
),
146181
mock.WithRequestMatchHandler(
147182
mock.GetReposContentsByOwnerByRepoByPath,
148183
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -203,7 +238,7 @@ func Test_GetFileContents(t *testing.T) {
203238
textContent := getTextResult(t, result)
204239
var returnedContents []*github.RepositoryContent
205240
err = json.Unmarshal([]byte(textContent.Text), &returnedContents)
206-
require.NoError(t, err)
241+
require.NoError(t, err, "Failed to unmarshal directory content result: %v", textContent.Text)
207242
assert.Len(t, returnedContents, len(expected))
208243
for i, content := range returnedContents {
209244
assert.Equal(t, *expected[i].Name, *content.Name)
@@ -2049,3 +2084,76 @@ func Test_GetTag(t *testing.T) {
20492084
})
20502085
}
20512086
}
2087+
2088+
func Test_ResolveGitReference(t *testing.T) {
2089+
2090+
ctx := context.Background()
2091+
owner := "owner"
2092+
repo := "repo"
2093+
mockedClient := mock.NewMockedHTTPClient(
2094+
mock.WithRequestMatchHandler(
2095+
mock.GetReposByOwnerByRepo,
2096+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
2097+
w.WriteHeader(http.StatusOK)
2098+
_, _ = w.Write([]byte(`{"name": "repo", "default_branch": "main"}`))
2099+
}),
2100+
),
2101+
mock.WithRequestMatchHandler(
2102+
mock.GetReposGitRefByOwnerByRepoByRef,
2103+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
2104+
w.WriteHeader(http.StatusOK)
2105+
_, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": "123sha456"}}`))
2106+
}),
2107+
),
2108+
)
2109+
2110+
tests := []struct {
2111+
name string
2112+
ref string
2113+
sha string
2114+
expectedOutput *raw.ContentOpts
2115+
}{
2116+
{
2117+
name: "sha takes precedence over ref",
2118+
ref: "refs/heads/main",
2119+
sha: "123sha456",
2120+
expectedOutput: &raw.ContentOpts{
2121+
SHA: "123sha456",
2122+
},
2123+
},
2124+
{
2125+
name: "use default branch if ref and sha both empty",
2126+
ref: "",
2127+
sha: "",
2128+
expectedOutput: &raw.ContentOpts{
2129+
Ref: "refs/heads/main",
2130+
SHA: "123sha456",
2131+
},
2132+
},
2133+
{
2134+
name: "get SHA from ref",
2135+
ref: "refs/heads/main",
2136+
sha: "",
2137+
expectedOutput: &raw.ContentOpts{
2138+
Ref: "refs/heads/main",
2139+
SHA: "123sha456",
2140+
},
2141+
},
2142+
}
2143+
2144+
for _, tc := range tests {
2145+
t.Run(tc.name, func(t *testing.T) {
2146+
// Setup client with mock
2147+
client := github.NewClient(mockedClient)
2148+
opts, err := resolveGitReference(ctx, client, owner, repo, tc.ref, tc.sha)
2149+
require.NoError(t, err)
2150+
2151+
if tc.expectedOutput.SHA != "" {
2152+
assert.Equal(t, tc.expectedOutput.SHA, opts.SHA)
2153+
}
2154+
if tc.expectedOutput.Ref != "" {
2155+
assert.Equal(t, tc.expectedOutput.Ref, opts.Ref)
2156+
}
2157+
})
2158+
}
2159+
}

0 commit comments

Comments
 (0)