Skip to content
113 changes: 113 additions & 0 deletions pkg/github/pullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,119 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc)
}
}

// updatePullRequest creates a tool to update an existing pull request.
func updatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("update_pull_request",
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")),
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
),
mcp.WithString("repo",
mcp.Required(),
mcp.Description("Repository name"),
),
mcp.WithNumber("pullNumber",
mcp.Required(),
mcp.Description("Pull request number to update"),
),
mcp.WithString("title",
mcp.Description("New title"),
),
mcp.WithString("body",
mcp.Description("New description"),
),
mcp.WithString("state",
mcp.Description("New state ('open' or 'closed')"),
mcp.Enum("open", "closed"),
),
mcp.WithString("base",
mcp.Description("New base branch name"),
),
mcp.WithBoolean("maintainer_can_modify",
mcp.Description("Allow maintainer edits"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pullNumber, err := requiredInt(request, "pullNumber")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

// Build the update struct only with provided fields
update := &github.PullRequest{}
updateNeeded := false

if title, ok, err := optionalParamOk[string](request, "title"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Title = github.Ptr(title)
updateNeeded = true
}

if body, ok, err := optionalParamOk[string](request, "body"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Body = github.Ptr(body)
updateNeeded = true
}

if state, ok, err := optionalParamOk[string](request, "state"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.State = github.Ptr(state)
updateNeeded = true
}

if base, ok, err := optionalParamOk[string](request, "base"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
updateNeeded = true
}

if maintainerCanModify, ok, err := optionalParamOk[bool](request, "maintainer_can_modify"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
updateNeeded = true
}

if !updateNeeded {
return mcp.NewToolResultError("No update parameters provided."), nil
}

pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
if err != nil {
return nil, fmt.Errorf("failed to update pull request: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
}

r, err := json.Marshal(pr)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}

return mcp.NewToolResultText(string(r)), nil
}
}

// listPullRequests creates a tool to list and filter repository pull requests.
func listPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("list_pull_requests",
Expand Down
182 changes: 182 additions & 0 deletions pkg/github/pullrequests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,188 @@ func Test_GetPullRequest(t *testing.T) {
}
}

func Test_UpdatePullRequest(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
tool, _ := updatePullRequest(mockClient, translations.NullTranslationHelper)

assert.Equal(t, "update_pull_request", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Contains(t, tool.InputSchema.Properties, "owner")
assert.Contains(t, tool.InputSchema.Properties, "repo")
assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
assert.Contains(t, tool.InputSchema.Properties, "title")
assert.Contains(t, tool.InputSchema.Properties, "body")
assert.Contains(t, tool.InputSchema.Properties, "state")
assert.Contains(t, tool.InputSchema.Properties, "base")
assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify")
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})

// Setup mock PR for success case
mockUpdatedPR := &github.PullRequest{
Number: github.Ptr(42),
Title: github.Ptr("Updated Test PR Title"),
State: github.Ptr("open"),
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
Body: github.Ptr("Updated test PR body."),
MaintainerCanModify: github.Ptr(false),
Base: &github.PullRequestBranch{
Ref: github.Ptr("develop"),
},
}

mockClosedPR := &github.PullRequest{
Number: github.Ptr(42),
Title: github.Ptr("Test PR"),
State: github.Ptr("closed"), // State updated
}

tests := []struct {
name string
mockedClient *http.Client
requestArgs map[string]interface{}
expectError bool
expectedPR *github.PullRequest
expectedErrMsg string
}{
{
name: "successful PR update (title, body, base, maintainer_can_modify)",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.PatchReposPullsByOwnerByRepoByPullNumber,
// Expect the flat string based on previous test failure output and API docs
expectRequestBody(t, map[string]interface{}{
"title": "Updated Test PR Title",
"body": "Updated test PR body.",
"base": "develop",
"maintainer_can_modify": false,
}).andThen(
mockResponse(t, http.StatusOK, mockUpdatedPR),
),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
"title": "Updated Test PR Title",
"body": "Updated test PR body.",
"base": "develop",
"maintainer_can_modify": false,
},
expectError: false,
expectedPR: mockUpdatedPR,
},
{
name: "successful PR update (state)",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.PatchReposPullsByOwnerByRepoByPullNumber,
expectRequestBody(t, map[string]interface{}{
"state": "closed",
}).andThen(
mockResponse(t, http.StatusOK, mockClosedPR),
),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
"state": "closed",
},
expectError: false,
expectedPR: mockClosedPR,
},
{
name: "no update parameters provided",
mockedClient: mock.NewMockedHTTPClient(), // No API call expected
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
// No update fields
},
expectError: false, // Error is returned in the result, not as Go error
expectedErrMsg: "No update parameters provided",
},
{
name: "PR update fails (API error)",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.PatchReposPullsByOwnerByRepoByPullNumber,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
}),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
"title": "Invalid Title Causing Error",
},
expectError: true,
expectedErrMsg: "failed to update pull request",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
_, handler := updatePullRequest(client, translations.NullTranslationHelper)

// Create call request
request := createMCPRequest(tc.requestArgs)

// Call handler
result, err := handler(context.Background(), request)

// Verify results
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErrMsg)
return
}

require.NoError(t, err)

// Parse the result and get the text content
textContent := getTextResult(t, result)

// Check for expected error message within the result text
if tc.expectedErrMsg != "" {
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
return
}

// Unmarshal and verify the successful result
var returnedPR github.PullRequest
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
require.NoError(t, err)
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
if tc.expectedPR.Title != nil {
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
}
if tc.expectedPR.Body != nil {
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
}
if tc.expectedPR.State != nil {
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
}
if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil {
assert.NotNil(t, returnedPR.Base)
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
}
if tc.expectedPR.MaintainerCanModify != nil {
assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify)
}
})
}
}

func Test_ListPullRequests(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
Expand Down
21 changes: 21 additions & 0 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
s.AddTool(updatePullRequestBranch(client, t))
s.AddTool(createPullRequestReview(client, t))
s.AddTool(createPullRequest(client, t))
s.AddTool(updatePullRequest(client, t))
}

// Add GitHub tools - Repositories
Expand Down Expand Up @@ -112,6 +113,26 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc
}
}

// optionalParamOk is a helper function that can be used to fetch a requested parameter from the request.
// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong.
func optionalParamOk[T any](r mcp.CallToolRequest, p string) (T, bool, error) {
var zero T

// Check if the parameter is present in the request
val, ok := r.Params.Arguments[p]
if !ok {
return zero, false, nil // Not present, return zero value, false, no error
}

// Check if the parameter is of the expected type
typedVal, ok := val.(T)
if !ok {
return zero, true, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, val) // Present but wrong type
}

return typedVal, true, nil // Present and correct type
}

// isAcceptedError checks if the error is an accepted error.
func isAcceptedError(err error) bool {
var acceptedError *github.AcceptedError
Expand Down