@@ -161,6 +161,7 @@ func Test_UpdatePullRequest(t *testing.T) {
161161 HTMLURL : github .Ptr ("https://github.com/owner/repo/pull/42" ),
162162 Body : github .Ptr ("Updated test PR body." ),
163163 MaintainerCanModify : github .Ptr (false ),
164+ Draft : github .Ptr (false ),
164165 Base : & github.PullRequestBranch {
165166 Ref : github .Ptr ("develop" ),
166167 },
@@ -237,6 +238,31 @@ func Test_UpdatePullRequest(t *testing.T) {
237238 expectError : false ,
238239 expectedPR : mockClosedPR ,
239240 },
241+ {
242+ name : "successful PR update (title only)" ,
243+ mockedClient : mock .NewMockedHTTPClient (
244+ mock .WithRequestMatchHandler (
245+ mock .PatchReposPullsByOwnerByRepoByPullNumber ,
246+ expectRequestBody (t , map [string ]interface {}{
247+ "title" : "Updated Test PR Title" ,
248+ }).andThen (
249+ mockResponse (t , http .StatusOK , mockUpdatedPR ),
250+ ),
251+ ),
252+ mock .WithRequestMatch (
253+ mock .GetReposPullsByOwnerByRepoByPullNumber ,
254+ mockUpdatedPR ,
255+ ),
256+ ),
257+ requestArgs : map [string ]interface {}{
258+ "owner" : "owner" ,
259+ "repo" : "repo" ,
260+ "pullNumber" : float64 (42 ),
261+ "title" : "Updated Test PR Title" ,
262+ },
263+ expectError : false ,
264+ expectedPR : mockUpdatedPR ,
265+ },
240266 {
241267 name : "no update parameters provided" ,
242268 mockedClient : mock .NewMockedHTTPClient (), // No API call expected
@@ -325,6 +351,191 @@ func Test_UpdatePullRequest(t *testing.T) {
325351 }
326352}
327353
354+ func Test_UpdatePullRequest_Draft (t * testing.T ) {
355+ // Setup mock PR for success case
356+ mockUpdatedPR := & github.PullRequest {
357+ Number : github .Ptr (42 ),
358+ Title : github .Ptr ("Test PR Title" ),
359+ State : github .Ptr ("open" ),
360+ HTMLURL : github .Ptr ("https://github.com/owner/repo/pull/42" ),
361+ Body : github .Ptr ("Test PR body." ),
362+ MaintainerCanModify : github .Ptr (false ),
363+ Draft : github .Ptr (false ), // Updated to ready for review
364+ Base : & github.PullRequestBranch {
365+ Ref : github .Ptr ("main" ),
366+ },
367+ }
368+
369+ tests := []struct {
370+ name string
371+ mockedClient * http.Client
372+ requestArgs map [string ]interface {}
373+ expectError bool
374+ expectedPR * github.PullRequest
375+ expectedErrMsg string
376+ }{
377+ {
378+ name : "successful draft update to ready for review" ,
379+ mockedClient : githubv4mock .NewMockedHTTPClient (
380+ githubv4mock .NewQueryMatcher (
381+ struct {
382+ Repository struct {
383+ PullRequest struct {
384+ ID githubv4.ID
385+ IsDraft githubv4.Boolean
386+ } `graphql:"pullRequest(number: $prNum)"`
387+ } `graphql:"repository(owner: $owner, name: $repo)"`
388+ }{},
389+ map [string ]any {
390+ "owner" : githubv4 .String ("owner" ),
391+ "repo" : githubv4 .String ("repo" ),
392+ "prNum" : githubv4 .Int (42 ),
393+ },
394+ githubv4mock .DataResponse (map [string ]any {
395+ "repository" : map [string ]any {
396+ "pullRequest" : map [string ]any {
397+ "id" : "PR_kwDOA0xdyM50BPaO" ,
398+ "isDraft" : true , // Current state is draft
399+ },
400+ },
401+ }),
402+ ),
403+ githubv4mock .NewMutationMatcher (
404+ struct {
405+ MarkPullRequestReadyForReview struct {
406+ PullRequest struct {
407+ ID githubv4.ID
408+ IsDraft githubv4.Boolean
409+ }
410+ } `graphql:"markPullRequestReadyForReview(input: $input)"`
411+ }{},
412+ githubv4.MarkPullRequestReadyForReviewInput {
413+ PullRequestID : "PR_kwDOA0xdyM50BPaO" ,
414+ },
415+ nil ,
416+ githubv4mock .DataResponse (map [string ]any {
417+ "markPullRequestReadyForReview" : map [string ]any {
418+ "pullRequest" : map [string ]any {
419+ "id" : "PR_kwDOA0xdyM50BPaO" ,
420+ "isDraft" : false ,
421+ },
422+ },
423+ }),
424+ ),
425+ ),
426+ requestArgs : map [string ]interface {}{
427+ "owner" : "owner" ,
428+ "repo" : "repo" ,
429+ "pullNumber" : float64 (42 ),
430+ "draft" : false ,
431+ },
432+ expectError : false ,
433+ expectedPR : mockUpdatedPR ,
434+ },
435+ {
436+ name : "successful convert pull request to draft" ,
437+ mockedClient : githubv4mock .NewMockedHTTPClient (
438+ githubv4mock .NewQueryMatcher (
439+ struct {
440+ Repository struct {
441+ PullRequest struct {
442+ ID githubv4.ID
443+ IsDraft githubv4.Boolean
444+ } `graphql:"pullRequest(number: $prNum)"`
445+ } `graphql:"repository(owner: $owner, name: $repo)"`
446+ }{},
447+ map [string ]any {
448+ "owner" : githubv4 .String ("owner" ),
449+ "repo" : githubv4 .String ("repo" ),
450+ "prNum" : githubv4 .Int (42 ),
451+ },
452+ githubv4mock .DataResponse (map [string ]any {
453+ "repository" : map [string ]any {
454+ "pullRequest" : map [string ]any {
455+ "id" : "PR_kwDOA0xdyM50BPaO" ,
456+ "isDraft" : false , // Current state is draft
457+ },
458+ },
459+ }),
460+ ),
461+ githubv4mock .NewMutationMatcher (
462+ struct {
463+ ConvertPullRequestToDraft struct {
464+ PullRequest struct {
465+ ID githubv4.ID
466+ IsDraft githubv4.Boolean
467+ }
468+ } `graphql:"convertPullRequestToDraft(input: $input)"`
469+ }{},
470+ githubv4.ConvertPullRequestToDraftInput {
471+ PullRequestID : "PR_kwDOA0xdyM50BPaO" ,
472+ },
473+ nil ,
474+ githubv4mock .DataResponse (map [string ]any {
475+ "convertPullRequestToDraft" : map [string ]any {
476+ "pullRequest" : map [string ]any {
477+ "id" : "PR_kwDOA0xdyM50BPaO" ,
478+ "isDraft" : true ,
479+ },
480+ },
481+ }),
482+ ),
483+ ),
484+ requestArgs : map [string ]interface {}{
485+ "owner" : "owner" ,
486+ "repo" : "repo" ,
487+ "pullNumber" : float64 (42 ),
488+ "draft" : true ,
489+ },
490+ expectError : false ,
491+ expectedPR : mockUpdatedPR ,
492+ },
493+ }
494+
495+ for _ , tc := range tests {
496+ t .Run (tc .name , func (t * testing.T ) {
497+ // For draft-only tests, we need to mock both GraphQL and the final REST GET call
498+ restClient := github .NewClient (mock .NewMockedHTTPClient (
499+ mock .WithRequestMatch (
500+ mock .GetReposPullsByOwnerByRepoByPullNumber ,
501+ mockUpdatedPR ,
502+ ),
503+ ))
504+ gqlClient := githubv4 .NewClient (tc .mockedClient )
505+
506+ _ , handler := UpdatePullRequest (stubGetClientFn (restClient ), stubGetGQLClientFn (gqlClient ), translations .NullTranslationHelper )
507+
508+ request := createMCPRequest (tc .requestArgs )
509+
510+ result , err := handler (context .Background (), request )
511+
512+ if tc .expectError || tc .expectedErrMsg != "" {
513+ require .NoError (t , err )
514+ require .True (t , result .IsError )
515+ errorContent := getErrorResult (t , result )
516+ if tc .expectedErrMsg != "" {
517+ assert .Contains (t , errorContent .Text , tc .expectedErrMsg )
518+ }
519+ return
520+ }
521+
522+ require .NoError (t , err )
523+ require .False (t , result .IsError )
524+
525+ textContent := getTextResult (t , result )
526+
527+ // Unmarshal and verify the successful result
528+ var returnedPR github.PullRequest
529+ err = json .Unmarshal ([]byte (textContent .Text ), & returnedPR )
530+ require .NoError (t , err )
531+ assert .Equal (t , * tc .expectedPR .Number , * returnedPR .Number )
532+ if tc .expectedPR .Draft != nil {
533+ assert .Equal (t , * tc .expectedPR .Draft , * returnedPR .Draft )
534+ }
535+ })
536+ }
537+ }
538+
328539func Test_ListPullRequests (t * testing.T ) {
329540 // Verify tool definition once
330541 mockClient := github .NewClient (nil )
0 commit comments