@@ -20,25 +20,27 @@ type RESTClientInterface interface {
20
20
Patch (endpoint string , body io.Reader , response interface {}) error
21
21
}
22
22
23
- func CombinePRs ( ctx context. Context , graphQlClient * api. GraphQLClient , restClient RESTClientInterface , repo github. Repo , pulls github. Pulls ) error {
24
- // Define the combined branch name
23
+ // CombinePRsWithStats combines PRs and returns stats for summary output
24
+ func CombinePRsWithStats ( ctx context. Context , graphQlClient * api. GraphQLClient , restClient RESTClientInterface , repo github. Repo , pulls github. Pulls ) ( combined [] string , mergeConflicts [] string , combinedPRLink string , err error ) {
25
25
workingBranchName := combineBranchName + workingBranchSuffix
26
26
27
- // Get the default branch of the repository
28
27
repoDefaultBranch , err := getDefaultBranch (ctx , restClient , repo )
29
28
if err != nil {
30
- return fmt .Errorf ("failed to get default branch: %w" , err )
29
+ return nil , nil , "" , fmt .Errorf ("failed to get default branch: %w" , err )
31
30
}
32
31
33
32
baseBranchSHA , err := getBranchSHA (ctx , restClient , repo , repoDefaultBranch )
34
33
if err != nil {
35
- return fmt .Errorf ("failed to get SHA of main branch: %w" , err )
34
+ return nil , nil , "" , fmt .Errorf ("failed to get SHA of main branch: %w" , err )
36
35
}
36
+ // Delete any pre-existing working branch
37
37
38
- // Delete any pre-existing working branch
38
+ // Delete any pre-existing working branch
39
39
err = deleteBranch (ctx , restClient , repo , workingBranchName )
40
40
if err != nil {
41
41
Logger .Debug ("Working branch not found, continuing" , "branch" , workingBranchName )
42
+
43
+ // Delete any pre-existing combined branch
42
44
}
43
45
44
46
// Delete any pre-existing combined branch
@@ -47,60 +49,106 @@ func CombinePRs(ctx context.Context, graphQlClient *api.GraphQLClient, restClien
47
49
Logger .Debug ("Combined branch not found, continuing" , "branch" , combineBranchName )
48
50
}
49
51
50
- // Create the combined branch
51
52
err = createBranch (ctx , restClient , repo , combineBranchName , baseBranchSHA )
52
53
if err != nil {
53
- return fmt .Errorf ("failed to create combined branch: %w" , err )
54
+ return nil , nil , "" , fmt .Errorf ("failed to create combined branch: %w" , err )
54
55
}
55
-
56
- // Create the working branch
57
56
err = createBranch (ctx , restClient , repo , workingBranchName , baseBranchSHA )
58
57
if err != nil {
59
- return fmt .Errorf ("failed to create working branch: %w" , err )
58
+ return nil , nil , "" , fmt .Errorf ("failed to create working branch: %w" , err )
60
59
}
61
60
62
- // Merge all PR branches into the working branch
63
- var combinedPRs []string
64
- var mergeFailedPRs []string
65
61
for _ , pr := range pulls {
66
62
err := mergeBranch (ctx , restClient , repo , workingBranchName , pr .Head .Ref )
67
63
if err != nil {
68
- // Check if the error is a 409 merge conflict
69
64
if isMergeConflictError (err ) {
70
- // Log merge conflicts at DEBUG level
71
65
Logger .Debug ("Merge conflict" , "branch" , pr .Head .Ref , "error" , err )
72
66
} else {
73
- // Log other errors at WARN level
74
67
Logger .Warn ("Failed to merge branch" , "branch" , pr .Head .Ref , "error" , err )
75
68
}
76
- mergeFailedPRs = append (mergeFailedPRs , fmt .Sprintf ("#%d" , pr .Number ))
69
+ mergeConflicts = append (mergeConflicts , fmt .Sprintf ("#%d" , pr .Number ))
77
70
} else {
78
71
Logger .Debug ("Merged branch" , "branch" , pr .Head .Ref )
79
- combinedPRs = append (combinedPRs , fmt .Sprintf ("#%d - %s" , pr .Number , pr .Title ))
72
+ combined = append (combined , fmt .Sprintf ("#%d - %s" , pr .Number , pr .Title ))
80
73
}
81
74
}
82
75
83
- // Update the combined branch to the latest commit of the working branch
84
76
err = updateRef (ctx , restClient , repo , combineBranchName , workingBranchName )
85
77
if err != nil {
86
- return fmt .Errorf ("failed to update combined branch: %w" , err )
78
+ return combined , mergeConflicts , "" , fmt .Errorf ("failed to update combined branch: %w" , err )
87
79
}
88
-
89
- // Delete the temporary working branch
90
80
err = deleteBranch (ctx , restClient , repo , workingBranchName )
91
81
if err != nil {
92
82
Logger .Warn ("Failed to delete working branch" , "branch" , workingBranchName , "error" , err )
93
83
}
94
84
95
- // Create the combined PR
96
- prBody := generatePRBody (combinedPRs , mergeFailedPRs )
85
+ prBody := generatePRBody (combined , mergeConflicts )
97
86
prTitle := "Combined PRs"
98
- err = createPullRequest (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody , addLabels , addAssignees )
87
+ prNumber , prErr := createPullRequestWithNumber (ctx , restClient , repo , prTitle , combineBranchName , repoDefaultBranch , prBody , addLabels , addAssignees )
88
+ if prErr != nil {
89
+ return combined , mergeConflicts , "" , fmt .Errorf ("failed to create combined PR: %w" , prErr )
90
+ }
91
+ if prNumber > 0 {
92
+ combinedPRLink = fmt .Sprintf ("https://github.com/%s/%s/pull/%d" , repo .Owner , repo .Repo , prNumber )
93
+ }
94
+
95
+ return combined , mergeConflicts , combinedPRLink , nil
96
+ }
97
+
98
+ // createPullRequestWithNumber creates a PR and returns its number
99
+ func createPullRequestWithNumber (ctx context.Context , client RESTClientInterface , repo github.Repo , title , head , base , body string , labels , assignees []string ) (int , error ) {
100
+ endpoint := fmt .Sprintf ("repos/%s/%s/pulls" , repo .Owner , repo .Repo )
101
+ payload := map [string ]interface {}{
102
+ "title" : title ,
103
+ "head" : head ,
104
+ "base" : base ,
105
+ "body" : body ,
106
+ }
107
+
108
+ requestBody , err := encodePayload (payload )
99
109
if err != nil {
100
- return fmt .Errorf ("failed to create combined PR : %w" , err )
110
+ return 0 , fmt .Errorf ("failed to encode payload : %w" , err )
101
111
}
102
112
103
- return nil
113
+ var prResponse struct {
114
+ Number int `json:"number"`
115
+ }
116
+ err = client .Post (endpoint , requestBody , & prResponse )
117
+ if err != nil {
118
+ return 0 , fmt .Errorf ("failed to create pull request: %w" , err )
119
+ }
120
+
121
+ if len (labels ) > 0 {
122
+ labelsEndpoint := fmt .Sprintf ("repos/%s/%s/issues/%d/labels" , repo .Owner , repo .Repo , prResponse .Number )
123
+ labelsPayload , err := encodePayload (map [string ][]string {"labels" : labels })
124
+ if err != nil {
125
+ return prResponse .Number , fmt .Errorf ("failed to encode labels payload: %w" , err )
126
+ }
127
+ err = client .Post (labelsEndpoint , labelsPayload , nil )
128
+ if err != nil {
129
+ return prResponse .Number , fmt .Errorf ("failed to add labels: %w" , err )
130
+ }
131
+ }
132
+
133
+ if len (assignees ) > 0 {
134
+ assigneesEndpoint := fmt .Sprintf ("repos/%s/%s/issues/%d/assignees" , repo .Owner , repo .Repo , prResponse .Number )
135
+ assigneesPayload , err := encodePayload (map [string ][]string {"assignees" : assignees })
136
+ if err != nil {
137
+ return prResponse .Number , fmt .Errorf ("failed to encode assignees payload: %w" , err )
138
+ }
139
+ err = client .Post (assigneesEndpoint , assigneesPayload , nil )
140
+ if err != nil {
141
+ return prResponse .Number , fmt .Errorf ("failed to add assignees: %w" , err )
142
+ }
143
+ }
144
+
145
+ return prResponse .Number , nil
146
+ }
147
+
148
+ // Keep CombinePRs for backward compatibility
149
+ func CombinePRs (ctx context.Context , graphQlClient * api.GraphQLClient , restClient RESTClientInterface , repo github.Repo , pulls github.Pulls ) error {
150
+ _ , _ , _ , err := CombinePRsWithStats (ctx , graphQlClient , restClient , repo , pulls )
151
+ return err
104
152
}
105
153
106
154
// isMergeConflictError checks if the error is a 409 Merge Conflict
0 commit comments