Skip to content

Commit 65a4558

Browse files
committed
feat: add update_pull_request tool
1 parent 270bbf7 commit 65a4558

File tree

3 files changed

+316
-0
lines changed

3 files changed

+316
-0
lines changed

pkg/github/pullrequests.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,119 @@ func getPullRequest(client *github.Client, t translations.TranslationHelperFunc)
6767
}
6868
}
6969

70+
// updatePullRequest creates a tool to update an existing pull request.
71+
func updatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72+
return mcp.NewTool("update_pull_request",
73+
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")),
74+
mcp.WithString("owner",
75+
mcp.Required(),
76+
mcp.Description("Repository owner"),
77+
),
78+
mcp.WithString("repo",
79+
mcp.Required(),
80+
mcp.Description("Repository name"),
81+
),
82+
mcp.WithNumber("pullNumber",
83+
mcp.Required(),
84+
mcp.Description("Pull request number to update"),
85+
),
86+
mcp.WithString("title",
87+
mcp.Description("New title"),
88+
),
89+
mcp.WithString("body",
90+
mcp.Description("New description"),
91+
),
92+
mcp.WithString("state",
93+
mcp.Description("New state ('open' or 'closed')"),
94+
mcp.Enum("open", "closed"),
95+
),
96+
mcp.WithString("base",
97+
mcp.Description("New base branch name"),
98+
),
99+
mcp.WithBoolean("maintainer_can_modify",
100+
mcp.Description("Allow maintainer edits"),
101+
),
102+
),
103+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
104+
owner, err := requiredParam[string](request, "owner")
105+
if err != nil {
106+
return mcp.NewToolResultError(err.Error()), nil
107+
}
108+
repo, err := requiredParam[string](request, "repo")
109+
if err != nil {
110+
return mcp.NewToolResultError(err.Error()), nil
111+
}
112+
pullNumber, err := requiredInt(request, "pullNumber")
113+
if err != nil {
114+
return mcp.NewToolResultError(err.Error()), nil
115+
}
116+
117+
// Build the update struct only with provided fields
118+
update := &github.PullRequest{}
119+
updateNeeded := false
120+
121+
if title, ok, err := optionalParamOk[string](request, "title"); err != nil {
122+
return mcp.NewToolResultError(err.Error()), nil
123+
} else if ok {
124+
update.Title = github.Ptr(title)
125+
updateNeeded = true
126+
}
127+
128+
if body, ok, err := optionalParamOk[string](request, "body"); err != nil {
129+
return mcp.NewToolResultError(err.Error()), nil
130+
} else if ok {
131+
update.Body = github.Ptr(body)
132+
updateNeeded = true
133+
}
134+
135+
if state, ok, err := optionalParamOk[string](request, "state"); err != nil {
136+
return mcp.NewToolResultError(err.Error()), nil
137+
} else if ok {
138+
update.State = github.Ptr(state)
139+
updateNeeded = true
140+
}
141+
142+
if base, ok, err := optionalParamOk[string](request, "base"); err != nil {
143+
return mcp.NewToolResultError(err.Error()), nil
144+
} else if ok {
145+
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
146+
updateNeeded = true
147+
}
148+
149+
if maintainerCanModify, ok, err := optionalParamOk[bool](request, "maintainer_can_modify"); err != nil {
150+
return mcp.NewToolResultError(err.Error()), nil
151+
} else if ok {
152+
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
153+
updateNeeded = true
154+
}
155+
156+
if !updateNeeded {
157+
return mcp.NewToolResultError("No update parameters provided."), nil
158+
}
159+
160+
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
161+
if err != nil {
162+
return nil, fmt.Errorf("failed to update pull request: %w", err)
163+
}
164+
defer func() { _ = resp.Body.Close() }()
165+
166+
if resp.StatusCode != http.StatusOK {
167+
body, err := io.ReadAll(resp.Body)
168+
if err != nil {
169+
return nil, fmt.Errorf("failed to read response body: %w", err)
170+
}
171+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
172+
}
173+
174+
r, err := json.Marshal(pr)
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to marshal response: %w", err)
177+
}
178+
179+
return mcp.NewToolResultText(string(r)), nil
180+
}
181+
}
182+
70183
// listPullRequests creates a tool to list and filter repository pull requests.
71184
func listPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
72185
return mcp.NewTool("list_pull_requests",

pkg/github/pullrequests_test.go

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,188 @@ func Test_GetPullRequest(t *testing.T) {
126126
}
127127
}
128128

129+
func Test_UpdatePullRequest(t *testing.T) {
130+
// Verify tool definition once
131+
mockClient := github.NewClient(nil)
132+
tool, _ := updatePullRequest(mockClient, translations.NullTranslationHelper)
133+
134+
assert.Equal(t, "update_pull_request", tool.Name)
135+
assert.NotEmpty(t, tool.Description)
136+
assert.Contains(t, tool.InputSchema.Properties, "owner")
137+
assert.Contains(t, tool.InputSchema.Properties, "repo")
138+
assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
139+
assert.Contains(t, tool.InputSchema.Properties, "title")
140+
assert.Contains(t, tool.InputSchema.Properties, "body")
141+
assert.Contains(t, tool.InputSchema.Properties, "state")
142+
assert.Contains(t, tool.InputSchema.Properties, "base")
143+
assert.Contains(t, tool.InputSchema.Properties, "maintainer_can_modify")
144+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})
145+
146+
// Setup mock PR for success case
147+
mockUpdatedPR := &github.PullRequest{
148+
Number: github.Ptr(42),
149+
Title: github.Ptr("Updated Test PR Title"),
150+
State: github.Ptr("open"),
151+
HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"),
152+
Body: github.Ptr("Updated test PR body."),
153+
MaintainerCanModify: github.Ptr(false),
154+
Base: &github.PullRequestBranch{
155+
Ref: github.Ptr("develop"),
156+
},
157+
}
158+
159+
mockClosedPR := &github.PullRequest{
160+
Number: github.Ptr(42),
161+
Title: github.Ptr("Test PR"),
162+
State: github.Ptr("closed"), // State updated
163+
}
164+
165+
tests := []struct {
166+
name string
167+
mockedClient *http.Client
168+
requestArgs map[string]interface{}
169+
expectError bool
170+
expectedPR *github.PullRequest
171+
expectedErrMsg string
172+
}{
173+
{
174+
name: "successful PR update (title, body, base, maintainer_can_modify)",
175+
mockedClient: mock.NewMockedHTTPClient(
176+
mock.WithRequestMatchHandler(
177+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
178+
// Expect the flat string based on previous test failure output and API docs
179+
expectRequestBody(t, map[string]interface{}{
180+
"title": "Updated Test PR Title",
181+
"body": "Updated test PR body.",
182+
"base": "develop",
183+
"maintainer_can_modify": false,
184+
}).andThen(
185+
mockResponse(t, http.StatusOK, mockUpdatedPR),
186+
),
187+
),
188+
),
189+
requestArgs: map[string]interface{}{
190+
"owner": "owner",
191+
"repo": "repo",
192+
"pullNumber": float64(42),
193+
"title": "Updated Test PR Title",
194+
"body": "Updated test PR body.",
195+
"base": "develop",
196+
"maintainer_can_modify": false,
197+
},
198+
expectError: false,
199+
expectedPR: mockUpdatedPR,
200+
},
201+
{
202+
name: "successful PR update (state)",
203+
mockedClient: mock.NewMockedHTTPClient(
204+
mock.WithRequestMatchHandler(
205+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
206+
expectRequestBody(t, map[string]interface{}{
207+
"state": "closed",
208+
}).andThen(
209+
mockResponse(t, http.StatusOK, mockClosedPR),
210+
),
211+
),
212+
),
213+
requestArgs: map[string]interface{}{
214+
"owner": "owner",
215+
"repo": "repo",
216+
"pullNumber": float64(42),
217+
"state": "closed",
218+
},
219+
expectError: false,
220+
expectedPR: mockClosedPR,
221+
},
222+
{
223+
name: "no update parameters provided",
224+
mockedClient: mock.NewMockedHTTPClient(), // No API call expected
225+
requestArgs: map[string]interface{}{
226+
"owner": "owner",
227+
"repo": "repo",
228+
"pullNumber": float64(42),
229+
// No update fields
230+
},
231+
expectError: false, // Error is returned in the result, not as Go error
232+
expectedErrMsg: "No update parameters provided",
233+
},
234+
{
235+
name: "PR update fails (API error)",
236+
mockedClient: mock.NewMockedHTTPClient(
237+
mock.WithRequestMatchHandler(
238+
mock.PatchReposPullsByOwnerByRepoByPullNumber,
239+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
240+
w.WriteHeader(http.StatusUnprocessableEntity)
241+
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
242+
}),
243+
),
244+
),
245+
requestArgs: map[string]interface{}{
246+
"owner": "owner",
247+
"repo": "repo",
248+
"pullNumber": float64(42),
249+
"title": "Invalid Title Causing Error",
250+
},
251+
expectError: true,
252+
expectedErrMsg: "failed to update pull request",
253+
},
254+
}
255+
256+
for _, tc := range tests {
257+
t.Run(tc.name, func(t *testing.T) {
258+
// Setup client with mock
259+
client := github.NewClient(tc.mockedClient)
260+
_, handler := updatePullRequest(client, translations.NullTranslationHelper)
261+
262+
// Create call request
263+
request := createMCPRequest(tc.requestArgs)
264+
265+
// Call handler
266+
result, err := handler(context.Background(), request)
267+
268+
// Verify results
269+
if tc.expectError {
270+
require.Error(t, err)
271+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
272+
return
273+
}
274+
275+
require.NoError(t, err)
276+
277+
// Parse the result and get the text content
278+
textContent := getTextResult(t, result)
279+
280+
// Check for expected error message within the result text
281+
if tc.expectedErrMsg != "" {
282+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
283+
return
284+
}
285+
286+
// Unmarshal and verify the successful result
287+
var returnedPR github.PullRequest
288+
err = json.Unmarshal([]byte(textContent.Text), &returnedPR)
289+
require.NoError(t, err)
290+
assert.Equal(t, *tc.expectedPR.Number, *returnedPR.Number)
291+
if tc.expectedPR.Title != nil {
292+
assert.Equal(t, *tc.expectedPR.Title, *returnedPR.Title)
293+
}
294+
if tc.expectedPR.Body != nil {
295+
assert.Equal(t, *tc.expectedPR.Body, *returnedPR.Body)
296+
}
297+
if tc.expectedPR.State != nil {
298+
assert.Equal(t, *tc.expectedPR.State, *returnedPR.State)
299+
}
300+
if tc.expectedPR.Base != nil && tc.expectedPR.Base.Ref != nil {
301+
assert.NotNil(t, returnedPR.Base)
302+
assert.Equal(t, *tc.expectedPR.Base.Ref, *returnedPR.Base.Ref)
303+
}
304+
if tc.expectedPR.MaintainerCanModify != nil {
305+
assert.Equal(t, *tc.expectedPR.MaintainerCanModify, *returnedPR.MaintainerCanModify)
306+
}
307+
})
308+
}
309+
}
310+
129311
func Test_ListPullRequests(t *testing.T) {
130312
// Verify tool definition once
131313
mockClient := github.NewClient(nil)

pkg/github/server.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
5353
s.AddTool(updatePullRequestBranch(client, t))
5454
s.AddTool(createPullRequestReview(client, t))
5555
s.AddTool(createPullRequest(client, t))
56+
s.AddTool(updatePullRequest(client, t))
5657
}
5758

5859
// Add GitHub tools - Repositories
@@ -112,6 +113,26 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc
112113
}
113114
}
114115

116+
// optionalParamOk is a helper function that can be used to fetch a requested parameter from the request.
117+
// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong.
118+
func optionalParamOk[T any](r mcp.CallToolRequest, p string) (T, bool, error) {
119+
var zero T
120+
121+
// Check if the parameter is present in the request
122+
val, ok := r.Params.Arguments[p]
123+
if !ok {
124+
return zero, false, nil // Not present, return zero value, false, no error
125+
}
126+
127+
// Check if the parameter is of the expected type
128+
typedVal, ok := val.(T)
129+
if !ok {
130+
return zero, true, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, val) // Present but wrong type
131+
}
132+
133+
return typedVal, true, nil // Present and correct type
134+
}
135+
115136
// isAcceptedError checks if the error is an accepted error.
116137
func isAcceptedError(err error) bool {
117138
var acceptedError *github.AcceptedError

0 commit comments

Comments
 (0)