Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions pkg/github/repositories.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,69 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t
}
}

// ListBranches creates a tool to list branches in a GitHub repository.
func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_branches",
mcp.WithDescription(t("TOOL_LIST_BRANCHES_DESCRIPTION", "List branches in a GitHub repository")),
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
),
mcp.WithString("repo",
mcp.Required(),
mcp.Description("Repository name"),
),
WithPagination(),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pagination, err := OptionalPaginationParams(request)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

opts := &github.BranchListOptions{
ListOptions: github.ListOptions{
Page: pagination.page,
PerPage: pagination.perPage,
},
}

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}

branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts)
if err != nil {
return nil, fmt.Errorf("failed to list branches: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to list branches: %s", string(body))), nil
}

r, err := json.Marshal(branches)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository.
func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("create_or_update_file",
Expand Down
108 changes: 108 additions & 0 deletions pkg/github/repositories_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

Expand Down Expand Up @@ -1293,3 +1295,109 @@
})
}
}

func Test_ListBranches(t *testing.T) {
// Create a test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/repos/owner/repo/branches", r.URL.Path)
assert.Equal(t, "GET", r.Method)

// Check query parameters
query := r.URL.Query()
if page := query.Get("page"); page != "" {
assert.Equal(t, "2", page)
}
if perPage := query.Get("per_page"); perPage != "" {
assert.Equal(t, "30", perPage)
}

// Return mock branches
mockBranches := []github.Branch{
{Name: github.String("main")},

Check failure on line 1316 in pkg/github/repositories_test.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: github.String is deprecated: use Ptr instead. (staticcheck)
{Name: github.String("develop")},

Check failure on line 1317 in pkg/github/repositories_test.go

View workflow job for this annotation

GitHub Actions / lint

SA1019: github.String is deprecated: use Ptr instead. (staticcheck)
}
mockResponse(t, http.StatusOK, mockBranches)(w, r)
}))
defer ts.Close()

// Create a GitHub client using the test server URL
client := github.NewClient(nil)
client.BaseURL, _ = url.Parse(ts.URL + "/")

// Create the tool
tool, handler := ListBranches(stubGetClientFn(client), translations.NullTranslationHelper)

// Test cases
tests := []struct {
name string
args map[string]interface{}
wantErr bool
errContains string
}{
{
name: "success",
args: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"page": float64(2),
},
wantErr: false,
},
{
name: "missing owner",
args: map[string]interface{}{
"repo": "repo",
},
wantErr: true,
errContains: "missing required parameter: owner",
},
{
name: "missing repo",
args: map[string]interface{}{
"owner": "owner",
},
wantErr: true,
errContains: "missing required parameter: repo",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create request
request := createMCPRequest(tt.args)

// Call handler
result, err := handler(context.Background(), request)
if tt.wantErr {
require.NoError(t, err)
textContent := getTextResult(t, result)
if tt.errContains != "" {
assert.Contains(t, textContent.Text, tt.errContains)
}
return
}

require.NoError(t, err)
require.NotNil(t, result)
textContent := getTextResult(t, result)
require.NotEmpty(t, textContent.Text)

// Verify response
var branches []github.Branch
err = json.Unmarshal([]byte(textContent.Text), &branches)
require.NoError(t, err)
assert.Len(t, branches, 2)
assert.Equal(t, "main", *branches[0].Name)
assert.Equal(t, "develop", *branches[1].Name)
})
}

// Verify tool definition
assert.Equal(t, "list_branches", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Contains(t, tool.InputSchema.Properties, "owner")
assert.Contains(t, tool.InputSchema.Properties, "repo")
assert.Contains(t, tool.InputSchema.Properties, "page")
assert.Contains(t, tool.InputSchema.Properties, "perPage")
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"})
}
1 change: 1 addition & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati
s.AddTool(SearchRepositories(getClient, t))
s.AddTool(GetFileContents(getClient, t))
s.AddTool(ListCommits(getClient, t))
s.AddTool(ListBranches(getClient, t))
if !readOnly {
s.AddTool(CreateOrUpdateFile(getClient, t))
s.AddTool(CreateRepository(getClient, t))
Expand Down
Loading