Skip to content

Commit ae89361

Browse files
committed
initial impl of pull request draft state update
1 parent efef8ae commit ae89361

File tree

4 files changed

+136
-29
lines changed

4 files changed

+136
-29
lines changed

pkg/github/__toolsnaps__/update_pull_request.snap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"description": "New description",
1515
"type": "string"
1616
},
17+
"draft": {
18+
"description": "Mark pull request as draft (true) or ready for review (false)",
19+
"type": "boolean"
20+
},
1721
"maintainer_can_modify": {
1822
"description": "Allow maintainer edits",
1923
"type": "boolean"

pkg/github/pullrequests.go

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
203203
}
204204

205205
// UpdatePullRequest creates a tool to update an existing pull request.
206-
func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
206+
func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) {
207207
return mcp.NewTool("update_pull_request",
208208
mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")),
209209
mcp.WithToolAnnotation(mcp.ToolAnnotation{
@@ -232,6 +232,9 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
232232
mcp.Description("New state"),
233233
mcp.Enum("open", "closed"),
234234
),
235+
mcp.WithBoolean("draft",
236+
mcp.Description("Mark pull request as draft (true) or ready for review (false)"),
237+
),
235238
mcp.WithString("base",
236239
mcp.Description("New base branch name"),
237240
),
@@ -253,74 +256,165 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu
253256
return mcp.NewToolResultError(err.Error()), nil
254257
}
255258

256-
// Build the update struct only with provided fields
259+
draftProvided := request.GetArguments()["draft"] != nil
260+
var draftValue bool
261+
if draftProvided {
262+
draftValue, err = OptionalParam[bool](request, "draft")
263+
if err != nil {
264+
return nil, err
265+
}
266+
}
267+
257268
update := &github.PullRequest{}
258-
updateNeeded := false
269+
restUpdateNeeded := false
259270

260271
if title, ok, err := OptionalParamOK[string](request, "title"); err != nil {
261272
return mcp.NewToolResultError(err.Error()), nil
262273
} else if ok {
263274
update.Title = github.Ptr(title)
264-
updateNeeded = true
275+
restUpdateNeeded = true
265276
}
266277

267278
if body, ok, err := OptionalParamOK[string](request, "body"); err != nil {
268279
return mcp.NewToolResultError(err.Error()), nil
269280
} else if ok {
270281
update.Body = github.Ptr(body)
271-
updateNeeded = true
282+
restUpdateNeeded = true
272283
}
273284

274285
if state, ok, err := OptionalParamOK[string](request, "state"); err != nil {
275286
return mcp.NewToolResultError(err.Error()), nil
276287
} else if ok {
277288
update.State = github.Ptr(state)
278-
updateNeeded = true
289+
restUpdateNeeded = true
279290
}
280291

281292
if base, ok, err := OptionalParamOK[string](request, "base"); err != nil {
282293
return mcp.NewToolResultError(err.Error()), nil
283294
} else if ok {
284295
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
285-
updateNeeded = true
296+
restUpdateNeeded = true
286297
}
287298

288299
if maintainerCanModify, ok, err := OptionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
289300
return mcp.NewToolResultError(err.Error()), nil
290301
} else if ok {
291302
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
292-
updateNeeded = true
303+
restUpdateNeeded = true
293304
}
294305

295-
if !updateNeeded {
306+
if !restUpdateNeeded && !draftProvided {
296307
return mcp.NewToolResultError("No update parameters provided."), nil
297308
}
298309

310+
if restUpdateNeeded {
311+
client, err := getClient(ctx)
312+
if err != nil {
313+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
314+
}
315+
316+
_, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
317+
if err != nil {
318+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
319+
"failed to update pull request",
320+
resp,
321+
err,
322+
), nil
323+
}
324+
defer func() { _ = resp.Body.Close() }()
325+
326+
if resp.StatusCode != http.StatusOK {
327+
body, err := io.ReadAll(resp.Body)
328+
if err != nil {
329+
return nil, fmt.Errorf("failed to read response body: %w", err)
330+
}
331+
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
332+
}
333+
}
334+
335+
if draftProvided {
336+
gqlClient, err := getGQLClient(ctx)
337+
if err != nil {
338+
return nil, fmt.Errorf("failed to get GitHub GraphQL client: %w", err)
339+
}
340+
341+
var prQuery struct {
342+
Repository struct {
343+
PullRequest struct {
344+
ID githubv4.ID
345+
IsDraft githubv4.Boolean
346+
} `graphql:"pullRequest(number: $prNum)"`
347+
} `graphql:"repository(owner: $owner, name: $repo)"`
348+
}
349+
350+
err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{
351+
"owner": githubv4.String(owner),
352+
"repo": githubv4.String(repo),
353+
"prNum": githubv4.Int(pullNumber),
354+
})
355+
if err != nil {
356+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find pull request", err), nil
357+
}
358+
359+
currentIsDraft := bool(prQuery.Repository.PullRequest.IsDraft)
360+
361+
if currentIsDraft != draftValue {
362+
if draftValue {
363+
// Convert to draft
364+
var mutation struct {
365+
ConvertPullRequestToDraft struct {
366+
PullRequest struct {
367+
ID githubv4.ID
368+
IsDraft githubv4.Boolean
369+
}
370+
} `graphql:"convertPullRequestToDraft(input: $input)"`
371+
}
372+
373+
err = gqlClient.Mutate(ctx, &mutation, githubv4.ConvertPullRequestToDraftInput{
374+
PullRequestID: prQuery.Repository.PullRequest.ID,
375+
}, nil)
376+
if err != nil {
377+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to convert pull request to draft", err), nil
378+
}
379+
} else {
380+
// Mark as ready for review
381+
var mutation struct {
382+
MarkPullRequestReadyForReview struct {
383+
PullRequest struct {
384+
ID githubv4.ID
385+
IsDraft githubv4.Boolean
386+
}
387+
} `graphql:"markPullRequestReadyForReview(input: $input)"`
388+
}
389+
390+
err = gqlClient.Mutate(ctx, &mutation, githubv4.MarkPullRequestReadyForReviewInput{
391+
PullRequestID: prQuery.Repository.PullRequest.ID,
392+
}, nil)
393+
if err != nil {
394+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to mark pull request ready for review", err), nil
395+
}
396+
}
397+
}
398+
}
399+
299400
client, err := getClient(ctx)
300401
if err != nil {
301-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
402+
return nil, err
302403
}
303-
pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update)
404+
405+
finalPR, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber)
304406
if err != nil {
305-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
306-
"failed to update pull request",
307-
resp,
308-
err,
309-
), nil
407+
return ghErrors.NewGitHubAPIErrorResponse(ctx, "Failed to get pull request", resp, err), nil
310408
}
311-
defer func() { _ = resp.Body.Close() }()
312-
313-
if resp.StatusCode != http.StatusOK {
314-
body, err := io.ReadAll(resp.Body)
315-
if err != nil {
316-
return nil, fmt.Errorf("failed to read response body: %w", err)
409+
defer func() {
410+
if resp != nil && resp.Body != nil {
411+
_ = resp.Body.Close()
317412
}
318-
return mcp.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(body))), nil
319-
}
413+
}()
320414

321-
r, err := json.Marshal(pr)
415+
r, err := json.Marshal(finalPR)
322416
if err != nil {
323-
return nil, fmt.Errorf("failed to marshal response: %w", err)
417+
return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to marshal response: %v", err), nil
324418
}
325419

326420
return mcp.NewToolResultText(string(r)), nil

pkg/github/pullrequests_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ func Test_GetPullRequest(t *testing.T) {
137137
func Test_UpdatePullRequest(t *testing.T) {
138138
// Verify tool definition once
139139
mockClient := github.NewClient(nil)
140-
tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper)
140+
tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper)
141141
require.NoError(t, toolsnaps.Test(tool.Name, tool))
142142

143143
assert.Equal(t, "update_pull_request", tool.Name)
144144
assert.NotEmpty(t, tool.Description)
145145
assert.Contains(t, tool.InputSchema.Properties, "owner")
146146
assert.Contains(t, tool.InputSchema.Properties, "repo")
147147
assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
148+
assert.Contains(t, tool.InputSchema.Properties, "draft")
148149
assert.Contains(t, tool.InputSchema.Properties, "title")
149150
assert.Contains(t, tool.InputSchema.Properties, "body")
150151
assert.Contains(t, tool.InputSchema.Properties, "state")
@@ -194,6 +195,10 @@ func Test_UpdatePullRequest(t *testing.T) {
194195
mockResponse(t, http.StatusOK, mockUpdatedPR),
195196
),
196197
),
198+
mock.WithRequestMatch(
199+
mock.GetReposPullsByOwnerByRepoByPullNumber,
200+
mockUpdatedPR,
201+
),
197202
),
198203
requestArgs: map[string]interface{}{
199204
"owner": "owner",
@@ -218,6 +223,10 @@ func Test_UpdatePullRequest(t *testing.T) {
218223
mockResponse(t, http.StatusOK, mockClosedPR),
219224
),
220225
),
226+
mock.WithRequestMatch(
227+
mock.GetReposPullsByOwnerByRepoByPullNumber,
228+
mockClosedPR,
229+
),
221230
),
222231
requestArgs: map[string]interface{}{
223232
"owner": "owner",
@@ -266,7 +275,7 @@ func Test_UpdatePullRequest(t *testing.T) {
266275
t.Run(tc.name, func(t *testing.T) {
267276
// Setup client with mock
268277
client := github.NewClient(tc.mockedClient)
269-
_, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper)
278+
_, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper)
270279

271280
// Create call request
272281
request := createMCPRequest(tc.requestArgs)

pkg/github/tools.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG
8787
toolsets.NewServerTool(MergePullRequest(getClient, t)),
8888
toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)),
8989
toolsets.NewServerTool(CreatePullRequest(getClient, t)),
90-
toolsets.NewServerTool(UpdatePullRequest(getClient, t)),
90+
toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)),
9191
toolsets.NewServerTool(RequestCopilotReview(getClient, t)),
9292

9393
// Reviews

0 commit comments

Comments
 (0)